import uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    trim_messages,
)
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
import json
from typing import Optional, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import InjectedState
from document_rag_router import router as document_rag_router
from document_rag_router import QueryInput, query_collection, SearchResult
from fastapi import HTTPException
import requests
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import re

app = FastAPI()
app.include_router(document_rag_router) 

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@tool
def get_user_age(name: str) -> str:
    """Use this tool to find the user's age."""
    if "bob" in name.lower():
        return "42 years old"
    return "41 years old"

@tool
async def query_documents(
    query: str,
    config: RunnableConfig,
    #state: Annotated[dict, InjectedState]
) -> str:
    """Use this tool to retrieve relevant data from the collection.
    
    Args:
        query: The search query to find relevant document passages
    """
    # Get collection_id and user_id from config
    thread_config = config.get("configurable", {})
    collection_id = thread_config.get("collection_id")
    user_id = thread_config.get("user_id")
    
    if not collection_id or not user_id:
        return "Error: collection_id and user_id are required in the config"
    try:    
        # Create query input
        input_data = QueryInput(
            collection_id=collection_id,
            query=query,
            user_id=user_id,
            top_k=6
        )
        
        response = await query_collection(input_data)
        results = []
        
        # Access response directly since it's a Pydantic model
        for r in response.results:
            result_dict = {
                "text": r.text,
                "distance": r.distance,
                "metadata": {
                    "document_id": r.metadata.get("document_id"),
                    "chunk_index": r.metadata.get("location", {}).get("chunk_index")
                }
            }
            results.append(result_dict)
        
        return str(results)
    
    except Exception as e:
        print(e)
        return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"


async def query_documents_raw(
    query: str,
    config: RunnableConfig,
    #state: Annotated[dict, InjectedState]
) -> SearchResult:
    """Use this tool to retrieve relevant data from the collection.
    
    Args:
        query: The search query to find relevant document passages
    """
    # Get collection_id and user_id from config
    thread_config = config.get("configurable", {})
    collection_id = thread_config.get("collection_id")
    user_id = thread_config.get("user_id")
    
    if not collection_id or not user_id:
        return "Error: collection_id and user_id are required in the config"
    try:    
        # Create query input
        input_data = QueryInput(
            collection_id=collection_id,
            query=query,
            user_id=user_id,
            top_k=6
        )
        
        response = await query_collection(input_data)
        return response.results
    
    except Exception as e:
        print(e)
        return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"

memory = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)

def state_modifier(state) -> list[BaseMessage]:
    return trim_messages(
        state["messages"],
        token_counter=len,
        max_tokens=16000,
        strategy="last",
        start_on="human",
        include_system=True,
        allow_partial=False,
    )

agent = create_react_agent(
    model,
    tools=[query_documents],
    checkpointer=memory,
    state_modifier=state_modifier,
)

class ChatInput(BaseModel):
    message: str
    thread_id: Optional[str] = None
    collection_id: Optional[str] = None
    user_id: Optional[str] = None

@app.post("/chat")
async def chat(input_data: ChatInput):
    thread_id = input_data.thread_id or str(uuid.uuid4())
    
    config = {
        "configurable": {
            "thread_id": thread_id,
            "collection_id": input_data.collection_id,
            "user_id": input_data.user_id
        }
    }
    
    input_message = HumanMessage(content=input_data.message)
    
    async def generate():
        async for event in agent.astream_events(
            {"messages": [input_message]}, 
            config,
            version="v2"
        ):
            kind = event["event"]
            
            if kind == "on_chat_model_stream":
                content = event["data"]["chunk"].content
                if content:
                    yield f"{json.dumps({'type': 'token', 'content': content})}"

            elif kind == "on_tool_start":
                tool_input = str(event['data'].get('input', ''))
                yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
            
            elif kind == "on_tool_end":
                tool_output = str(event['data'].get('output', ''))
                yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
    
    return EventSourceResponse(
        generate(),
        media_type="text/event-stream"
    )

async def clean_tool_input(tool_input: str):
    # Use regex to parse the first key and value
    pattern = r"{\s*'([^']+)':\s*'([^']+)'"
    match = re.search(pattern, tool_input)
    if match:
        key, value = match.groups()
        return {key: value}
    return [tool_input]

async def clean_tool_response(tool_output: str):
    """Clean and extract relevant information from tool response if it contains query_documents."""
    if "query_documents" in tool_output:
        try:
            # First safely evaluate the string as a Python literal
            import ast
            print(tool_output)
            # Extract the list string from the content
            start = tool_output.find("[{")
            end = tool_output.rfind("}]") + 2
            if start >= 0 and end > 0:
                list_str = tool_output[start:end]
                
                # Convert string to Python object using ast.literal_eval
                results = ast.literal_eval(list_str)
                
                # Return only relevant fields
                return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} 
                       for r in results]
                
        except SyntaxError as e:
            print(f"Syntax error in parsing: {e}")
            return f"Error parsing document results: {str(e)}"
        except Exception as e:
            print(f"General error: {e}")
            return f"Error processing results: {str(e)}"
    return tool_output

@app.post("/chat2")
async def chat2(input_data: ChatInput):
    thread_id = input_data.thread_id or str(uuid.uuid4())
    
    config = {
        "configurable": {
            "thread_id": thread_id,
            "collection_id": input_data.collection_id,
            "user_id": input_data.user_id
        }
    }
    
    input_message = HumanMessage(content=input_data.message)
    
    async def generate():
        async for event in agent.astream_events(
            {"messages": [input_message]}, 
            config,
            version="v2"
        ):
            kind = event["event"]
            
            if kind == "on_chat_model_stream":
                content = event["data"]["chunk"].content
                if content:
                    yield f"{json.dumps({'type': 'token', 'content': content})}"

            elif kind == "on_tool_start":
                tool_name = event['name']
                tool_input = event['data'].get('input', '')
                clean_input = await clean_tool_input(str(tool_input))
                yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
            
            elif kind == "on_tool_end":
                if "query_documents" in event['name']:
                    print(event)
                    raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
                    try:
                        serializable_output = [
                            {
                                "text": result.text,
                                "distance": result.distance,
                                "metadata": result.metadata
                            }
                            for result in raw_output
                        ]
                        yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
                    except Exception as e:
                        print(e)
                        yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
                else:
                    tool_name = event['name']
                    raw_output = str(event['data'].get('output', ''))
                    clean_output = await clean_tool_response(raw_output)
                    yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
    
    return EventSourceResponse(
        generate(),
        media_type="text/event-stream"
    )

@app.get("/health")
async def health_check():
    return {"status": "healthy"}