|
import os |
|
import time |
|
import gc |
|
import threading |
|
from datetime import datetime |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, TextIteratorStreamer |
|
import spaces |
|
|
|
|
|
|
|
|
|
cancel_event = threading.Event() |
|
|
|
|
|
|
|
|
|
MODELS = { |
|
"Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B - Largest model with highest capabilities"}, |
|
"Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B - Good balance of performance and efficiency"}, |
|
"Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1.7B", "description": "Qwen3-1.7B - Smaller model for faster responses"}, |
|
"Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B - Ultra-lightweight model"} |
|
} |
|
|
|
|
|
PIPELINES = {} |
|
|
|
def load_pipeline(model_name): |
|
""" |
|
Load and cache a transformers pipeline for text generation. |
|
Tries bfloat16, falls back to float16 or float32 if unsupported. |
|
""" |
|
global PIPELINES |
|
if model_name in PIPELINES: |
|
return PIPELINES[model_name] |
|
repo = MODELS[model_name]["repo_id"] |
|
for dtype in (torch.bfloat16, torch.float16, torch.float32): |
|
try: |
|
pipe = pipeline( |
|
task="text-generation", |
|
model=repo, |
|
tokenizer=repo, |
|
trust_remote_code=True, |
|
torch_dtype=dtype, |
|
device_map="auto" |
|
) |
|
PIPELINES[model_name] = pipe |
|
return pipe |
|
except Exception: |
|
continue |
|
|
|
pipe = pipeline( |
|
task="text-generation", |
|
model=repo, |
|
tokenizer=repo, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
PIPELINES[model_name] = pipe |
|
return pipe |
|
|
|
def format_conversation(history, system_prompt): |
|
""" |
|
Flatten chat history and system prompt into a single string. |
|
""" |
|
prompt = system_prompt.strip() + "\n" |
|
|
|
for turn in history: |
|
user_msg, assistant_msg = turn |
|
prompt += "User: " + user_msg.strip() + "\n" |
|
if assistant_msg: |
|
prompt += "Assistant: " + assistant_msg.strip() + "\n" |
|
|
|
if not prompt.strip().endswith("Assistant:"): |
|
prompt += "Assistant: " |
|
return prompt |
|
|
|
@spaces.GPU(duration=60) |
|
def chat_response(user_msg, history, system_prompt, |
|
model_name, max_tokens, temperature, |
|
top_k, top_p, repeat_penalty): |
|
""" |
|
Generates streaming chat responses using the standard (user, assistant) format. |
|
""" |
|
cancel_event.clear() |
|
|
|
|
|
history = history + [[user_msg, None]] |
|
|
|
|
|
prompt = format_conversation(history, system_prompt) |
|
|
|
try: |
|
pipe = load_pipeline(model_name) |
|
streamer = TextIteratorStreamer(pipe.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True) |
|
|
|
gen_thread = threading.Thread( |
|
target=pipe, |
|
args=(prompt,), |
|
kwargs={ |
|
'max_new_tokens': max_tokens, |
|
'temperature': temperature, |
|
'top_k': top_k, |
|
'top_p': top_p, |
|
'repetition_penalty': repeat_penalty, |
|
'streamer': streamer, |
|
'return_full_text': False |
|
} |
|
) |
|
gen_thread.start() |
|
|
|
|
|
assistant_text = '' |
|
for chunk in streamer: |
|
if cancel_event.is_set(): |
|
break |
|
assistant_text += chunk |
|
history[-1][1] = assistant_text |
|
yield history |
|
|
|
gen_thread.join() |
|
except Exception as e: |
|
history[-1][1] = f"Error: {e}" |
|
yield history |
|
finally: |
|
gc.collect() |
|
|
|
def cancel_generation(): |
|
cancel_event.set() |
|
return 'Generation cancelled.' |
|
|
|
def get_default_system_prompt(): |
|
today = datetime.now().strftime('%Y-%m-%d') |
|
return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud. |
|
Today is {today}. |
|
Be concise, accurate, and helpful in your responses.""" |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
background-color: #f5f7fb !important; |
|
} |
|
.qwen-header { |
|
background: linear-gradient(90deg, #0099FF, #0066CC); |
|
padding: 20px; |
|
border-radius: 10px; |
|
margin-bottom: 20px; |
|
text-align: center; |
|
color: white; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
} |
|
.qwen-container { |
|
border-radius: 10px; |
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); |
|
background: white; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
} |
|
.controls-container { |
|
background: #f0f4fa; |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin-bottom: 15px; |
|
} |
|
.model-select { |
|
border: 2px solid #0099FF !important; |
|
border-radius: 8px !important; |
|
} |
|
.button-primary { |
|
background-color: #0099FF !important; |
|
color: white !important; |
|
} |
|
.button-secondary { |
|
background-color: #6c757d !important; |
|
color: white !important; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 20px; |
|
font-size: 0.8em; |
|
color: #666; |
|
} |
|
""" |
|
|
|
|
|
def get_model_name(full_selection): |
|
return full_selection.split(" - ")[0] |
|
|
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
|
|
def submit_message(user_input, history, system_prompt, model_name, max_tokens, temp, k, p, rp): |
|
return "", history + [[user_input, None]] |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Qwen3 Chat", css=css) as demo: |
|
gr.HTML(""" |
|
<div class="qwen-header"> |
|
<h1>🤖 Qwen3 Chat</h1> |
|
<p>Interact with Alibaba Cloud's Qwen3 language models</p> |
|
</div> |
|
""") |
|
|
|
chatbot = gr.Chatbot(height=500) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Group(elem_classes="qwen-container"): |
|
model_dd = gr.Dropdown( |
|
label="Select Qwen3 Model", |
|
choices=[f"{k} - {v['description']}" for k, v in MODELS.items()], |
|
value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}", |
|
elem_classes="model-select" |
|
) |
|
|
|
with gr.Group(elem_classes="controls-container"): |
|
gr.Markdown("### ⚙️ Generation Parameters") |
|
sys_prompt = gr.Textbox(label="System Prompt", lines=5, value=get_default_system_prompt()) |
|
with gr.Row(): |
|
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") |
|
with gr.Row(): |
|
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") |
|
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") |
|
with gr.Row(): |
|
k = gr.Slider(1, 100, value=40, step=1, label="Top-K") |
|
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") |
|
|
|
with gr.Row(): |
|
clr = gr.Button("Clear Chat", elem_classes="button-secondary") |
|
cnl = gr.Button("Cancel Generation", elem_classes="button-secondary") |
|
|
|
with gr.Column(scale=7): |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Type your message and press Enter...", |
|
lines=2, |
|
show_label=False |
|
) |
|
send_btn = gr.Button("Send", variant="primary", elem_classes="button-primary") |
|
|
|
gr.HTML(""" |
|
<div class="footer"> |
|
<p>Qwen3 models developed by Alibaba Cloud. Interface powered by Gradio and ZeroGPU.</p> |
|
</div> |
|
""") |
|
|
|
|
|
clr.click(fn=clear_chat, outputs=[chatbot, msg]) |
|
cnl.click(fn=cancel_generation) |
|
|
|
|
|
msg.submit( |
|
fn=submit_message, |
|
inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], |
|
outputs=[msg, chatbot] |
|
).then( |
|
fn=lambda history, prompt, model, tok, temp, k, p, rp: |
|
chat_response( |
|
history[-1][0], history[:-1], prompt, |
|
get_model_name(model), tok, temp, k, p, rp |
|
), |
|
inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], |
|
outputs=chatbot |
|
) |
|
|
|
send_btn.click( |
|
fn=submit_message, |
|
inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], |
|
outputs=[msg, chatbot] |
|
).then( |
|
fn=lambda history, prompt, model, tok, temp, k, p, rp: |
|
chat_response( |
|
history[-1][0], history[:-1], prompt, |
|
get_model_name(model), tok, temp, k, p, rp |
|
), |
|
inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], |
|
outputs=chatbot |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |