Qwen3-8B / app.py
openfree's picture
Update app.py
ecffdea verified
import os
import time
import gc
import threading
from datetime import datetime
import gradio as gr
import torch
from transformers import pipeline, TextIteratorStreamer
import spaces # Import spaces early to enable ZeroGPU support
# ------------------------------
# Global Cancellation Event
# ------------------------------
cancel_event = threading.Event()
# ------------------------------
# Qwen3 Model Definitions
# ------------------------------
MODELS = {
"Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B - Largest model with highest capabilities"}
}
# Global cache for pipelines to avoid re-loading.
PIPELINES = {}
def load_pipeline(model_name):
"""
Load and cache a transformers pipeline for text generation.
Tries bfloat16, falls back to float16 or float32 if unsupported.
"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
repo = MODELS[model_name]["repo_id"]
for dtype in (torch.bfloat16, torch.float16, torch.float32):
try:
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=repo,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
except Exception:
continue
# Final fallback
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=repo,
trust_remote_code=True,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
def format_conversation(history, system_prompt):
"""
Flatten chat history and system prompt into a single string.
"""
prompt = system_prompt.strip() + "\n"
for user_msg, assistant_msg in history:
prompt += "User: " + user_msg.strip() + "\n"
if assistant_msg: # might be None or empty
prompt += "Assistant: " + assistant_msg.strip() + "\n"
prompt += "Assistant: "
return prompt
# Function to get just the model name from the dropdown selection
def get_model_name(full_selection):
return full_selection.split(" - ")[0]
# User input handling function
def user_input(user_message, history):
return "", history + [(user_message, None)]
@spaces.GPU(duration=60)
def bot_response(history, system_prompt, model_selection, max_tokens, temperature, top_k, top_p, repetition_penalty):
"""
Generate AI response to user input
"""
cancel_event.clear()
# Extract the latest user message
user_message = history[-1][0]
history_without_last = history[:-1]
# Get model name from selection
model_name = get_model_name(model_selection)
# Format the conversation
conversation = format_conversation(history_without_last, system_prompt)
conversation += "User: " + user_message + "\nAssistant: "
try:
pipe = load_pipeline(model_name)
response = pipe(
conversation,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
return_full_text=False
)[0]["generated_text"]
# Update the last message pair with the response
history[-1] = (user_message, response)
return history
except Exception as e:
history[-1] = (user_message, f"Error: {e}")
return history
finally:
gc.collect()
def get_default_system_prompt():
today = datetime.now().strftime('%Y-%m-%d')
return f"""You are Qwen3, a helpful and friendly AI assistat. Be concise, accurate, and helpful in your responses."""
def clear_chat():
return []
# CSS for improved visual style
css = """
.gradio-container {
background-color: #f5f7fb !important;
}
.qwen-header {
background: linear-gradient(90deg, #0099FF, #0066CC);
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
text-align: center;
color: white;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.qwen-container {
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
background: white;
padding: 20px;
margin-bottom: 20px;
}
.controls-container {
background: #f0f4fa;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
}
.model-select {
border: 2px solid #0099FF !important;
border-radius: 8px !important;
}
.button-primary {
background-color: #0099FF !important;
color: white !important;
}
.button-secondary {
background-color: #6c757d !important;
color: white !important;
}
.footer {
text-align: center;
margin-top: 20px;
font-size: 0.8em;
color: #666;
}
"""
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(title="Qwen3 Chat", css=css) as demo:
gr.HTML("""
<div class="qwen-header">
<h1>🤖 Qwen3 Chat</h1>
<p>Interact with Alibaba Cloud's Qwen3 language models</p>
</div>
""")
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="qwen-container"):
model_dd = gr.Dropdown(
label="Select Qwen3 Model",
choices=[f"{k} - {v['description']}" for k, v in MODELS.items()],
value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}",
elem_classes="model-select"
)
with gr.Group(elem_classes="controls-container"):
gr.Markdown("### ⚙️ Generation Parameters")
sys_prompt = gr.Textbox(label="System Prompt", lines=5, value=get_default_system_prompt())
with gr.Row():
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens")
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
with gr.Row():
k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary")
with gr.Column(scale=7):
chatbot = gr.Chatbot()
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
lines=2
)
submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary")
gr.HTML("""
<div class="footer">
<p>Qwen3 models developed by Alibaba Cloud. Interface powered by Gradio and ZeroGPU.</p>
</div>
""")
# Connect UI elements to functions
submit_btn.click(
user_input,
inputs=[txt, chatbot],
outputs=[txt, chatbot],
queue=False
).then(
bot_response,
inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
outputs=chatbot,
api_name="generate"
)
txt.submit(
user_input,
inputs=[txt, chatbot],
outputs=[txt, chatbot],
queue=False
).then(
bot_response,
inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp],
outputs=chatbot,
api_name="generate"
)
clear_btn.click(
clear_chat,
outputs=[chatbot],
queue=False
)
if __name__ == "__main__":
demo.launch()