Woziii commited on
Commit
12f46b7
·
verified ·
1 Parent(s): 9255c5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -48,17 +48,11 @@ def load_model(model_name, progress=gr.Progress()):
48
  device_map="auto",
49
  load_in_8bit=True
50
  )
51
- elif "llama" in model_name.lower() or "mistral" in model_name.lower():
52
- model = AutoModelForCausalLM.from_pretrained(
53
- model_name,
54
- torch_dtype=torch.float16,
55
- device_map="cpu"
56
- )
57
  else:
58
  model = AutoModelForCausalLM.from_pretrained(
59
  model_name,
60
  torch_dtype=torch.float16,
61
- device_map="cpu"
62
  )
63
 
64
  if tokenizer.pad_token is None:
@@ -87,7 +81,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
87
  if model is None or tokenizer is None:
88
  return "Veuillez d'abord charger un modèle.", None, None
89
 
90
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
91
 
92
  try:
93
  with torch.no_grad():
@@ -106,7 +100,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
106
  prob_text += f"{word}: {prob:.2%}\n"
107
 
108
  prob_plot = plot_probabilities(prob_data)
109
- attention_plot = plot_attention(inputs["input_ids"][0], last_token_logits)
110
 
111
  return prob_text, attention_plot, prob_plot
112
  except Exception as e:
@@ -118,7 +112,7 @@ def generate_text(input_text, temperature, top_p, top_k):
118
  if model is None or tokenizer is None:
119
  return "Veuillez d'abord charger un modèle."
120
 
121
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
122
 
123
  try:
124
  with torch.no_grad():
 
48
  device_map="auto",
49
  load_in_8bit=True
50
  )
 
 
 
 
 
 
51
  else:
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_name,
54
  torch_dtype=torch.float16,
55
+ device_map="auto"
56
  )
57
 
58
  if tokenizer.pad_token is None:
 
81
  if model is None or tokenizer is None:
82
  return "Veuillez d'abord charger un modèle.", None, None
83
 
84
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
85
 
86
  try:
87
  with torch.no_grad():
 
100
  prob_text += f"{word}: {prob:.2%}\n"
101
 
102
  prob_plot = plot_probabilities(prob_data)
103
+ attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
104
 
105
  return prob_text, attention_plot, prob_plot
106
  except Exception as e:
 
112
  if model is None or tokenizer is None:
113
  return "Veuillez d'abord charger un modèle."
114
 
115
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
116
 
117
  try:
118
  with torch.no_grad():