import os

import numpy as np
import torch

from pathlib import Path
from typing import Union

from huggingface_hub import hf_hub_download
from numpy.linalg import norm
from onnxruntime import InferenceSession
from tclogger import logger
from transformers import AutoTokenizer, AutoModel

from configs.envs import ENVS
from configs.constants import AVAILABLE_MODELS

if ENVS["HF_ENDPOINT"]:
    os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]


def cosine_similarity(a, b):
    return (a @ b.T) / (norm(a) * norm(b))


class JinaAIOnnxEmbedder:
    """https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893"""

    def __init__(self):
        self.repo_name = "jinaai/jina-embeddings-v2-base-zh"
        self.download_model()
        self.load_model()

    def download_model(self):
        self.onnx_folder = Path(__file__).parents[2] / ".cache"
        self.onnx_folder.mkdir(parents=True, exist_ok=True)
        self.onnx_filename = "onnx/model_quantized.onnx"
        self.onnx_path = self.onnx_folder / self.onnx_filename
        if not self.onnx_path.exists():
            logger.note("> Downloading ONNX model")
            hf_hub_download(
                repo_id=self.repo_name,
                filename=self.onnx_filename,
                local_dir=self.onnx_folder,
                local_dir_use_symlinks=False,
            )
            logger.success(f"+ ONNX model downloaded: {self.onnx_path}")
        else:
            logger.success(f"+ ONNX model loaded: {self.onnx_path}")

    def load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.repo_name, trust_remote_code=True
        )
        self.session = InferenceSession(self.onnx_path)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
        )

    def encode(self, text: str):
        inputs = self.tokenizer(text, return_tensors="np")
        inputs = {
            name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items()
        }
        outputs = self.session.run(
            output_names=["last_hidden_state"], input_feed=dict(inputs)
        )
        embeddings = self.mean_pooling(
            torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"])
        )
        return embeddings


class JinaAIEmbedder:
    def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
        self.model_name = model_name
        self.load_model()

    def check_model_name(self):
        if self.model_name not in AVAILABLE_MODELS:
            self.model_name = AVAILABLE_MODELS[0]
        return True

    def load_model(self):
        self.check_model_name()
        self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)

    def switch_model(self, model_name: str):
        if model_name != self.model_name:
            self.model_name = model_name
            self.load_model()

    def encode(self, text: Union[str, list[str]]):
        if isinstance(text, str):
            text = [text]
        return self.model.encode(text)


if __name__ == "__main__":
    # embedder = JinaAIEmbedder()
    embedder = JinaAIOnnxEmbedder()
    texts = ["How is the weather today?", "今天天气怎么样?"]
    embeddings = []
    for text in texts:
        embeddings.append(embedder.encode(text))
    logger.success(embeddings)
    print(cosine_similarity(embeddings[0], embeddings[1]))

    # python -m transforms.embed