import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # Đặt pad_token_id nếu chưa có
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager"  # Tránh cảnh báo sdpa
    )
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length, state):
    try:
        # Mã hóa đầu vào với attention_mask
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Cập nhật state với kết quả mới
        state.append(generated_text)
        return state, generated_text  # Trả về state và output để hiển thị
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        state.append(error_msg)
        return state, error_msg

# Hàm hiển thị thông tin API
def get_api_info():
    base_url = "https://<your-space-name>.hf.space"
    return (
        "Welcome to Qwen2.5-0.5B API!\n"
        f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
        "Endpoints:\n"
        f"- GET {base_url}/api/health_check (Check API status)\n"
        f"- POST {base_url}/api/generate (Generate text)\n"
        "To use the generate API, send a POST request with JSON:\n"
        '{"0": "your prompt", "1": 150}'
    )

# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
    return "Qwen2.5-0.5B API is running!"

# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
    gr.Markdown("# Qwen2.5-0.5B Text Generator")
    gr.Markdown("Enter a prompt below or use the API!")
    
    # State để lưu trữ lịch sử kết quả
    state = gr.State(value=[])  # Khởi tạo state là danh sách rỗng
    
    # Hiển thị thông tin API
    gr.Markdown("### API Information")
    api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
    
    # Giao diện sinh văn bản
    gr.Markdown("### Generate Text")
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
        max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
    
    generate_button = gr.Button("Generate")
    output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10)
    
    # Liên kết button với hàm generate_text
    generate_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_input, state],
        outputs=[state, output_text]  # Cập nhật cả state và output_text
    )

# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
    fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1],  # Chỉ lấy output, không dùng state cho API
    inputs=["text", "number"],
    outputs="text",
    title="Qwen2.5-0.5B API",
    api_name="/generate"
).queue()

health_interface = gr.Interface(
    fn=health_check,
    inputs=None,
    outputs="text",
    api_name="/health_check"
)

# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])

# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860)