from fastapi import FastAPI, HTTPException, Depends, Header, Request
from pydantic import BaseModel
import os
import logging
import time
from langchain_community.llms import LlamaCpp
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)

# API keys from .env
API_KEYS = {
    "user1": os.getenv("API_KEY_USER1"),
    "user2": os.getenv("API_KEY_USER2"),
}

app = FastAPI()

# API Key Authentication
def verify_api_key(request: Request, api_key: str = Header(None, alias="X-API-Key")):
    logging.info(f"Received Headers: {request.headers}")
    if not api_key:
        raise HTTPException(status_code=401, detail="API key is missing")

    api_key = api_key.strip()
    if api_key not in API_KEYS.values():
        raise HTTPException(status_code=401, detail="Invalid API key")

    return api_key

# OpenAI-compatible request format
class OpenAIRequest(BaseModel):
    model: str
    messages: list
    stream: bool = False

# Initialize LangChain with Llama.cpp
def get_llm():
    model_path =  "/app/Meta-Llama-3-8B-Instruct.Q4_0.gguf"
    return LlamaCpp(model_path=model_path, n_ctx=2048)

@app.post("/v1/chat/completions")
def generate_text(request: OpenAIRequest, api_key: str = Depends(verify_api_key)):
    try:
        llm = get_llm()

        # Extract last user message
        user_message = next((msg["content"] for msg in reversed(request.messages) if msg["role"] == "user"), None)
        if not user_message:
            raise HTTPException(status_code=400, detail="User message is required")

        response_text = llm.invoke(user_message)

        response = {
            "id": "chatcmpl-123",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": request.model,
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": response_text},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": len(user_message.split()),
                "completion_tokens": len(response_text.split()),
                "total_tokens": len(user_message.split()) + len(response_text.split()),
            }
        }

        return response

    except Exception as e:
        logging.error(f"Error generating response: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")