Papaya-Voldemort commited on
Commit
b0e67d9
·
verified ·
1 Parent(s): 66a12b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -2,6 +2,14 @@ import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from functools import lru_cache
4
 
 
 
 
 
 
 
 
 
5
  # Cache the model and tokenizer to avoid reloading
6
  @lru_cache(maxsize=1)
7
  def load_model_cached(model_name):
@@ -13,9 +21,14 @@ def load_model_cached(model_name):
13
  return f"Error loading model: {str(e)}"
14
 
15
  # Function to generate a response from the model
16
- def chat(model_name, user_input, chat_history, system_prompt=""):
17
- if model_name.strip() == "":
18
- return "Please enter a valid model name.", chat_history
 
 
 
 
 
19
 
20
  # Load the model (cached)
21
  generator = load_model_cached(model_name)
@@ -59,13 +72,13 @@ def chat(model_name, user_input, chat_history, system_prompt=""):
59
 
60
  # Gradio interface
61
  with gr.Blocks() as demo:
62
- gr.Markdown("# Chat with SmolLM2-135M-Instruct")
63
 
64
  with gr.Row():
65
- model_name = gr.Textbox(
66
- label="Enter Hugging Face Model Name",
67
- value="HuggingFaceTB/SmolLM2-135M-Instruct", # Default model
68
- placeholder="e.g., HuggingFaceTB/SmolLM2-135M-Instruct"
69
  )
70
 
71
  chatbot = gr.Chatbot(label="Chat")
@@ -78,7 +91,7 @@ with gr.Blocks() as demo:
78
  clear_button = gr.Button("Clear Chat")
79
 
80
  # Define the chat function
81
- user_input.submit(chat, [model_name, user_input, chatbot, system_prompt], [user_input, chatbot])
82
  clear_button.click(lambda: [], None, chatbot, queue=False)
83
 
84
  # Launch the app
 
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from functools import lru_cache
4
 
5
+ # Pre-selected small models
6
+ MODELS = {
7
+ "SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
8
+ "GPT-2 (Small)": "gpt2",
9
+ "DistilGPT-2": "distilgpt2",
10
+ "Facebook OPT-125M": "facebook/opt-125m"
11
+ }
12
+
13
  # Cache the model and tokenizer to avoid reloading
14
  @lru_cache(maxsize=1)
15
  def load_model_cached(model_name):
 
21
  return f"Error loading model: {str(e)}"
22
 
23
  # Function to generate a response from the model
24
+ def chat(selected_model, user_input, chat_history, system_prompt=""):
25
+ if not selected_model:
26
+ return "Please select a model from the dropdown.", chat_history
27
+
28
+ # Get the model name from the dropdown
29
+ model_name = MODELS.get(selected_model)
30
+ if not model_name:
31
+ return "Invalid model selected.", chat_history
32
 
33
  # Load the model (cached)
34
  generator = load_model_cached(model_name)
 
72
 
73
  # Gradio interface
74
  with gr.Blocks() as demo:
75
+ gr.Markdown("# Chat with Small Language Models")
76
 
77
  with gr.Row():
78
+ selected_model = gr.Dropdown(
79
+ label="Select a Model",
80
+ choices=list(MODELS.keys()),
81
+ value="SmolLM2-135M-Instruct" # Default model
82
  )
83
 
84
  chatbot = gr.Chatbot(label="Chat")
 
91
  clear_button = gr.Button("Clear Chat")
92
 
93
  # Define the chat function
94
+ user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot])
95
  clear_button.click(lambda: [], None, chatbot, queue=False)
96
 
97
  # Launch the app