import json
import os
import requests
from typing import Optional, List, Any
from pydantic import BaseModel, Field

class LlmPredictParams(BaseModel):
    """
    Параметры для предсказания LLM.
    """
    system_prompt: Optional[str] = Field(None, description="Системный промпт.")
    user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
    n_predict: Optional[int] = None
    temperature: Optional[float] = None
    top_k: Optional[int] = None
    top_p: Optional[float] = None
    min_p: Optional[float] = None
    seed: Optional[int] = None
    repeat_penalty: Optional[float] = None
    repeat_last_n: Optional[int] = None
    retry_if_text_not_present: Optional[str] = None
    retry_count: Optional[int] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    n_keep: Optional[int] = None
    cache_prompt: Optional[bool] = None
    stop: Optional[List[str]] = None


class LlmParams(BaseModel):
    """
    Основные параметры для LLM.
    """
    url: str
    type: Optional[str] = None
    default: Optional[bool] = None
    template: Optional[str] = None
    predict_params: Optional[LlmPredictParams] = None
    
class LlmApi:
    """
    Класс для работы с API vllm.
    """
    
    params: LlmParams = None
    
    def __init__(self, params: LlmParams):
        self.params = params
        
        
    def get_models(self) -> list[str]:
        """
        Выполняет GET-запрос к API для получения списка доступных моделей.

        Возвращает:
            list[str]: Список идентификаторов моделей.
                       Если произошла ошибка или данные недоступны, возвращается пустой список.

        Исключения:
            Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
        """
        
        try:
            response = requests.get(f"{self.params.url}/v1/models", headers={"Content-Type": "application/json"})

            if response.status_code == 200:
                json_data = response.json()
                result = [item['id'] for item in json_data.get('data', [])]
                return result

        except requests.RequestException as error:
            print('OpenAiService.getModels error:')
            print(error)

        return []
    
    def create_messages(self, prompt: str) -> list[dict]:
        """
        Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).

        Args:
            prompt (str): Пользовательский промпт.

        Returns:
            list[dict]: Список сообщений с ролями и содержимым.
        """
        actual_prompt = self.apply_llm_template_to_prompt(prompt)
        messages = []

        if self.params.predict_params and self.params.predict_params.system_prompt:
            messages.append({"role": "system", "content": self.params.predict_params.system_prompt})

        messages.append({"role": "user", "content": actual_prompt})
        return messages

    def apply_llm_template_to_prompt(self, prompt: str) -> str:
        """
        Применяет шаблон LLM к переданному промпту, если он задан.

        Args:
            prompt (str): Пользовательский промпт.

        Returns:
            str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
        """
        actual_prompt = prompt
        if self.params.template is not None:
            actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
        return actual_prompt
    
    def tokenize(self, prompt: str) -> Optional[dict]:
        """
        Выполняет токенизацию переданного промпта.

        Args:
            prompt (str): Промпт для токенизации.

        Returns:
            Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
                            Если запрос неуспешен, возвращает None.
        """
        model = self.get_models()[0] if self.get_models() else None
        if not model:
            print("No models available for tokenization.")
            return None

        actual_prompt = self.apply_llm_template_to_prompt(prompt)
        request_data = {
            "model": model,
            "prompt": actual_prompt,
            "add_special_tokens": False,
        }

        try:
            response = requests.post(
                f"{self.params.url}/tokenize",
                json=request_data,
                headers={"Content-Type": "application/json"},
            )

            if response.ok:
                data = response.json()
                if "tokens" in data:
                    return {"tokens": data["tokens"], "maxLength": data.get("max_model_len")}
            elif response.status_code == 404:
                print("Tokenization endpoint not found (404).")
            else:
                print(f"Failed to tokenize: {response.status_code}")
        except requests.RequestException as e:
            print(f"Request failed: {e}")

        return None

    def detokenize(self, tokens: List[int]) -> Optional[str]:
        """
        Выполняет детокенизацию переданных токенов.

        Args:
            tokens (List[int]): Список токенов для детокенизации.

        Returns:
            Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
                           Если запрос неуспешен, возвращает None.
        """
        model = self.get_models()[0] if self.get_models() else None
        if not model:
            print("No models available for detokenization.")
            return None

        request_data = {"model": model, "tokens": tokens or []}

        try:
            response = requests.post(
                f"{self.params.url}/detokenize",
                json=request_data,
                headers={"Content-Type": "application/json"},
            )

            if response.ok:
                data = response.json()
                if "prompt" in data:
                    return data["prompt"].strip()
            elif response.status_code == 404:
                print("Detokenization endpoint not found (404).")
            else:
                print(f"Failed to detokenize: {response.status_code}")
        except requests.RequestException as e:
            print(f"Request failed: {e}")

        return None
    
    def create_request(self, prompt: str) -> dict:
        """
        Создает запрос для предсказания на основе параметров LLM.

        Args:
            prompt (str): Промпт для запроса.

        Returns:
            dict: Словарь с параметрами для выполнения запроса.
        """
        llm_params = self.params
        models = self.get_models()
        if not models:
            raise ValueError("No models available to create a request.")
        model = models[0]

        request = {
            "stream": True,
            "model": model,
        }

        predict_params = llm_params.predict_params

        if predict_params:
            if predict_params.stop:
                # Фильтруем пустые строки в stop
                non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
                if non_empty_stop:
                    request["stop"] = non_empty_stop

            if predict_params.n_predict is not None:
                request["max_tokens"] = int(predict_params.n_predict or 0)

            request["temperature"] = float(predict_params.temperature or 0)

            if predict_params.top_k is not None:
                request["top_k"] = int(predict_params.top_k)

            if predict_params.top_p is not None:
                request["top_p"] = float(predict_params.top_p)

            if predict_params.min_p is not None:
                request["min_p"] = float(predict_params.min_p)

            if predict_params.seed is not None:
                request["seed"] = int(predict_params.seed)

            if predict_params.n_keep is not None:
                request["n_keep"] = int(predict_params.n_keep)

            if predict_params.cache_prompt is not None:
                request["cache_prompt"] = bool(predict_params.cache_prompt)

            if predict_params.repeat_penalty is not None:
                request["repetition_penalty"] = float(predict_params.repeat_penalty)

            if predict_params.repeat_last_n is not None:
                request["repeat_last_n"] = int(predict_params.repeat_last_n)

            if predict_params.presence_penalty is not None:
                request["presence_penalty"] = float(predict_params.presence_penalty)

            if predict_params.frequency_penalty is not None:
                request["frequency_penalty"] = float(predict_params.frequency_penalty)

        # Генерируем сообщения
        request["messages"] = self.create_messages(prompt)

        return request

    
    def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
        """
        Обрезает текст источников, чтобы уложиться в допустимое количество токенов.

        Args:
            sources (str): Текст источников.
            user_request (str): Запрос пользователя с примененным шаблоном без текста источников.
            system_prompt (str): Системный промпт, если нужен.

        Returns:
            dict: Словарь с результатом, количеством токенов до и после обрезки.
        """
        # Токенизация текста источников
        sources_tokens_data = self.tokenize(sources)
        if sources_tokens_data is None:
            raise ValueError("Failed to tokenize sources.")
        max_token_count = sources_tokens_data.get("maxLength", 0)

        # Токены системного промпта
        system_prompt_token_count = 0
        
        if system_prompt is not None:
            system_prompt_tokens = self.tokenize(system_prompt)
            system_prompt_token_count = len(system_prompt_tokens["tokens"]) if system_prompt_tokens else 0

        # Оригинальное количество токенов
        original_token_count = len(sources_tokens_data["tokens"])

        # Токенизация пользовательского промпта
        aux_prompt = self.apply_llm_template_to_prompt(user_request)
        aux_tokens_data = self.tokenize(aux_prompt)

        aux_token_count = len(aux_tokens_data["tokens"]) if aux_tokens_data else 0

        # Максимально допустимое количество токенов для источников
        max_length = (
            max_token_count
            - (self.params.predict_params.n_predict or 0)
            - aux_token_count
            - system_prompt_token_count
        )
        max_length = max(max_length, 0)

        # Обрезка токенов источников
        if "tokens" in sources_tokens_data:
            sources_tokens_data["tokens"] = sources_tokens_data["tokens"][:max_length]
            detokenized_prompt = self.detokenize(sources_tokens_data["tokens"])
            if detokenized_prompt is not None:
                sources = detokenized_prompt
            else:
                sources = sources[:max_length]
        else:
            sources = sources[:max_length]

        # Возврат результата
        return {
            "result": sources,
            "originalTokenCount": original_token_count,
            "slicedTokenCount": len(sources_tokens_data["tokens"]),
        }

    def predict(self, prompt: str) -> str:
        """
        Выполняет SSE-запрос к API и возвращает собранный результат как текст.

        Args:
            prompt (str): Входной текст для предсказания.

        Returns:
            str: Сгенерированный текст.
        
        Raises:
            Exception: Если запрос завершился ошибкой.
        """
        
        # Создание запроса
        request = self.create_request(prompt)

        print(f"Predict request. Url: {self.params.url}")

        response = requests.post(
            f"{self.params.url}/v1/chat/completions",
            headers={"Content-Type": "application/json"},
            json=request,
            stream=True  # Для обработки SSE
        )

        if not response.ok:
            raise Exception(f"Failed to generate text: {response.text}")

        # Обработка SSE-ответа
        generated_text = ""
        for line in response.iter_lines(decode_unicode=True):
            if line.startswith("data: "):
                try:
                    data = json.loads(line[len("data: "):].strip())

                    # Проверка завершения генерации
                    if data == "[DONE]":
                        break

                    # Получение текста из ответа
                    if "choices" in data and data["choices"]:
                        token_value = data["choices"][0].get("delta", {}).get("content", "")
                        generated_text += token_value.replace("</s>", "")

                except json.JSONDecodeError:
                    continue  # Игнорирование строк, которые не удалось декодировать

        return generated_text