Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,14 @@ import gradio as gr
|
|
2 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
from functools import lru_cache
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
# Cache the model and tokenizer to avoid reloading
|
6 |
@lru_cache(maxsize=1)
|
7 |
def load_model_cached(model_name):
|
@@ -13,9 +21,14 @@ def load_model_cached(model_name):
|
|
13 |
return f"Error loading model: {str(e)}"
|
14 |
|
15 |
# Function to generate a response from the model
|
16 |
-
def chat(
|
17 |
-
if
|
18 |
-
return "Please
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Load the model (cached)
|
21 |
generator = load_model_cached(model_name)
|
@@ -59,13 +72,13 @@ def chat(model_name, user_input, chat_history, system_prompt=""):
|
|
59 |
|
60 |
# Gradio interface
|
61 |
with gr.Blocks() as demo:
|
62 |
-
gr.Markdown("# Chat with
|
63 |
|
64 |
with gr.Row():
|
65 |
-
|
66 |
-
label="
|
67 |
-
|
68 |
-
|
69 |
)
|
70 |
|
71 |
chatbot = gr.Chatbot(label="Chat")
|
@@ -78,7 +91,7 @@ with gr.Blocks() as demo:
|
|
78 |
clear_button = gr.Button("Clear Chat")
|
79 |
|
80 |
# Define the chat function
|
81 |
-
user_input.submit(chat, [
|
82 |
clear_button.click(lambda: [], None, chatbot, queue=False)
|
83 |
|
84 |
# Launch the app
|
|
|
2 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
from functools import lru_cache
|
4 |
|
5 |
+
# Pre-selected small models
|
6 |
+
MODELS = {
|
7 |
+
"SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
8 |
+
"GPT-2 (Small)": "gpt2",
|
9 |
+
"DistilGPT-2": "distilgpt2",
|
10 |
+
"Facebook OPT-125M": "facebook/opt-125m"
|
11 |
+
}
|
12 |
+
|
13 |
# Cache the model and tokenizer to avoid reloading
|
14 |
@lru_cache(maxsize=1)
|
15 |
def load_model_cached(model_name):
|
|
|
21 |
return f"Error loading model: {str(e)}"
|
22 |
|
23 |
# Function to generate a response from the model
|
24 |
+
def chat(selected_model, user_input, chat_history, system_prompt=""):
|
25 |
+
if not selected_model:
|
26 |
+
return "Please select a model from the dropdown.", chat_history
|
27 |
+
|
28 |
+
# Get the model name from the dropdown
|
29 |
+
model_name = MODELS.get(selected_model)
|
30 |
+
if not model_name:
|
31 |
+
return "Invalid model selected.", chat_history
|
32 |
|
33 |
# Load the model (cached)
|
34 |
generator = load_model_cached(model_name)
|
|
|
72 |
|
73 |
# Gradio interface
|
74 |
with gr.Blocks() as demo:
|
75 |
+
gr.Markdown("# Chat with Small Language Models")
|
76 |
|
77 |
with gr.Row():
|
78 |
+
selected_model = gr.Dropdown(
|
79 |
+
label="Select a Model",
|
80 |
+
choices=list(MODELS.keys()),
|
81 |
+
value="SmolLM2-135M-Instruct" # Default model
|
82 |
)
|
83 |
|
84 |
chatbot = gr.Chatbot(label="Chat")
|
|
|
91 |
clear_button = gr.Button("Clear Chat")
|
92 |
|
93 |
# Define the chat function
|
94 |
+
user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot])
|
95 |
clear_button.click(lambda: [], None, chatbot, queue=False)
|
96 |
|
97 |
# Launch the app
|