from openai import OpenAI
import logging
from typing import List
import os 
from constant import HYPERBOLIC_MODELS, MODEL_MAPPING


def model_name_mapping(model_name):
    model_mapping = MODEL_MAPPING   
    if model_name in model_mapping:
        return model_mapping[model_name]
    else:
        raise ValueError("Invalid model name:", model_name)


def urial_template(urial_prompt, history, message):
    current_prompt = urial_prompt + "\n"
    for user_msg, ai_msg in history:
        current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
    current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
    return current_prompt

def chat_template(history, message):
    messages = [] 
    for user_msg, ai_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": ai_msg})
    messages.append({"role": "user", "content": message})
    return messages

def openai_base_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    prompt: str=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  

    if model in HYPERBOLIC_MODELS:
        BASE_URL = "https://api.hyperbolic.xyz/v1"
        DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
    else:
        BASE_URL = "https://api.together.xyz/v1"
        DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 

    if api_key is None:
        api_key = DEFAULT_API_KEY
    client = OpenAI(api_key=api_key, base_url=BASE_URL) 
    logging.info(f"Requesting base completion from OpenAI API with model {model}")
    logging.info(f"Prompt: {prompt}")
    logging.info(f"Temperature: {temperature}")
    logging.info(f"Max tokens: {max_tokens}")
    logging.info(f"Top-p: {top_p}")
    logging.info(f"Repetition penalty: {repetition_penalty}")
    logging.info(f"Stop: {stop}")

    request = client.completions.create(
        model=model, 
        prompt=prompt,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    ) 
    
    return request 



def openai_chat_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    messages=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  

    if model in HYPERBOLIC_MODELS:
        BASE_URL = "https://api.hyperbolic.xyz/v1"
        DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
    else:
        BASE_URL = "https://api.together.xyz/v1"
        DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 

    if api_key is None:
        api_key = DEFAULT_API_KEY
    
    logging.info(f"Requesting chat completion from OpenAI API with model {model}")

    client = OpenAI(api_key=api_key, base_url=BASE_URL)  

    request = client.chat.completions.create(
        model=model, 
        messages=messages,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    )  
    return request