MisterAI commited on
Commit
93a327d
·
verified ·
1 Parent(s): e475772

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://huggingface.co/spaces/MisterAI/Docker_AutoTrain_02
2
+ #app.py_01
3
+ #just POC
4
+
5
+
6
+ import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ from datasets import load_dataset
9
+ import torch
10
+
11
+ # Interface Gradio avec gr.Blocks
12
+ with gr.Blocks() as demo:
13
+ gr.Markdown("# Entraînement de modèle de transformateur")
14
+
15
+ # Bloc pour sélectionner le modèle à entraîner
16
+ model_name = gr.Textbox(label="Nom du modèle à entraîner")
17
+ model_name.placeholder = "Nom du modèle à entraîner"
18
+ model_name.value = "MisterAI/AIForce3"
19
+
20
+ # Bloc pour sélectionner le jeu de données à utiliser
21
+ dataset_path = gr.Textbox(label="Chemin du jeu de données")
22
+ dataset_path.placeholder = "Chemin du jeu de données"
23
+ dataset_path.value = "path/to/your/dataset"
24
+
25
+ # Bloc pour entrer le nom du modèle une fois qu'il est entraîné
26
+ model_name_checked = gr.Textbox(label="Nom du modèle entraîné")
27
+ model_name_checked.placeholder = "Nom du modèle entraîné"
28
+ model_name_checked.value = "Mistral-7B-Instruct-v0.3"
29
+
30
+ # Bloc pour entrer l'emplacement où enregistrer le modèle entraîné
31
+ model_path = gr.Textbox(label="Emplacement pour enregistrer le modèle entraîné")
32
+ model_path.placeholder = "Emplacement pour enregistrer le modèle entraîné"
33
+ model_path.value = "path/to/save/directory"
34
+
35
+ # Bouton pour lancer l'entraînement
36
+ submit = gr.Button("Lancer l'entraînement")
37
+
38
+ # Bloc pour afficher les résultats de l'entraînement
39
+ results = gr.Textbox(label="Résultats de l'entraînement")
40
+ results.placeholder = "Résultats de l'entraînement"
41
+ results.value = ""
42
+
43
+ # Fonction pour entraîner le modèle
44
+ def train_model(model_name, dataset_path, model_name_checked, model_path):
45
+ # Charger le modèle pré-entraîné
46
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
47
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
48
+
49
+ # Charger le jeu de données
50
+ dataset = load_dataset(dataset_path)
51
+
52
+ # Prétraiter les données pour l'entraînement
53
+ def preprocess_function(examples):
54
+ input_text = tokenizer.batch_encode([example["input_text"] for example in examples])
55
+ target_text = tokenizer.batch_encode([example["target_text"] for example in examples])
56
+ return {"input_ids": input_text, "attention_mask": input_text, "labels": target_text}
57
+ dataset = dataset.map(preprocess_function, batched=True)
58
+
59
+ # Configurer l'entraînement
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ model.to(device)
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
63
+ criterion = torch.nn.CrossEntropyLoss()
64
+
65
+ # Entraîner le modèle
66
+ num_epochs = 3
67
+ for epoch in range(num_epochs):
68
+ for batch in dataset:
69
+ input_ids = batch["input_ids"].to(device)
70
+ attention_mask = batch["attention_mask"].to(device)
71
+ labels = batch["labels"].to(device)
72
+
73
+ optimizer.zero_grad()
74
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
75
+ loss = outputs.loss
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ print(f"Epoch {epoch+1}/{num_epochs}, Loss = {loss.item():.4f}")
80
+
81
+ # Enregistrer le modèle entraîné
82
+ model.save_pretrained(model_path)
83
+
84
+ # Afficher les résultats de l'entraînement
85
+ results.value = f"Modèle entraîné avec succès !\nNom du modèle : {model_name_checked}\nEmplacement : {model_path}"
86
+
87
+ # Associer la fonction d'entraînement du modèle au bouton de soumission
88
+ submit.click(fn=train_model, inputs=[model_name, dataset_path, model_name_checked, model_path], outputs=results)
89
+
90
+ # Lancer l'interface Gradio
91
+ demo.launch()