import pandas as pd
import hashlib
import requests
from typing import List, Optional
from datetime import datetime
from langchain.schema.embeddings import Embeddings
from streamlit.runtime.uploaded_file_manager import UploadedFile
from clickhouse_connect import get_client
from multiprocessing.pool import ThreadPool
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
from .helper import create_retriever_tool

parser_url = "https://api.unstructured.io/general/v0/general"


def parse_files(api_key, user_id, files: List[UploadedFile]):
    def parse_file(file: UploadedFile):
        headers = {
            "accept": "application/json",
            "unstructured-api-key": api_key,
        }
        data = {"strategy": "auto", "ocr_languages": ["eng"]}
        file_hash = hashlib.sha256(file.read()).hexdigest()
        file_data = {"files": (file.name, file.getvalue(), file.type)}
        response = requests.post(
            parser_url, headers=headers, data=data, files=file_data
        )
        json_response = response.json()
        if response.status_code != 200:
            raise ValueError(str(json_response))
        texts = [
            {
                "text": t["text"],
                "file_name": t["metadata"]["filename"],
                "entity_id": hashlib.sha256(
                    (file_hash + t["text"]).encode()
                ).hexdigest(),
                "user_id": user_id,
                "created_by": datetime.now(),
            }
            for t in json_response
            if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
        ]
        return texts

    with ThreadPool(8) as p:
        rows = []
        for r in p.imap_unordered(parse_file, files):
            rows.extend(r)
        return rows


def extract_embedding(embeddings: Embeddings, texts):
    if len(texts) > 0:
        embs = embeddings.embed_documents(
            [t["text"] for _, t in enumerate(texts)])
        for i, _ in enumerate(texts):
            texts[i]["vector"] = embs[i]
        return texts
    raise ValueError("No texts extracted!")


class PrivateKnowledgeBase:
    def __init__(
        self,
        host,
        port,
        username,
        password,
        embedding: Embeddings,
        parser_api_key,
        db="chat",
        kb_table="private_kb",
        tool_table="private_tool",
    ) -> None:
        super().__init__()
        kb_schema_ = f"""
            CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
                entity_id String,
                file_name String,
                text String,
                user_id String,
                created_by DateTime,
                vector Array(Float32),
                CONSTRAINT cons_vec_len CHECK length(vector) = 768,
                VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
            ) ENGINE = ReplacingMergeTree ORDER BY entity_id
        """
        tool_schema_ = f"""
            CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
                tool_id String,
                tool_name String,
                file_names Array(String),
                user_id String,
                created_by DateTime,
                tool_description String
            ) ENGINE = ReplacingMergeTree ORDER BY tool_id
        """
        self.kb_table = kb_table
        self.tool_table = tool_table
        config = MyScaleSettings(
            host=host,
            port=port,
            username=username,
            password=password,
            database=db,
            table=kb_table,
        )
        client = get_client(
            host=config.host,
            port=config.port,
            username=config.username,
            password=config.password,
        )
        client.command("SET allow_experimental_object_type=1")
        client.command(kb_schema_)
        client.command(tool_schema_)
        self.parser_api_key = parser_api_key
        self.vstore = MyScaleWithoutJSON(
            embedding=embedding,
            config=config,
            must_have_cols=["file_name", "text", "created_by"],
        )

    def list_files(self, user_id, tool_name=None):
        query = f"""
        SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, 
            arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
        FROM {self.vstore.config.database}.{self.kb_table}
        WHERE user_id = '{user_id}' GROUP BY file_name
        """
        return [r for r in self.vstore.client.query(query).named_results()]

    def add_by_file(
        self, user_id, files: List[UploadedFile], **kwargs
    ):
        data = parse_files(self.parser_api_key, user_id, files)
        data = extract_embedding(self.vstore.embeddings, data)
        self.vstore.client.insert_df(
            self.kb_table,
            pd.DataFrame(data),
            database=self.vstore.config.database,
        )

    def clear(self, user_id):
        self.vstore.client.command(
            f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
            f"WHERE user_id='{user_id}'"
        )
        query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} 
                    WHERE user_id  = '{user_id}'"""
        self.vstore.client.command(query)

    def create_tool(
        self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
    ):
        self.vstore.client.insert_df(
            self.tool_table,
            pd.DataFrame(
                [
                    {
                        "tool_id": hashlib.sha256(
                            (user_id + tool_name).encode("utf-8")
                        ).hexdigest(),
                        "tool_name": tool_name,
                        "file_names": files,
                        "user_id": user_id,
                        "created_by": datetime.now(),
                        "tool_description": tool_description,
                    }
                ]
            ),
            database=self.vstore.config.database,
        )

    def list_tools(self, user_id, tool_name=None):
        extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
        query = f"""
        SELECT tool_name, tool_description, length(file_names) 
        FROM {self.vstore.config.database}.{self.tool_table}
        WHERE user_id = '{user_id}' {extended_where}
        """
        return [r for r in self.vstore.client.query(query).named_results()]

    def remove_tools(self, user_id, tool_names):
        tool_names = ",".join([f"'{t}'" for t in tool_names])
        query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
                    WHERE user_id  = '{user_id}' AND tool_name IN [{tool_names}]"""
        self.vstore.client.command(query)

    def as_tools(self, user_id, tool_name=None):
        tools = self.list_tools(user_id=user_id, tool_name=tool_name)
        retrievers = {
            t["tool_name"]: create_retriever_tool(
                self.vstore.as_retriever(
                    search_kwargs={
                        "where_str": (
                            f"user_id='{user_id}' "
                            f"""AND file_name IN (
                                SELECT arrayJoin(file_names) FROM (
                                    SELECT file_names 
                                    FROM {self.vstore.config.database}.{self.tool_table}
                                    WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
                        )"""
                        )
                    },
                ),
                name=t["tool_name"],
                description=t["tool_description"],
            )
            for t in tools
        }
        return retrievers