from flask import Flask, request, jsonify, Response
from functools import wraps
import uuid
import json
from typing import List, Optional
from pydantic import BaseModel, ValidationError
from API_provider import API_Inference
from core_logic import (
    check_api_key_validity,
    update_request_count,
    get_rate_limit_status,
    get_subscription_status,
    get_available_models,
    get_model_info,
)

app = Flask(__name__)

class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 4000
    temperature: Optional[float] = 0.5
    top_p: Optional[float] = 0.95

def get_api_key():
    auth_header = request.headers.get('Authorization')
    if not auth_header or not auth_header.startswith('Bearer '):
        return None
    return auth_header.split(' ')[1]

def requires_api_key(func):
    @wraps(func)
    def decorated(*args, **kwargs):
        api_key = get_api_key()
        if not api_key:
            return jsonify({'detail': 'Not authenticated'}), 401
        kwargs['api_key'] = api_key
        return func(*args, **kwargs)
    return decorated

@app.route('/')
def index():
    return 'Hello, World!'

@app.route('/chat/completions', methods=['POST', 'GET'])
@requires_api_key
def chat_completions(api_key):
    print("requess received")
    try:
        logging.info("Received request for chat completions")
        # Parse and validate request data
        try:
            data = request.get_json()
            chat_request = ChatCompletionRequest(**data)
        except ValidationError as e:
            return jsonify({'detail': e.errors()}), 400

        # Check API key validity and rate limit
        is_valid, error_message = check_api_key_validity(api_key)
        if not is_valid:
            return jsonify({'detail': error_message}), 401

        messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]

        # Get model info
        model_info = get_model_info(chat_request.model)
        if not model_info:
            return jsonify({'detail': 'Invalid model specified'}), 400

        # Model mapping
        model_mapping = {
            "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
            "claude-3.5-sonnet": "claude-3-sonnet-20240229",
        }
        model_name = model_mapping.get(chat_request.model, chat_request.model)
        credits_reduction = {
            "gpt-4o": 1,
            "claude-3-sonnet-20240229": 1,
            "gemini-1.5-pro": 1,
            "gemini-1-5-flash": 1,
            "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1,
            "o1-mini": 2,
            "o1-preview": 3,
        }.get(model_name, 0)

        if chat_request.stream:
            def generate():
                try:
                    for chunk in API_Inference(messages, model=model_name, stream=True,
                                               max_tokens=chat_request.max_tokens,
                                               temperature=chat_request.temperature,
                                               top_p=chat_request.top_p):
                        data = json.dumps({'choices': [{'delta': {'content': chunk}}]})
                        yield f"data: {data}\n\n"
                    yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n"
                    update_request_count(api_key, credits_reduction)
                except Exception as e:
                    yield f"data: [ERROR] {str(e)}\n\n"

            return Response(generate(), mimetype='text/event-stream')
        else:
            response = API_Inference(messages, model=model_name, stream=False,
                                     max_tokens=chat_request.max_tokens,
                                     temperature=chat_request.temperature,
                                     top_p=chat_request.top_p)
            update_request_count(api_key, credits_reduction)
            prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
            completion_tokens = len(response.split())
            total_tokens = prompt_tokens + completion_tokens
            return jsonify({
                "id": f"chatcmpl-{str(uuid.uuid4())}",
                "object": "chat.completion",
                "created": int(uuid.uuid1().time // 1e7),
                "model": model_name,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": response
                        },
                        "finish_reason": "stop"
                    }
                ],
                "usage": {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": total_tokens
                },
                "credits_used": credits_reduction
            })
    except Exception as e:
        return jsonify({'detail': str(e)}), 500

@app.route('/rate_limit/status', methods=['GET'])
@requires_api_key
def get_rate_limit_status_endpoint(api_key):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        return jsonify({'detail': error_message}), 401
    return jsonify(get_rate_limit_status(api_key))

@app.route('/subscription/status', methods=['GET'])
@requires_api_key
def get_subscription_status_endpoint(api_key):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        return jsonify({'detail': error_message}), 401
    return jsonify(get_subscription_status(api_key))

@app.route('/models', methods=['GET'])
@requires_api_key
def get_available_models_endpoint(api_key):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        return jsonify({'detail': error_message}), 401
    return jsonify({"data": [{"id": model} for model in get_available_models().values()]})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8000)