# main.py

import os
import streamlit as st
import anthropic
from requests import JSONDecodeError

from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint

from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage

# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url    = st.secrets.SUPABASE_URL
supabase_key    = st.secrets.SUPABASE_KEY
openai_api_key  = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key      = st.secrets.hf_api_key
username        = st.secrets.username

supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)

# ─────── embeddings ─────────────────────────────────────────────────────────────
# Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True},
)

# ─────── vector store + memory ─────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    query_name="match_documents",
    table_name="documents",
)
memory = ConversationBufferMemory(
    memory_key="chat_history",
    input_key="question",
    output_key="answer",
    return_messages=True,
)

# ─────── LLM setup ──────────────────────────────────────────────────────────────
model        = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature  = 0.1
max_tokens   = 500

def response_generator(query: str) -> str:
    """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
    # log usage
    add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
    logger.info("Using HF model %s", model)

    # prepare HF text-generation LLM
    hf = HuggingFaceEndpoint(
        # endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
        endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
        task="text-generation",
        huggingfacehub_api_token=hf_api_key,
        model_kwargs={
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "return_full_text": False,
        },
    )

    # conversational RAG chain
    qa = ConversationalRetrievalChain.from_llm(
        llm=hf,
        retriever=vector_store.as_retriever(
            search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
        ),
        memory=memory,
        verbose=True,
        return_source_documents=True,
    )

    try:
        result = qa({"question": query})
    except JSONDecodeError as e:
        # fallback logging  
        logger.error("Embedding JSONDecodeError: %s", e)
        return "Sorry, I had trouble understanding the embedded data. Please try again."

    answer = result.get("answer", "")
    sources = result.get("source_documents", [])

    if not sources:
        return (
            "I’m sorry, I don’t have enough information to answer that. "
            "If you have a public data source to add, please email copilot@securade.ai."
        )
    return answer

# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
        "Get Help": "https://securade.ai",
        "Report a Bug": "mailto:hello@securade.ai",
    },
)

st.title("👷‍♂️ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
    "Chat with your personal safety assistant about any health & safety related queries. "
    "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
    "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# show history
for msg in st.session_state.chat_history:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# new user input
if prompt := st.chat_input("Ask a question"):
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.spinner("Safety briefing in progress..."):
        answer = response_generator(prompt)

    with st.chat_message("assistant"):
        st.markdown(answer)
    st.session_state.chat_history.append({"role": "assistant", "content": answer})