import string
from typing import List, Optional, Tuple

from langchain.chains import LLMChain
from langchain.chains.base import Chain
from loguru import logger

from app.chroma import ChromaDenseVectorDB
from app.config.models.configs import (
    ResponseModel,
    Config, SemanticSearchConfig,
)
from app.ranking import BCEReranker, rerank
from app.splade import SpladeSparseVectorDB


class LLMBundle:
    def __init__(
            self,
            chain: Chain,
            dense_db: ChromaDenseVectorDB,
            reranker: BCEReranker,
            sparse_db: SpladeSparseVectorDB,
            chunk_sizes: List[int],
            hyde_chain: Optional[LLMChain] = None
    ) -> None:
        self.chain = chain
        self.dense_db = dense_db
        self.reranker = reranker
        self.sparse_db = sparse_db
        self.chunk_sizes = chunk_sizes
        self.hyde_chain = hyde_chain

    def get_relevant_documents(
            self,
            original_query: str,
            query: str,
            config: SemanticSearchConfig,
            label: str,
    ) -> Tuple[List[str], float]:
        most_relevant_docs = []
        docs = []

        current_reranker_score, reranker_score = -1e5, -1e5

        for chunk_size in self.chunk_sizes:
            all_relevant_docs = []
            all_relevant_doc_ids = set()
            logger.debug("Evaluating query: {}", query)
            if config.query_prefix:
                logger.info(f"Adding query prefix for retrieval: {config.query_prefix}")
                query = config.query_prefix + query
            sparse_search_docs_ids, sparse_scores = self.sparse_db.query(
                search=query, n=config.max_k, label=label, chunk_size=chunk_size
            )

            logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.")

            filter = (
                {"chunk_size": chunk_size}
                if len(self.chunk_sizes) > 1
                else dict()
            )

            if label:
                filter.update({"label": label})

            if (
                    not filter
            ):
                filter = None

            logger.info(f"Dense embeddings filter: {filter}")

            res = self.dense_db.similarity_search_with_relevance_scores(
                query, filter=filter
            )
            dense_search_doc_ids = [r[0].metadata["document_id"] for r in res]

            all_doc_ids = (
                set(sparse_search_docs_ids).union(set(dense_search_doc_ids))
            ).difference(all_relevant_doc_ids)
            if all_doc_ids:
                relevant_docs = self.dense_db.get_documents_by_id(
                    document_ids=list(all_doc_ids)
                )
                all_relevant_docs += relevant_docs

            # Re-rank embeddings
            reranker_score, relevant_docs = rerank(
                rerank_model=self.reranker,
                query=original_query,
                docs=all_relevant_docs,
            )
            if reranker_score > current_reranker_score:
                docs = relevant_docs
                current_reranker_score = reranker_score

        len_ = 0

        for doc in docs:
            doc_length = len(doc.page_content)
            if len_ + doc_length < config.max_char_size:
                most_relevant_docs.append(doc)
                len_ += doc_length

        return most_relevant_docs, current_reranker_score

    def get_and_parse_response(
            self,
            query: str,
            config: Config,
            label: str = "",
    ) -> ResponseModel:
        original_query = query

        # Add HyDE queries
        hyde_response = self.hyde_chain.run(query)

        query += hyde_response

        logger.info(f"query: {query}")

        semantic_search_config = config.semantic_search
        most_relevant_docs, score = self.get_relevant_documents(
            original_query, query, semantic_search_config, label
        )

        res = self.chain(
            {"input_documents": most_relevant_docs, "question": original_query},
        )

        out = ResponseModel(
            response=res["output_text"],
            question=query,
            average_score=score,
            hyde_response="",
        )
        for doc in res["input_documents"]:
            out.semantic_search.append(doc.page_content)

        return out


class PartialFormatter(string.Formatter):
    def __init__(self, missing="~~", bad_fmt="!!"):
        self.missing, self.bad_fmt = missing, bad_fmt

    def get_field(self, field_name, args, kwargs):
        try:
            val = super(PartialFormatter, self).get_field(field_name, args, kwargs)
        except (KeyError, AttributeError):
            val = None, field_name
        return val

    def format_field(self, value, spec):
        if value is None:
            return self.missing
        try:
            return super(PartialFormatter, self).format_field(value, spec)
        except ValueError:
            if self.bad_fmt is not None:
                return self.bad_fmt
            else:
                raise