openfree commited on
Commit
ce4f027
·
verified ·
1 Parent(s): 31fd291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -49
app.py CHANGED
@@ -74,47 +74,62 @@ def format_conversation(history, system_prompt):
74
  prompt += "Assistant: "
75
  return prompt
76
 
77
- def generate_response(user_input, history, system_prompt, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
 
 
 
 
 
 
 
 
 
78
  """
79
- Generate a complete response (non-streaming).
80
  """
81
  cancel_event.clear()
82
- full_history = history.copy()
83
 
84
- # Format conversation for the model
85
- conversation = format_conversation(full_history, system_prompt)
 
 
 
 
 
 
 
 
86
 
87
  try:
88
  pipe = load_pipeline(model_name)
89
- output = pipe(
90
  conversation,
91
  max_new_tokens=max_tokens,
92
  temperature=temperature,
93
  top_k=top_k,
94
  top_p=top_p,
95
- repetition_penalty=repeat_penalty,
96
  return_full_text=False
97
  )[0]["generated_text"]
98
 
99
- # Return the updated history
100
- history.append((user_input, output))
101
  return history
102
  except Exception as e:
103
- history.append((user_input, f"Error: {e}"))
104
  return history
105
  finally:
106
  gc.collect()
107
 
108
- def cancel_generation():
109
- cancel_event.set()
110
- return 'Generation cancelled.'
111
-
112
  def get_default_system_prompt():
113
  today = datetime.now().strftime('%Y-%m-%d')
114
  return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud.
115
  Today is {today}.
116
  Be concise, accurate, and helpful in your responses."""
117
 
 
 
 
118
  # CSS for improved visual style
119
  css = """
120
  .gradio-container {
@@ -162,10 +177,6 @@ css = """
162
  }
163
  """
164
 
165
- # Function to get just the model name from the dropdown selection
166
- def get_model_name(full_selection):
167
- return full_selection.split(" - ")[0]
168
-
169
  # ------------------------------
170
  # Gradio UI
171
  # ------------------------------
@@ -217,51 +228,36 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
217
  </div>
218
  """)
219
 
220
- # Define event handlers
221
- def user_input(user_message, history):
222
- return "", history + [(user_message, None)]
223
-
224
- def bot_response(history, sys_prompt, model, max_tok, temp, k, p, rp):
225
- user_message = history[-1][0]
226
- bot_message = generate_response(
227
- user_message,
228
- history[:-1],
229
- sys_prompt,
230
- get_model_name(model),
231
- max_tok,
232
- temp,
233
- k,
234
- p,
235
- rp
236
- )[-1][1]
237
-
238
- history[-1] = (user_message, bot_message)
239
- return history
240
-
241
- # Connect everything
242
  submit_btn.click(
243
  user_input,
244
- [txt, chatbot],
245
- [txt, chatbot],
246
  queue=False
247
  ).then(
248
  bot_response,
249
- [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
250
- [chatbot]
 
251
  )
252
 
253
  txt.submit(
254
  user_input,
255
- [txt, chatbot],
256
- [txt, chatbot],
257
  queue=False
258
  ).then(
259
  bot_response,
260
- [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
261
- [chatbot]
 
262
  )
263
 
264
- clear_btn.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
265
 
266
  if __name__ == "__main__":
267
  demo.launch()
 
74
  prompt += "Assistant: "
75
  return prompt
76
 
77
+ # Function to get just the model name from the dropdown selection
78
+ def get_model_name(full_selection):
79
+ return full_selection.split(" - ")[0]
80
+
81
+ # User input handling function
82
+ def user_input(user_message, history):
83
+ return "", history + [(user_message, None)]
84
+
85
+ @spaces.GPU(duration=60)
86
+ def bot_response(history, system_prompt, model_selection, max_tokens, temperature, top_k, top_p, repetition_penalty):
87
  """
88
+ Generate AI response to user input
89
  """
90
  cancel_event.clear()
 
91
 
92
+ # Extract the latest user message
93
+ user_message = history[-1][0]
94
+ history_without_last = history[:-1]
95
+
96
+ # Get model name from selection
97
+ model_name = get_model_name(model_selection)
98
+
99
+ # Format the conversation
100
+ conversation = format_conversation(history_without_last, system_prompt)
101
+ conversation += "User: " + user_message + "\nAssistant: "
102
 
103
  try:
104
  pipe = load_pipeline(model_name)
105
+ response = pipe(
106
  conversation,
107
  max_new_tokens=max_tokens,
108
  temperature=temperature,
109
  top_k=top_k,
110
  top_p=top_p,
111
+ repetition_penalty=repetition_penalty,
112
  return_full_text=False
113
  )[0]["generated_text"]
114
 
115
+ # Update the last message pair with the response
116
+ history[-1] = (user_message, response)
117
  return history
118
  except Exception as e:
119
+ history[-1] = (user_message, f"Error: {e}")
120
  return history
121
  finally:
122
  gc.collect()
123
 
 
 
 
 
124
  def get_default_system_prompt():
125
  today = datetime.now().strftime('%Y-%m-%d')
126
  return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud.
127
  Today is {today}.
128
  Be concise, accurate, and helpful in your responses."""
129
 
130
+ def clear_chat():
131
+ return []
132
+
133
  # CSS for improved visual style
134
  css = """
135
  .gradio-container {
 
177
  }
178
  """
179
 
 
 
 
 
180
  # ------------------------------
181
  # Gradio UI
182
  # ------------------------------
 
228
  </div>
229
  """)
230
 
231
+ # Connect UI elements to functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  submit_btn.click(
233
  user_input,
234
+ inputs=[txt, chatbot],
235
+ outputs=[txt, chatbot],
236
  queue=False
237
  ).then(
238
  bot_response,
239
+ inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
240
+ outputs=chatbot,
241
+ api_name="generate"
242
  )
243
 
244
  txt.submit(
245
  user_input,
246
+ inputs=[txt, chatbot],
247
+ outputs=[txt, chatbot],
248
  queue=False
249
  ).then(
250
  bot_response,
251
+ inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
252
+ outputs=chatbot,
253
+ api_name="generate"
254
  )
255
 
256
+ clear_btn.click(
257
+ clear_chat,
258
+ outputs=[chatbot],
259
+ queue=False
260
+ )
261
 
262
  if __name__ == "__main__":
263
  demo.launch()