import json
import os
import time
import secrets
import threading
from datetime import datetime, timedelta
from config import ACCESS_TOKEN_EXPIRE_HOURS


# Singleton to store tokens
class TokenStore:
    _instance = None
    _lock = threading.Lock()

    def __new__(cls):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(TokenStore, cls).__new__(cls)
                # Initialize here to make instance attributes
                cls._instance.tokens = {}  # username -> {token, created_at}
                cls._instance.token_to_user = {}  # token -> username
                cls._instance.tokens_file = "data/tokens.json"
            return cls._instance

    def __init__(self):
        # Re-initialize in __init__ to help linters recognize these attributes
        if not hasattr(self, "tokens"):
            self.tokens = {}
        if not hasattr(self, "token_to_user"):
            self.token_to_user = {}
        if not hasattr(self, "tokens_file"):
            self.tokens_file = "data/tokens.json"

        # Load tokens when instance is created
        if not hasattr(self, "_loaded"):
            self._load_tokens()
            self._loaded = True

    def _load_tokens(self):
        """Load tokens from file if it exists"""
        os.makedirs("data", exist_ok=True)
        if os.path.exists(self.tokens_file):
            try:
                with open(self.tokens_file, "r") as f:
                    data = json.load(f)
                    self.tokens = data.get("tokens", {})
                    self.token_to_user = data.get("token_to_user", {})

                    # Clean expired tokens on load
                    self._clean_expired_tokens()
            except Exception as e:
                print(f"Error loading tokens: {e}")
                self.tokens = {}
                self.token_to_user = {}

    def _save_tokens(self):
        """Save tokens to file"""
        try:
            with open(self.tokens_file, "w") as f:
                json.dump(
                    {"tokens": self.tokens, "token_to_user": self.token_to_user},
                    f,
                    indent=4,
                )
        except Exception as e:
            print(f"Error saving tokens: {e}")

    def _clean_expired_tokens(self):
        """Remove expired tokens"""
        current_time = time.time()
        expired_usernames = []
        expired_tokens = []

        # Find expired tokens
        for username, token_data in self.tokens.items():
            created_at = token_data.get("created_at", 0)
            expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600

            if current_time - created_at > expiry_seconds:
                expired_usernames.append(username)
                expired_tokens.append(token_data.get("token"))

        # Remove expired tokens
        for username in expired_usernames:
            if username in self.tokens:
                del self.tokens[username]

        for token in expired_tokens:
            if token in self.token_to_user:
                del self.token_to_user[token]

        # Save changes if any tokens were removed
        if expired_tokens:
            self._save_tokens()

    def create_token(self, username):
        """Create a new token for a user, removing any existing token"""
        with self._lock:
            # Clean expired tokens first
            self._clean_expired_tokens()

            # Remove old token if it exists
            if username in self.tokens:
                old_token = self.tokens[username].get("token")
                if old_token in self.token_to_user:
                    del self.token_to_user[old_token]

            # Create new token
            token = secrets.token_hex(32)  # 64 character random hex string
            self.tokens[username] = {"token": token, "created_at": time.time()}
            self.token_to_user[token] = username

            # Save changes
            self._save_tokens()

            return token

    def validate_token(self, token):
        """Validate a token and return the username if valid"""
        with self._lock:
            # Clean expired tokens first
            self._clean_expired_tokens()

            # Check if token exists
            if token not in self.token_to_user:
                return None

            username = self.token_to_user[token]

            # Check if token is not expired
            if username in self.tokens:
                token_data = self.tokens[username]
                created_at = token_data.get("created_at", 0)
                current_time = time.time()
                expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600

                if current_time - created_at <= expiry_seconds:
                    return username

            # Token is expired or invalid
            return None

    def remove_token(self, token):
        """Remove a token"""
        with self._lock:
            if token in self.token_to_user:
                username = self.token_to_user[token]
                del self.token_to_user[token]

                if username in self.tokens:
                    del self.tokens[username]

                self._save_tokens()
                return True
            return False


# Get the singleton instance
token_store = TokenStore()