import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

MODELS = {
    "rubert-tiny2": "cointegrated/rubert-tiny2",
    "sbert": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "LaBSE": "sentence-transformers/LaBSE",
    "ruRoberta": "sberbank-ai/ruRoberta-large"
}

PROMPT_TEMPLATES = {
    "basic": "Товар: {item}. Категория:",
    "examples": "Примеры:\n- Молоток → Инструменты\n- Морковь → Овощи\nТовар: {item} → ",
    "strict": "Выбери категорию из [{categories}]. Товар: {item}. Категория:"
}

def get_embeddings(model, tokenizer, text):
    inputs = tokenizer(text, 
                      padding=True, 
                      truncation=True, 
                      return_tensors="pt",
                      max_length=512)
    outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0].detach().numpy()

def classify(model_name: str, prompt_type: str, item: str, categories: str) -> str:
    tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
    model = AutoModel.from_pretrained(MODELS[model_name])
    
    # Формируем промпт
    prompt = PROMPT_TEMPLATES[prompt_type].format(
        item=item,
        categories=", ".join([c.strip() for c in categories.split(",")])
    )
    
    # Эмбеддинги
    item_embedding = get_embeddings(model, tokenizer, prompt)
    category_embeddings = [
        get_embeddings(model, tokenizer, c.strip()) 
        for c in categories.split(",")
    ]
    
    # Сравнение
    similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0]
    best_idx = np.argmax(similarities)
    
    return f"{categories.split(',')[best_idx].strip()} ({similarities[best_idx]:.2f})"

gr.Interface(
    fn=classify,
    inputs=[
        gr.Dropdown(list(MODELS.keys()), label="Модель"),
        gr.Dropdown(list(PROMPT_TEMPLATES.keys()), label="Шаблон промпта"),
        gr.Textbox(label="Товар"),
        gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
    ],
    outputs=gr.Textbox()
).launch()