Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,39 @@ import gradio as gr
|
|
2 |
from huggingface_hub import InferenceClient
|
3 |
import os
|
4 |
import json
|
|
|
|
|
|
|
5 |
|
6 |
ACCESS_TOKEN = os.getenv("HF_TOKEN")
|
7 |
print("Access token loaded.")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def respond(
|
10 |
message,
|
|
|
11 |
history: list[tuple[str, str]],
|
12 |
system_message,
|
13 |
max_tokens,
|
@@ -16,26 +43,26 @@ def respond(
|
|
16 |
frequency_penalty,
|
17 |
seed,
|
18 |
provider,
|
19 |
-
custom_api_key,
|
20 |
custom_model,
|
21 |
model_search_term,
|
22 |
selected_model
|
23 |
):
|
24 |
print(f"Received message: {message}")
|
|
|
25 |
print(f"History: {history}")
|
26 |
print(f"System message: {system_message}")
|
27 |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
|
28 |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
|
29 |
print(f"Selected provider: {provider}")
|
30 |
-
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
|
31 |
print(f"Selected model (custom_model): {custom_model}")
|
32 |
print(f"Model search term: {model_search_term}")
|
33 |
print(f"Selected model from radio: {selected_model}")
|
34 |
|
35 |
-
# Determine which token to use
|
36 |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
|
37 |
|
38 |
-
# Log which token source we're using (without printing the actual token)
|
39 |
if custom_api_key.strip() != "":
|
40 |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
|
41 |
else:
|
@@ -49,6 +76,33 @@ def respond(
|
|
49 |
if seed == -1:
|
50 |
seed = None
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Prepare messages in the format expected by the API
|
53 |
messages = [{"role": "system", "content": system_message}]
|
54 |
print("Initial messages array constructed.")
|
@@ -59,14 +113,14 @@ def respond(
|
|
59 |
assistant_part = val[1]
|
60 |
if user_part:
|
61 |
messages.append({"role": "user", "content": user_part})
|
62 |
-
print(f"Added user message to context: {user_part}")
|
63 |
if assistant_part:
|
64 |
messages.append({"role": "assistant", "content": assistant_part})
|
65 |
print(f"Added assistant message to context: {assistant_part}")
|
66 |
|
67 |
# Append the latest user message
|
68 |
-
messages.append({"role": "user", "content":
|
69 |
-
print("Latest user message appended
|
70 |
|
71 |
# Determine which model to use, prioritizing custom_model if provided
|
72 |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
|
@@ -90,15 +144,13 @@ def respond(
|
|
90 |
# Use the InferenceClient for making the request
|
91 |
try:
|
92 |
# Create a generator for the streaming response
|
93 |
-
# The provider is already set when initializing the client
|
94 |
stream = client.chat_completion(
|
95 |
model=model_to_use,
|
96 |
messages=messages,
|
97 |
stream=True,
|
98 |
-
**parameters
|
99 |
)
|
100 |
|
101 |
-
# Print a starting message for token streaming
|
102 |
print("Received tokens: ", end="", flush=True)
|
103 |
|
104 |
# Process the streaming response
|
@@ -108,12 +160,10 @@ def respond(
|
|
108 |
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
|
109 |
token_text = chunk.choices[0].delta.content
|
110 |
if token_text:
|
111 |
-
# Print tokens inline without newlines
|
112 |
print(token_text, end="", flush=True)
|
113 |
response += token_text
|
114 |
yield response
|
115 |
|
116 |
-
# Print a newline at the end of all tokens
|
117 |
print()
|
118 |
except Exception as e:
|
119 |
print(f"Error during inference: {e}")
|
@@ -124,174 +174,284 @@ def respond(
|
|
124 |
|
125 |
# Function to validate provider selection based on BYOK
|
126 |
def validate_provider(api_key, provider):
|
127 |
-
# If no custom API key is provided, only "hf-inference" can be used
|
128 |
if not api_key.strip() and provider != "hf-inference":
|
129 |
return gr.update(value="hf-inference")
|
130 |
return gr.update(value=provider)
|
131 |
|
132 |
# GRADIO UI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
#
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
)
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
)
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
)
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
)
|
195 |
-
|
196 |
-
#
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
#
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
)
|
292 |
-
print("
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
# Connect the model filter to update the radio choices
|
296 |
model_search_box.change(
|
297 |
fn=filter_models,
|
|
|
2 |
from huggingface_hub import InferenceClient
|
3 |
import os
|
4 |
import json
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
|
9 |
ACCESS_TOKEN = os.getenv("HF_TOKEN")
|
10 |
print("Access token loaded.")
|
11 |
|
12 |
+
# Function to encode image to base64
|
13 |
+
def encode_image(image):
|
14 |
+
if image is None:
|
15 |
+
return None
|
16 |
+
|
17 |
+
# Convert to PIL Image if needed
|
18 |
+
if not isinstance(image, Image.Image):
|
19 |
+
try:
|
20 |
+
image = Image.open(image)
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error opening image: {e}")
|
23 |
+
return None
|
24 |
+
|
25 |
+
# Convert to RGB if image has an alpha channel (RGBA)
|
26 |
+
if image.mode == 'RGBA':
|
27 |
+
image = image.convert('RGB')
|
28 |
+
|
29 |
+
# Encode to base64
|
30 |
+
buffered = io.BytesIO()
|
31 |
+
image.save(buffered, format="JPEG")
|
32 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
33 |
+
return img_str
|
34 |
+
|
35 |
def respond(
|
36 |
message,
|
37 |
+
images, # New parameter for uploaded images
|
38 |
history: list[tuple[str, str]],
|
39 |
system_message,
|
40 |
max_tokens,
|
|
|
43 |
frequency_penalty,
|
44 |
seed,
|
45 |
provider,
|
46 |
+
custom_api_key,
|
47 |
custom_model,
|
48 |
model_search_term,
|
49 |
selected_model
|
50 |
):
|
51 |
print(f"Received message: {message}")
|
52 |
+
print(f"Received {len(images) if images else 0} images")
|
53 |
print(f"History: {history}")
|
54 |
print(f"System message: {system_message}")
|
55 |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
|
56 |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
|
57 |
print(f"Selected provider: {provider}")
|
58 |
+
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
|
59 |
print(f"Selected model (custom_model): {custom_model}")
|
60 |
print(f"Model search term: {model_search_term}")
|
61 |
print(f"Selected model from radio: {selected_model}")
|
62 |
|
63 |
+
# Determine which token to use
|
64 |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
|
65 |
|
|
|
66 |
if custom_api_key.strip() != "":
|
67 |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
|
68 |
else:
|
|
|
76 |
if seed == -1:
|
77 |
seed = None
|
78 |
|
79 |
+
# Create multimodal content if images are present
|
80 |
+
if images and any(images):
|
81 |
+
# Process the user message to include images
|
82 |
+
user_content = []
|
83 |
+
|
84 |
+
# Add text part if there is any
|
85 |
+
if message and message.strip():
|
86 |
+
user_content.append({
|
87 |
+
"type": "text",
|
88 |
+
"text": message
|
89 |
+
})
|
90 |
+
|
91 |
+
# Add image parts
|
92 |
+
for img in images:
|
93 |
+
if img is not None:
|
94 |
+
encoded_image = encode_image(img)
|
95 |
+
if encoded_image:
|
96 |
+
user_content.append({
|
97 |
+
"type": "image_url",
|
98 |
+
"image_url": {
|
99 |
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
100 |
+
}
|
101 |
+
})
|
102 |
+
else:
|
103 |
+
# Text-only message
|
104 |
+
user_content = message
|
105 |
+
|
106 |
# Prepare messages in the format expected by the API
|
107 |
messages = [{"role": "system", "content": system_message}]
|
108 |
print("Initial messages array constructed.")
|
|
|
113 |
assistant_part = val[1]
|
114 |
if user_part:
|
115 |
messages.append({"role": "user", "content": user_part})
|
116 |
+
print(f"Added user message to context (type: {type(user_part)})")
|
117 |
if assistant_part:
|
118 |
messages.append({"role": "assistant", "content": assistant_part})
|
119 |
print(f"Added assistant message to context: {assistant_part}")
|
120 |
|
121 |
# Append the latest user message
|
122 |
+
messages.append({"role": "user", "content": user_content})
|
123 |
+
print(f"Latest user message appended (content type: {type(user_content)})")
|
124 |
|
125 |
# Determine which model to use, prioritizing custom_model if provided
|
126 |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
|
|
|
144 |
# Use the InferenceClient for making the request
|
145 |
try:
|
146 |
# Create a generator for the streaming response
|
|
|
147 |
stream = client.chat_completion(
|
148 |
model=model_to_use,
|
149 |
messages=messages,
|
150 |
stream=True,
|
151 |
+
**parameters
|
152 |
)
|
153 |
|
|
|
154 |
print("Received tokens: ", end="", flush=True)
|
155 |
|
156 |
# Process the streaming response
|
|
|
160 |
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
|
161 |
token_text = chunk.choices[0].delta.content
|
162 |
if token_text:
|
|
|
163 |
print(token_text, end="", flush=True)
|
164 |
response += token_text
|
165 |
yield response
|
166 |
|
|
|
167 |
print()
|
168 |
except Exception as e:
|
169 |
print(f"Error during inference: {e}")
|
|
|
174 |
|
175 |
# Function to validate provider selection based on BYOK
|
176 |
def validate_provider(api_key, provider):
|
|
|
177 |
if not api_key.strip() and provider != "hf-inference":
|
178 |
return gr.update(value="hf-inference")
|
179 |
return gr.update(value=provider)
|
180 |
|
181 |
# GRADIO UI
|
182 |
+
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
|
183 |
+
# Create the chatbot component
|
184 |
+
chatbot = gr.Chatbot(
|
185 |
+
height=600,
|
186 |
+
show_copy_button=True,
|
187 |
+
placeholder="Select a model and begin chatting",
|
188 |
+
layout="panel"
|
189 |
+
)
|
190 |
+
print("Chatbot interface created.")
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
# Text input for messages
|
194 |
+
msg = gr.Textbox(
|
195 |
+
placeholder="Type a message...",
|
196 |
+
show_label=False,
|
197 |
+
container=False,
|
198 |
+
scale=9
|
199 |
+
)
|
200 |
+
|
201 |
+
# Image upload button
|
202 |
+
image_upload = gr.Image(
|
203 |
+
type="filepath",
|
204 |
+
label="Upload Image",
|
205 |
+
scale=1
|
206 |
+
)
|
207 |
|
208 |
+
# Send button for messages
|
209 |
+
submit_btn = gr.Button("Send", variant="primary")
|
210 |
+
|
211 |
+
# Create tabs for different settings
|
212 |
+
with gr.Accordion("Settings", open=False):
|
213 |
+
# Tab for general settings
|
214 |
+
with gr.Tab("General Settings"):
|
215 |
+
# System message
|
216 |
+
system_message_box = gr.Textbox(
|
217 |
+
value="You are a helpful AI assistant that can understand images and text.",
|
218 |
+
placeholder="You are a helpful assistant.",
|
219 |
+
label="System Prompt"
|
220 |
+
)
|
221 |
+
|
222 |
+
# Generation parameters
|
223 |
+
with gr.Row():
|
224 |
+
with gr.Column():
|
225 |
+
max_tokens_slider = gr.Slider(
|
226 |
+
minimum=1,
|
227 |
+
maximum=4096,
|
228 |
+
value=512,
|
229 |
+
step=1,
|
230 |
+
label="Max tokens"
|
231 |
+
)
|
232 |
+
|
233 |
+
temperature_slider = gr.Slider(
|
234 |
+
minimum=0.1,
|
235 |
+
maximum=4.0,
|
236 |
+
value=0.7,
|
237 |
+
step=0.1,
|
238 |
+
label="Temperature"
|
239 |
+
)
|
240 |
+
|
241 |
+
with gr.Column():
|
242 |
+
top_p_slider = gr.Slider(
|
243 |
+
minimum=0.1,
|
244 |
+
maximum=1.0,
|
245 |
+
value=0.95,
|
246 |
+
step=0.05,
|
247 |
+
label="Top-P"
|
248 |
+
)
|
249 |
+
|
250 |
+
frequency_penalty_slider = gr.Slider(
|
251 |
+
minimum=-2.0,
|
252 |
+
maximum=2.0,
|
253 |
+
value=0.0,
|
254 |
+
step=0.1,
|
255 |
+
label="Frequency Penalty"
|
256 |
+
)
|
257 |
+
|
258 |
+
seed_slider = gr.Slider(
|
259 |
+
minimum=-1,
|
260 |
+
maximum=65535,
|
261 |
+
value=-1,
|
262 |
+
step=1,
|
263 |
+
label="Seed (-1 for random)"
|
264 |
+
)
|
265 |
+
|
266 |
+
# Tab for provider and model selection
|
267 |
+
with gr.Tab("Provider & Model"):
|
268 |
+
with gr.Row():
|
269 |
+
with gr.Column():
|
270 |
+
# Provider selection
|
271 |
+
providers_list = [
|
272 |
+
"hf-inference", # Default Hugging Face Inference
|
273 |
+
"cerebras", # Cerebras provider
|
274 |
+
"together", # Together AI
|
275 |
+
"sambanova", # SambaNova
|
276 |
+
"novita", # Novita AI
|
277 |
+
"cohere", # Cohere
|
278 |
+
"fireworks-ai", # Fireworks AI
|
279 |
+
"hyperbolic", # Hyperbolic
|
280 |
+
"nebius", # Nebius
|
281 |
+
]
|
282 |
+
|
283 |
+
provider_radio = gr.Radio(
|
284 |
+
choices=providers_list,
|
285 |
+
value="hf-inference",
|
286 |
+
label="Inference Provider",
|
287 |
+
info="[View all models here](https://huggingface.co/models?inference_provider=all&sort=trending)"
|
288 |
+
)
|
289 |
+
|
290 |
+
# New BYOK textbox
|
291 |
+
byok_textbox = gr.Textbox(
|
292 |
+
value="",
|
293 |
+
label="BYOK (Bring Your Own Key)",
|
294 |
+
info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.",
|
295 |
+
placeholder="Enter your Hugging Face API token",
|
296 |
+
type="password" # Hide the API key for security
|
297 |
+
)
|
298 |
+
|
299 |
+
with gr.Column():
|
300 |
+
# Custom model box
|
301 |
+
custom_model_box = gr.Textbox(
|
302 |
+
value="",
|
303 |
+
label="Custom Model",
|
304 |
+
info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
|
305 |
+
placeholder="meta-llama/Llama-3.3-70B-Instruct"
|
306 |
+
)
|
307 |
+
|
308 |
+
# Model search
|
309 |
+
model_search_box = gr.Textbox(
|
310 |
+
label="Filter Models",
|
311 |
+
placeholder="Search for a featured model...",
|
312 |
+
lines=1
|
313 |
+
)
|
314 |
+
|
315 |
+
# Featured models list
|
316 |
+
# Updated to include multimodal models
|
317 |
+
models_list = [
|
318 |
+
# Multimodal models
|
319 |
+
"meta-llama/Llama-3.3-70B-Vision",
|
320 |
+
"Alibaba-NLP/NephilaV-16B-Chat",
|
321 |
+
"mistralai/Mistral-Large-Vision-2407",
|
322 |
+
"OpenGVLab/InternVL-Chat-V1-5",
|
323 |
+
"microsoft/Phi-3.5-vision-instruct",
|
324 |
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
325 |
+
"liuhaotian/llava-v1.6-mistral-7b",
|
326 |
+
|
327 |
+
# Standard text models
|
328 |
+
"meta-llama/Llama-3.3-70B-Instruct",
|
329 |
+
"meta-llama/Llama-3.1-70B-Instruct",
|
330 |
+
"meta-llama/Llama-3.0-70B-Instruct",
|
331 |
+
"meta-llama/Llama-3.2-3B-Instruct",
|
332 |
+
"meta-llama/Llama-3.2-1B-Instruct",
|
333 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
334 |
+
"NousResearch/Hermes-3-Llama-3.1-8B",
|
335 |
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
336 |
+
"mistralai/Mistral-Nemo-Instruct-2407",
|
337 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
338 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
339 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
340 |
+
"Qwen/Qwen3-235B-A22B",
|
341 |
+
"Qwen/Qwen3-32B",
|
342 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
343 |
+
"Qwen/Qwen2.5-3B-Instruct",
|
344 |
+
"Qwen/Qwen2.5-0.5B-Instruct",
|
345 |
+
"Qwen/QwQ-32B",
|
346 |
+
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
347 |
+
"microsoft/Phi-3.5-mini-instruct",
|
348 |
+
"microsoft/Phi-3-mini-128k-instruct",
|
349 |
+
"microsoft/Phi-3-mini-4k-instruct",
|
350 |
+
]
|
351 |
+
|
352 |
+
featured_model_radio = gr.Radio(
|
353 |
+
label="Select a model below",
|
354 |
+
choices=models_list,
|
355 |
+
value="meta-llama/Llama-3.3-70B-Vision", # Default to a multimodal model
|
356 |
+
interactive=True
|
357 |
+
)
|
358 |
+
|
359 |
+
gr.Markdown("[View all multimodal models](https://huggingface.co/models?pipeline_tag=image-to-text&sort=trending)")
|
360 |
+
|
361 |
+
# Chat history state
|
362 |
+
chat_history = gr.State([])
|
363 |
+
|
364 |
+
# Function to filter models
|
365 |
+
def filter_models(search_term):
|
366 |
+
print(f"Filtering models with search term: {search_term}")
|
367 |
+
filtered = [m for m in models_list if search_term.lower() in m.lower()]
|
368 |
+
print(f"Filtered models: {filtered}")
|
369 |
+
return gr.update(choices=filtered)
|
370 |
+
|
371 |
+
# Function to set custom model from radio
|
372 |
+
def set_custom_model_from_radio(selected):
|
373 |
+
print(f"Featured model selected: {selected}")
|
374 |
+
return selected
|
375 |
+
|
376 |
+
# Function for the chat interface
|
377 |
+
def user(user_message, image, history):
|
378 |
+
if user_message == "" and image is None:
|
379 |
+
return history
|
380 |
+
|
381 |
+
# Format image reference for display
|
382 |
+
img_placeholder = ""
|
383 |
+
if image is not None:
|
384 |
+
img_placeholder = f""
|
385 |
+
|
386 |
+
# Combine text and image reference for display
|
387 |
+
display_message = f"{user_message}\n{img_placeholder}" if img_placeholder else user_message
|
388 |
+
|
389 |
+
# Return updated history
|
390 |
+
return history + [[display_message, None]]
|
391 |
+
|
392 |
+
# Define chat interface
|
393 |
+
def bot(history, images, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
|
394 |
+
# Extract the last user message
|
395 |
+
user_message = history[-1][0] if history and len(history) > 0 else ""
|
396 |
+
|
397 |
+
# Clean up the user message to remove image reference
|
398 |
+
if "![Image]" in user_message:
|
399 |
+
text_parts = user_message.split("![Image]")[0].strip()
|
400 |
+
else:
|
401 |
+
text_parts = user_message
|
402 |
+
|
403 |
+
# Process message through respond function
|
404 |
+
history[-1][1] = ""
|
405 |
+
for response in respond(
|
406 |
+
text_parts, # Send only the text part
|
407 |
+
[images], # Send images separately
|
408 |
+
history[:-1],
|
409 |
+
system_msg,
|
410 |
+
max_tokens,
|
411 |
+
temperature,
|
412 |
+
top_p,
|
413 |
+
freq_penalty,
|
414 |
+
seed,
|
415 |
+
provider,
|
416 |
+
api_key,
|
417 |
+
custom_model,
|
418 |
+
search_term,
|
419 |
+
selected_model
|
420 |
+
):
|
421 |
+
history[-1][1] = response
|
422 |
+
yield history
|
423 |
+
|
424 |
+
# Event handlers
|
425 |
+
msg.submit(
|
426 |
+
user,
|
427 |
+
[msg, image_upload, chatbot],
|
428 |
+
[chatbot],
|
429 |
+
queue=False
|
430 |
+
).then(
|
431 |
+
bot,
|
432 |
+
[chatbot, image_upload, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
|
433 |
+
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
|
434 |
+
model_search_box, featured_model_radio],
|
435 |
+
[chatbot]
|
436 |
+
)
|
437 |
+
|
438 |
+
submit_btn.click(
|
439 |
+
user,
|
440 |
+
[msg, image_upload, chatbot],
|
441 |
+
[chatbot],
|
442 |
+
queue=False
|
443 |
+
).then(
|
444 |
+
bot,
|
445 |
+
[chatbot, image_upload, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
|
446 |
+
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
|
447 |
+
model_search_box, featured_model_radio],
|
448 |
+
[chatbot]
|
449 |
+
).then(
|
450 |
+
lambda: (None, "", None), # Clear inputs after submission
|
451 |
+
None,
|
452 |
+
[msg, msg, image_upload]
|
453 |
+
)
|
454 |
+
|
455 |
# Connect the model filter to update the radio choices
|
456 |
model_search_box.change(
|
457 |
fn=filter_models,
|