import gradio as gr
from transformers import pipeline
import os
from huggingface_hub import login

# Hugging Face login function
def hf_login(token):
    if token:
        try:
            login(token)
            return "Successfully logged in to Hugging Face Hub"
        except Exception as e:
            return f"Login error: {str(e)}"
    return "No token provided"

# Define all pipelines with lazy loading
def get_pipeline(model_name):
    """Lazy load pipeline only when needed"""
    try:
        if model_name == "GPT-2 Original":
            return pipeline("text-generation", model="gpt2")
        elif model_name == "GPT-2 Medium":
            return pipeline("text-generation", model="gpt2-medium")
        elif model_name == "DistilGPT-2":
            return pipeline("text-generation", model="distilgpt2")
        elif model_name == "German GPT-2":
            return pipeline("text-generation", model="german-nlp-group/german-gpt2")
        elif model_name == "German Wechsel GPT-2":
            return pipeline("text-generation", model="benjamin/gpt2-wechsel-german")
        elif model_name == "T5 Base":
            return pipeline("text2text-generation", model="t5-base")
        elif model_name == "T5 Large":
            return pipeline("text2text-generation", model="t5-large")
        elif model_name == "Text Classification":
            return pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
        elif model_name == "Sentiment Analysis":
            return pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
        else:
            raise ValueError(f"Unknown model: {model_name}")
    except Exception as e:
        raise Exception(f"Error loading model {model_name}: {str(e)}")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    model_name,
    max_tokens,
    temperature,
    top_p,
):
    try:
        # Get the appropriate pipeline
        pipe = get_pipeline(model_name)
        
        # For text generation models
        if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", 
                          "German GPT-2", "German Wechsel GPT-2"]:
            # Prepare full prompt
            full_history = ' '.join([f"User: {msg[0]}\nAssistant: {msg[1] or ''}" for msg in history]) if history else ''
            full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:"
            
            response = pipe(
                full_prompt, 
                max_length=len(full_prompt.split()) + max_tokens,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=1
            )[0]['generated_text']
            
            # Extract just the new assistant response
            assistant_response = response[len(full_prompt):].strip()
            return assistant_response
        
        # For T5 models
        elif model_name in ["T5 Base", "T5 Large"]:
            # T5 doesn't handle chat history the same way, so simplify
            input_text = f"{message}"
            
            response = pipe(
                input_text,
                max_length=max_tokens,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=1
            )[0]['generated_text']
            
            return response
        
        # For classification and sentiment models
        elif model_name == "Text Classification":
            result = pipe(message)[0]
            return f"Classification: {result['label']} (Confidence: {result['score']:.2f})"
        
        elif model_name == "Sentiment Analysis":
            result = pipe(message)[0]
            return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})"
            
    except Exception as e:
        return f"Error: {str(e)}"

def create_interface():
    with gr.Blocks(title="Hugging Face Models Demo") as demo:
        gr.Markdown("# Hugging Face Models Chat Interface")
        
        with gr.Accordion("Hugging Face Login", open=False):
            with gr.Row():
                hf_token = gr.Textbox(label="Enter Hugging Face Token", type="password")
                login_btn = gr.Button("Login")
                login_output = gr.Textbox(label="Login Status")
            login_btn.click(hf_login, inputs=[hf_token], outputs=[login_output])
        
        chat_interface = gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value="You are a helpful assistant.", label="System message"),
                gr.Dropdown(
                    ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", 
                     "German GPT-2", "German Wechsel GPT-2", 
                     "T5 Base", "T5 Large", 
                     "Text Classification", "Sentiment Analysis"], 
                    value="GPT-2 Original", 
                    label="Select Model"
                ),
                gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
                gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-p (nucleus sampling)",
                ),
            ]
        )
    
    return demo

if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)