openfree commited on
Commit
31fd291
·
verified ·
1 Parent(s): c5529eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -88
app.py CHANGED
@@ -66,65 +66,42 @@ def format_conversation(history, system_prompt):
66
  """
67
  prompt = system_prompt.strip() + "\n"
68
 
69
- for turn in history:
70
- user_msg, assistant_msg = turn
71
  prompt += "User: " + user_msg.strip() + "\n"
72
  if assistant_msg: # might be None or empty
73
  prompt += "Assistant: " + assistant_msg.strip() + "\n"
74
 
75
- if not prompt.strip().endswith("Assistant:"):
76
- prompt += "Assistant: "
77
  return prompt
78
 
79
- @spaces.GPU(duration=60)
80
- def chat_response(user_msg, history, system_prompt,
81
- model_name, max_tokens, temperature,
82
- top_k, top_p, repeat_penalty):
83
  """
84
- Generates streaming chat responses using the standard (user, assistant) format.
85
  """
86
  cancel_event.clear()
 
87
 
88
- # Add the user message to history
89
- history = history + [[user_msg, None]]
90
-
91
- # Format the conversation for the model
92
- prompt = format_conversation(history, system_prompt)
93
 
94
  try:
95
  pipe = load_pipeline(model_name)
96
- streamer = TextIteratorStreamer(pipe.tokenizer,
97
- skip_prompt=True,
98
- skip_special_tokens=True)
99
-
100
- gen_thread = threading.Thread(
101
- target=pipe,
102
- args=(prompt,),
103
- kwargs={
104
- 'max_new_tokens': max_tokens,
105
- 'temperature': temperature,
106
- 'top_k': top_k,
107
- 'top_p': top_p,
108
- 'repetition_penalty': repeat_penalty,
109
- 'streamer': streamer,
110
- 'return_full_text': False
111
- }
112
- )
113
- gen_thread.start()
114
-
115
- # Stream the response
116
- assistant_text = ''
117
- for chunk in streamer:
118
- if cancel_event.is_set():
119
- break
120
- assistant_text += chunk
121
- history[-1][1] = assistant_text
122
- yield history
123
 
124
- gen_thread.join()
 
 
125
  except Exception as e:
126
- history[-1][1] = f"Error: {e}"
127
- yield history
128
  finally:
129
  gc.collect()
130
 
@@ -189,14 +166,6 @@ css = """
189
  def get_model_name(full_selection):
190
  return full_selection.split(" - ")[0]
191
 
192
- # Function to clear chat
193
- def clear_chat():
194
- return [], ""
195
-
196
- # Function to handle message submission and clear input
197
- def submit_message(user_input, history, system_prompt, model_name, max_tokens, temp, k, p, rp):
198
- return "", history + [[user_input, None]]
199
-
200
  # ------------------------------
201
  # Gradio UI
202
  # ------------------------------
@@ -208,8 +177,6 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
208
  </div>
209
  """)
210
 
211
- chatbot = gr.Chatbot(height=500)
212
-
213
  with gr.Row():
214
  with gr.Column(scale=3):
215
  with gr.Group(elem_classes="qwen-container"):
@@ -232,18 +199,17 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
232
  k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
233
  rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
234
 
235
- with gr.Row():
236
- clr = gr.Button("Clear Chat", elem_classes="button-secondary")
237
- cnl = gr.Button("Cancel Generation", elem_classes="button-secondary")
238
 
239
  with gr.Column(scale=7):
 
240
  with gr.Row():
241
- msg = gr.Textbox(
242
- placeholder="Type your message and press Enter...",
243
- lines=2,
244
- show_label=False
245
  )
246
- send_btn = gr.Button("Send", variant="primary", elem_classes="button-primary")
247
 
248
  gr.HTML("""
249
  <div class="footer">
@@ -251,38 +217,51 @@ with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
251
  </div>
252
  """)
253
 
254
- # Event handlers
255
- clr.click(fn=clear_chat, outputs=[chatbot, msg])
256
- cnl.click(fn=cancel_generation)
257
 
258
- # Handle sending messages and generating responses
259
- msg.submit(
260
- fn=submit_message,
261
- inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
262
- outputs=[msg, chatbot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  ).then(
264
- fn=lambda history, prompt, model, tok, temp, k, p, rp:
265
- chat_response(
266
- history[-1][0], history[:-1], prompt,
267
- get_model_name(model), tok, temp, k, p, rp
268
- ),
269
- inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
270
- outputs=chatbot
271
  )
272
 
273
- send_btn.click(
274
- fn=submit_message,
275
- inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
276
- outputs=[msg, chatbot]
 
277
  ).then(
278
- fn=lambda history, prompt, model, tok, temp, k, p, rp:
279
- chat_response(
280
- history[-1][0], history[:-1], prompt,
281
- get_model_name(model), tok, temp, k, p, rp
282
- ),
283
- inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
284
- outputs=chatbot
285
  )
 
 
286
 
287
  if __name__ == "__main__":
288
  demo.launch()
 
66
  """
67
  prompt = system_prompt.strip() + "\n"
68
 
69
+ for user_msg, assistant_msg in history:
 
70
  prompt += "User: " + user_msg.strip() + "\n"
71
  if assistant_msg: # might be None or empty
72
  prompt += "Assistant: " + assistant_msg.strip() + "\n"
73
 
74
+ prompt += "Assistant: "
 
75
  return prompt
76
 
77
+ def generate_response(user_input, history, system_prompt, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
 
 
 
78
  """
79
+ Generate a complete response (non-streaming).
80
  """
81
  cancel_event.clear()
82
+ full_history = history.copy()
83
 
84
+ # Format conversation for the model
85
+ conversation = format_conversation(full_history, system_prompt)
 
 
 
86
 
87
  try:
88
  pipe = load_pipeline(model_name)
89
+ output = pipe(
90
+ conversation,
91
+ max_new_tokens=max_tokens,
92
+ temperature=temperature,
93
+ top_k=top_k,
94
+ top_p=top_p,
95
+ repetition_penalty=repeat_penalty,
96
+ return_full_text=False
97
+ )[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Return the updated history
100
+ history.append((user_input, output))
101
+ return history
102
  except Exception as e:
103
+ history.append((user_input, f"Error: {e}"))
104
+ return history
105
  finally:
106
  gc.collect()
107
 
 
166
  def get_model_name(full_selection):
167
  return full_selection.split(" - ")[0]
168
 
 
 
 
 
 
 
 
 
169
  # ------------------------------
170
  # Gradio UI
171
  # ------------------------------
 
177
  </div>
178
  """)
179
 
 
 
180
  with gr.Row():
181
  with gr.Column(scale=3):
182
  with gr.Group(elem_classes="qwen-container"):
 
199
  k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
200
  rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
201
 
202
+ clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary")
 
 
203
 
204
  with gr.Column(scale=7):
205
+ chatbot = gr.Chatbot()
206
  with gr.Row():
207
+ txt = gr.Textbox(
208
+ show_label=False,
209
+ placeholder="Type your message here...",
210
+ lines=2
211
  )
212
+ submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary")
213
 
214
  gr.HTML("""
215
  <div class="footer">
 
217
  </div>
218
  """)
219
 
220
+ # Define event handlers
221
+ def user_input(user_message, history):
222
+ return "", history + [(user_message, None)]
223
 
224
+ def bot_response(history, sys_prompt, model, max_tok, temp, k, p, rp):
225
+ user_message = history[-1][0]
226
+ bot_message = generate_response(
227
+ user_message,
228
+ history[:-1],
229
+ sys_prompt,
230
+ get_model_name(model),
231
+ max_tok,
232
+ temp,
233
+ k,
234
+ p,
235
+ rp
236
+ )[-1][1]
237
+
238
+ history[-1] = (user_message, bot_message)
239
+ return history
240
+
241
+ # Connect everything
242
+ submit_btn.click(
243
+ user_input,
244
+ [txt, chatbot],
245
+ [txt, chatbot],
246
+ queue=False
247
  ).then(
248
+ bot_response,
249
+ [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
250
+ [chatbot]
 
 
 
 
251
  )
252
 
253
+ txt.submit(
254
+ user_input,
255
+ [txt, chatbot],
256
+ [txt, chatbot],
257
+ queue=False
258
  ).then(
259
+ bot_response,
260
+ [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
261
+ [chatbot]
 
 
 
 
262
  )
263
+
264
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
265
 
266
  if __name__ == "__main__":
267
  demo.launch()