Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -48,17 +48,11 @@ def load_model(model_name, progress=gr.Progress()):
|
|
48 |
device_map="auto",
|
49 |
load_in_8bit=True
|
50 |
)
|
51 |
-
elif "llama" in model_name.lower() or "mistral" in model_name.lower():
|
52 |
-
model = AutoModelForCausalLM.from_pretrained(
|
53 |
-
model_name,
|
54 |
-
torch_dtype=torch.float16,
|
55 |
-
device_map="cpu"
|
56 |
-
)
|
57 |
else:
|
58 |
model = AutoModelForCausalLM.from_pretrained(
|
59 |
model_name,
|
60 |
torch_dtype=torch.float16,
|
61 |
-
device_map="
|
62 |
)
|
63 |
|
64 |
if tokenizer.pad_token is None:
|
@@ -87,7 +81,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
87 |
if model is None or tokenizer is None:
|
88 |
return "Veuillez d'abord charger un modèle.", None, None
|
89 |
|
90 |
-
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
91 |
|
92 |
try:
|
93 |
with torch.no_grad():
|
@@ -106,7 +100,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
106 |
prob_text += f"{word}: {prob:.2%}\n"
|
107 |
|
108 |
prob_plot = plot_probabilities(prob_data)
|
109 |
-
attention_plot = plot_attention(inputs["input_ids"][0], last_token_logits)
|
110 |
|
111 |
return prob_text, attention_plot, prob_plot
|
112 |
except Exception as e:
|
@@ -118,7 +112,7 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
118 |
if model is None or tokenizer is None:
|
119 |
return "Veuillez d'abord charger un modèle."
|
120 |
|
121 |
-
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
122 |
|
123 |
try:
|
124 |
with torch.no_grad():
|
|
|
48 |
device_map="auto",
|
49 |
load_in_8bit=True
|
50 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
else:
|
52 |
model = AutoModelForCausalLM.from_pretrained(
|
53 |
model_name,
|
54 |
torch_dtype=torch.float16,
|
55 |
+
device_map="auto"
|
56 |
)
|
57 |
|
58 |
if tokenizer.pad_token is None:
|
|
|
81 |
if model is None or tokenizer is None:
|
82 |
return "Veuillez d'abord charger un modèle.", None, None
|
83 |
|
84 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
85 |
|
86 |
try:
|
87 |
with torch.no_grad():
|
|
|
100 |
prob_text += f"{word}: {prob:.2%}\n"
|
101 |
|
102 |
prob_plot = plot_probabilities(prob_data)
|
103 |
+
attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
|
104 |
|
105 |
return prob_text, attention_plot, prob_plot
|
106 |
except Exception as e:
|
|
|
112 |
if model is None or tokenizer is None:
|
113 |
return "Veuillez d'abord charger un modèle."
|
114 |
|
115 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
116 |
|
117 |
try:
|
118 |
with torch.no_grad():
|