#https://huggingface.co/spaces/MisterAI/Docker_AutoTrain_02
#app.py_01
#just POC


import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch

# Interface Gradio avec gr.Blocks
with gr.Blocks() as demo:
    gr.Markdown("# Entraînement de modèle de transformateur")

    # Bloc pour sélectionner le modèle à entraîner
    model_name = gr.Textbox(label="Nom du modèle à entraîner")
    model_name.placeholder = "Nom du modèle à entraîner"
    model_name.value = "MisterAI/AIForce3"

    # Bloc pour sélectionner le jeu de données à utiliser
    dataset_path = gr.Textbox(label="Chemin du jeu de données")
    dataset_path.placeholder = "Chemin du jeu de données"
    dataset_path.value = "path/to/your/dataset"

    # Bloc pour entrer le nom du modèle une fois qu'il est entraîné
    model_name_checked = gr.Textbox(label="Nom du modèle entraîné")
    model_name_checked.placeholder = "Nom du modèle entraîné"
    model_name_checked.value = "Mistral-7B-Instruct-v0.3"

    # Bloc pour entrer l'emplacement où enregistrer le modèle entraîné
    model_path = gr.Textbox(label="Emplacement pour enregistrer le modèle entraîné")
    model_path.placeholder = "Emplacement pour enregistrer le modèle entraîné"
    model_path.value = "path/to/save/directory"

    # Bouton pour lancer l'entraînement
    submit = gr.Button("Lancer l'entraînement")

    # Bloc pour afficher les résultats de l'entraînement
    results = gr.Textbox(label="Résultats de l'entraînement")
    results.placeholder = "Résultats de l'entraînement"
    results.value = ""

    # Fonction pour entraîner le modèle
    def train_model(model_name, dataset_path, model_name_checked, model_path):
        # Charger le modèle pré-entraîné
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # Charger le jeu de données
        dataset = load_dataset(dataset_path)

        # Prétraiter les données pour l'entraînement
        def preprocess_function(examples):
            input_text = tokenizer.batch_encode([example["input_text"] for example in examples])
            target_text = tokenizer.batch_encode([example["target_text"] for example in examples])
            return {"input_ids": input_text, "attention_mask": input_text, "labels": target_text}
        dataset = dataset.map(preprocess_function, batched=True)

        # Configurer l'entraînement
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
        criterion = torch.nn.CrossEntropyLoss()

        # Entraîner le modèle
        num_epochs = 3
        for epoch in range(num_epochs):
            for batch in dataset:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                optimizer.zero_grad()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                loss.backward()
                optimizer.step()

                print(f"Epoch {epoch+1}/{num_epochs}, Loss = {loss.item():.4f}")

        # Enregistrer le modèle entraîné
        model.save_pretrained(model_path)

        # Afficher les résultats de l'entraînement
        results.value = f"Modèle entraîné avec succès !\nNom du modèle : {model_name_checked}\nEmplacement : {model_path}"

    # Associer la fonction d'entraînement du modèle au bouton de soumission
    submit.click(fn=train_model, inputs=[model_name, dataset_path, model_name_checked, model_path], outputs=results)

# Lancer l'interface Gradio
demo.launch()