openfree commited on
Commit
c5529eb
·
verified ·
1 Parent(s): 9290b81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -39
app.py CHANGED
@@ -65,38 +65,38 @@ def format_conversation(history, system_prompt):
65
  Flatten chat history and system prompt into a single string.
66
  """
67
  prompt = system_prompt.strip() + "\n"
68
- for msg in history:
69
- if msg["role"] == "user":
70
- prompt += "User: " + msg["content"].strip() + "\n"
71
- elif msg["role"] == "assistant":
72
- prompt += "Assistant: " + msg["content"].strip() + "\n"
73
- else:
74
- prompt += msg["content"].strip() + "\n"
75
  if not prompt.strip().endswith("Assistant:"):
76
  prompt += "Assistant: "
77
  return prompt
78
 
79
  @spaces.GPU(duration=60)
80
- def chat_response(user_msg, chat_history, system_prompt,
81
- model_name, max_tokens, temperature,
82
- top_k, top_p, repeat_penalty):
83
  """
84
- Generates streaming chat responses.
85
  """
86
  cancel_event.clear()
87
- history = list(chat_history or [])
88
- history.append({"role": "user", "content": user_msg})
89
-
90
- # Prepare assistant placeholder
91
- history.append({"role": "assistant", "content": ""})
92
-
 
93
  try:
94
- prompt = format_conversation(history, system_prompt)
95
-
96
  pipe = load_pipeline(model_name)
97
  streamer = TextIteratorStreamer(pipe.tokenizer,
98
  skip_prompt=True,
99
  skip_special_tokens=True)
 
100
  gen_thread = threading.Thread(
101
  target=pipe,
102
  args=(prompt,),
@@ -112,16 +112,18 @@ def chat_response(user_msg, chat_history, system_prompt,
112
  )
113
  gen_thread.start()
114
 
 
115
  assistant_text = ''
116
  for chunk in streamer:
117
  if cancel_event.is_set():
118
  break
119
  assistant_text += chunk
120
- history[-1]["content"] = assistant_text
121
  yield history
 
122
  gen_thread.join()
123
  except Exception as e:
124
- history[-1]["content"] = f"Error: {e}"
125
  yield history
126
  finally:
127
  gc.collect()
@@ -187,12 +189,13 @@ css = """
187
  def get_model_name(full_selection):
188
  return full_selection.split(" - ")[0]
189
 
190
- # Function to handle message submission
191
- def submit_message(msg, history, prompt, model, tok, temp, k, p, rp):
192
- return chat_response(
193
- msg, history, prompt,
194
- get_model_name(model), tok, temp, k, p, rp
195
- ), ""
 
196
 
197
  # ------------------------------
198
  # Gradio UI
@@ -205,6 +208,8 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
205
  </div>
206
  """)
207
 
 
 
208
  with gr.Row():
209
  with gr.Column(scale=3):
210
  with gr.Group(elem_classes="qwen-container"):
@@ -232,9 +237,8 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
232
  cnl = gr.Button("Cancel Generation", elem_classes="button-secondary")
233
 
234
  with gr.Column(scale=7):
235
- chat = gr.Chatbot(type="messages", height=500)
236
  with gr.Row():
237
- txt = gr.Textbox(
238
  placeholder="Type your message and press Enter...",
239
  lines=2,
240
  show_label=False
@@ -248,23 +252,36 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
248
  """)
249
 
250
  # Event handlers
251
- clr.click(fn=lambda: ([], ""), outputs=[chat, txt])
252
  cnl.click(fn=cancel_generation)
253
 
254
- # Handle submission from Enter key
255
- txt.submit(
256
  fn=submit_message,
257
- inputs=[txt, chat, sys_prompt, model_dd, max_tok, temp, k, p, rp],
258
- outputs=[chat, txt],
259
- show_progress=True
 
 
 
 
 
 
 
260
  )
261
 
262
- # Handle submission from Send button
263
  send_btn.click(
264
  fn=submit_message,
265
- inputs=[txt, chat, sys_prompt, model_dd, max_tok, temp, k, p, rp],
266
- outputs=[chat, txt],
267
- show_progress=True
 
 
 
 
 
 
 
268
  )
269
 
270
  if __name__ == "__main__":
 
65
  Flatten chat history and system prompt into a single string.
66
  """
67
  prompt = system_prompt.strip() + "\n"
68
+
69
+ for turn in history:
70
+ user_msg, assistant_msg = turn
71
+ prompt += "User: " + user_msg.strip() + "\n"
72
+ if assistant_msg: # might be None or empty
73
+ prompt += "Assistant: " + assistant_msg.strip() + "\n"
74
+
75
  if not prompt.strip().endswith("Assistant:"):
76
  prompt += "Assistant: "
77
  return prompt
78
 
79
  @spaces.GPU(duration=60)
80
+ def chat_response(user_msg, history, system_prompt,
81
+ model_name, max_tokens, temperature,
82
+ top_k, top_p, repeat_penalty):
83
  """
84
+ Generates streaming chat responses using the standard (user, assistant) format.
85
  """
86
  cancel_event.clear()
87
+
88
+ # Add the user message to history
89
+ history = history + [[user_msg, None]]
90
+
91
+ # Format the conversation for the model
92
+ prompt = format_conversation(history, system_prompt)
93
+
94
  try:
 
 
95
  pipe = load_pipeline(model_name)
96
  streamer = TextIteratorStreamer(pipe.tokenizer,
97
  skip_prompt=True,
98
  skip_special_tokens=True)
99
+
100
  gen_thread = threading.Thread(
101
  target=pipe,
102
  args=(prompt,),
 
112
  )
113
  gen_thread.start()
114
 
115
+ # Stream the response
116
  assistant_text = ''
117
  for chunk in streamer:
118
  if cancel_event.is_set():
119
  break
120
  assistant_text += chunk
121
+ history[-1][1] = assistant_text
122
  yield history
123
+
124
  gen_thread.join()
125
  except Exception as e:
126
+ history[-1][1] = f"Error: {e}"
127
  yield history
128
  finally:
129
  gc.collect()
 
189
  def get_model_name(full_selection):
190
  return full_selection.split(" - ")[0]
191
 
192
+ # Function to clear chat
193
+ def clear_chat():
194
+ return [], ""
195
+
196
+ # Function to handle message submission and clear input
197
+ def submit_message(user_input, history, system_prompt, model_name, max_tokens, temp, k, p, rp):
198
+ return "", history + [[user_input, None]]
199
 
200
  # ------------------------------
201
  # Gradio UI
 
208
  </div>
209
  """)
210
 
211
+ chatbot = gr.Chatbot(height=500)
212
+
213
  with gr.Row():
214
  with gr.Column(scale=3):
215
  with gr.Group(elem_classes="qwen-container"):
 
237
  cnl = gr.Button("Cancel Generation", elem_classes="button-secondary")
238
 
239
  with gr.Column(scale=7):
 
240
  with gr.Row():
241
+ msg = gr.Textbox(
242
  placeholder="Type your message and press Enter...",
243
  lines=2,
244
  show_label=False
 
252
  """)
253
 
254
  # Event handlers
255
+ clr.click(fn=clear_chat, outputs=[chatbot, msg])
256
  cnl.click(fn=cancel_generation)
257
 
258
+ # Handle sending messages and generating responses
259
+ msg.submit(
260
  fn=submit_message,
261
+ inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
262
+ outputs=[msg, chatbot]
263
+ ).then(
264
+ fn=lambda history, prompt, model, tok, temp, k, p, rp:
265
+ chat_response(
266
+ history[-1][0], history[:-1], prompt,
267
+ get_model_name(model), tok, temp, k, p, rp
268
+ ),
269
+ inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
270
+ outputs=chatbot
271
  )
272
 
 
273
  send_btn.click(
274
  fn=submit_message,
275
+ inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
276
+ outputs=[msg, chatbot]
277
+ ).then(
278
+ fn=lambda history, prompt, model, tok, temp, k, p, rp:
279
+ chat_response(
280
+ history[-1][0], history[:-1], prompt,
281
+ get_model_name(model), tok, temp, k, p, rp
282
+ ),
283
+ inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
284
+ outputs=chatbot
285
  )
286
 
287
  if __name__ == "__main__":