import uuid
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import json
from API_provider import API_Inference
from core_logic import (
    check_api_key_validity,
    update_request_count,
    get_rate_limit_status,
    get_subscription_status,
    get_available_models,
    get_model_info,
)

app = FastAPI()
security = HTTPBearer()

class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 4000
    temperature: Optional[float] = 0.5
    top_p: Optional[float] = 0.95

def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    return credentials.credentials

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(get_api_key)):
    try:
        # Check API key validity and rate limit
        is_valid, error_message = check_api_key_validity(api_key)
        if not is_valid:
            raise HTTPException(status_code=401, detail=error_message)

        messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
        
        # Get model info
        model_info = get_model_info(request.model)
        if not model_info:
            raise HTTPException(status_code=400, detail="Invalid model specified")
        
        if "meta-llama-405b-turbo" in request.model:
            request.model = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"

        if "claude-3.5-sonnet" in request.model:
            request.model = "claude-3-sonnet-20240229"

        if request.stream:
            def generate():
                for chunk in API_Inference(messages, model=request.model, stream=True,
                                        max_tokens=request.max_tokens,
                                        temperature=request.temperature,
                                        top_p=request.top_p):
                    yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
                yield "data: [DONE]\n\nCredits used: 1\n\n"

            
            # Update request count
            if request.model == "gpt-4o" or request.model == "claude-3-sonnet-20240229" or request.model == "gemini-1.5-pro" or request.model == "gemini-1-5-flash" or request.model == "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo":
                update_request_count(api_key, 1) 

            elif request.model == "o1-mini":
                update_request_count(api_key, 2)

            elif request.model == "o1-preview":
                update_request_count(api_key, 3)
                
            return StreamingResponse(generate(), media_type="text/event-stream")
        else:
            response = API_Inference(messages, model=request.model, stream=False,
                                  max_tokens=request.max_tokens,
                                  temperature=request.temperature,
                                  top_p=request.top_p)
            
            # Update request count
            update_request_count(api_key, 1)  # Assume 1 credit per request, adjust as needed

            return {
                "id": f"chatcmpl-{uuid.uuid4()}",
                "object": "chat.completion",
                "created": int(uuid.uuid1().time // 1e7),
                "model": request.model,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": response
                        },
                        "finish_reason": "stop"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(' '.join(msg['content'] for msg in messages).split()),
                    "completion_tokens": len(response.split()),
                    "total_tokens": len(' '.join(msg['content'] for msg in messages).split()) + len(response.split())
                },
                "credits_used": 1
            }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/rate_limit/status")
async def get_rate_limit_status_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return get_rate_limit_status(api_key)

@app.get("/subscription/status")
async def get_subscription_status_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return get_subscription_status(api_key)

@app.get("/models")
async def get_available_models_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return {"data": [{"id": model} for model in get_available_models().values()]}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)