Spaces:
Sleeping
Sleeping
Abhinav Singh
commited on
Commit
Β·
8464b63
1
Parent(s):
0ebbc88
init
Browse files- app.py +460 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from torchvision.transforms.functional import InterpolationMode
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
from decord import VideoReader, cpu
|
10 |
+
import tempfile
|
11 |
+
import json
|
12 |
+
from typing import List, Tuple, Optional, Union
|
13 |
+
import logging
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Constants
|
20 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
21 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
22 |
+
MODEL_PATH = "OpenGVLab/InternVL2_5-4B"
|
23 |
+
|
24 |
+
class InternVLChatBot:
|
25 |
+
def __init__(self):
|
26 |
+
self.model = None
|
27 |
+
self.tokenizer = None
|
28 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
self.generation_config = dict(max_new_tokens=1024, do_sample=True)
|
30 |
+
self.load_model()
|
31 |
+
|
32 |
+
def load_model(self):
|
33 |
+
"""Load the InternVL model and tokenizer"""
|
34 |
+
try:
|
35 |
+
logger.info("Loading InternVL2.5-4B model...")
|
36 |
+
self.model = AutoModel.from_pretrained(
|
37 |
+
MODEL_PATH,
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
low_cpu_mem_usage=True,
|
40 |
+
trust_remote_code=True,
|
41 |
+
use_flash_attn=True if self.device == "cuda" else False,
|
42 |
+
device_map="auto" if self.device == "cuda" else None
|
43 |
+
)
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
45 |
+
MODEL_PATH, trust_remote_code=True
|
46 |
+
)
|
47 |
+
logger.info("Model loaded successfully!")
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error loading model: {str(e)}")
|
50 |
+
raise e
|
51 |
+
|
52 |
+
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
53 |
+
"""Find the closest aspect ratio from target ratios"""
|
54 |
+
best_ratio_diff = float('inf')
|
55 |
+
best_ratio = (1, 1)
|
56 |
+
area = width * height
|
57 |
+
|
58 |
+
for ratio in target_ratios:
|
59 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
60 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
61 |
+
if ratio_diff < best_ratio_diff:
|
62 |
+
best_ratio_diff = ratio_diff
|
63 |
+
best_ratio = ratio
|
64 |
+
elif ratio_diff == best_ratio_diff:
|
65 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
66 |
+
best_ratio = ratio
|
67 |
+
return best_ratio
|
68 |
+
|
69 |
+
def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
70 |
+
"""Dynamically preprocess image based on aspect ratio"""
|
71 |
+
orig_width, orig_height = image.size
|
72 |
+
aspect_ratio = orig_width / orig_height
|
73 |
+
|
74 |
+
# Calculate target ratios
|
75 |
+
target_ratios = set(
|
76 |
+
(i, j) for n in range(min_num, max_num + 1)
|
77 |
+
for i in range(1, n + 1)
|
78 |
+
for j in range(1, n + 1)
|
79 |
+
if i * j <= max_num and i * j >= min_num
|
80 |
+
)
|
81 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
82 |
+
|
83 |
+
# Find closest aspect ratio
|
84 |
+
target_aspect_ratio = self.find_closest_aspect_ratio(
|
85 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
86 |
+
)
|
87 |
+
|
88 |
+
# Calculate target dimensions
|
89 |
+
target_width = image_size * target_aspect_ratio[0]
|
90 |
+
target_height = image_size * target_aspect_ratio[1]
|
91 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
92 |
+
|
93 |
+
# Resize and split image
|
94 |
+
resized_img = image.resize((target_width, target_height))
|
95 |
+
processed_images = []
|
96 |
+
|
97 |
+
for i in range(blocks):
|
98 |
+
box = (
|
99 |
+
(i % (target_width // image_size)) * image_size,
|
100 |
+
(i // (target_width // image_size)) * image_size,
|
101 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
102 |
+
((i // (target_width // image_size)) + 1) * image_size
|
103 |
+
)
|
104 |
+
split_img = resized_img.crop(box)
|
105 |
+
processed_images.append(split_img)
|
106 |
+
|
107 |
+
if use_thumbnail and len(processed_images) != 1:
|
108 |
+
thumbnail_img = image.resize((image_size, image_size))
|
109 |
+
processed_images.append(thumbnail_img)
|
110 |
+
|
111 |
+
return processed_images
|
112 |
+
|
113 |
+
def build_transform(self, input_size):
|
114 |
+
"""Build image transformation pipeline"""
|
115 |
+
transform = T.Compose([
|
116 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
117 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
118 |
+
T.ToTensor(),
|
119 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
120 |
+
])
|
121 |
+
return transform
|
122 |
+
|
123 |
+
def load_image(self, image_path, input_size=448, max_num=12):
|
124 |
+
"""Load and preprocess image"""
|
125 |
+
if isinstance(image_path, str):
|
126 |
+
image = Image.open(image_path).convert('RGB')
|
127 |
+
else:
|
128 |
+
image = image_path.convert('RGB')
|
129 |
+
|
130 |
+
transform = self.build_transform(input_size=input_size)
|
131 |
+
images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
132 |
+
pixel_values = [transform(img) for img in images]
|
133 |
+
pixel_values = torch.stack(pixel_values)
|
134 |
+
return pixel_values
|
135 |
+
|
136 |
+
def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32):
|
137 |
+
"""Get frame indices for video processing"""
|
138 |
+
if bound:
|
139 |
+
start, end = bound[0], bound[1]
|
140 |
+
else:
|
141 |
+
start, end = -100000, 100000
|
142 |
+
|
143 |
+
start_idx = max(first_idx, round(start * fps))
|
144 |
+
end_idx = min(round(end * fps), max_frame)
|
145 |
+
seg_size = float(end_idx - start_idx) / num_segments
|
146 |
+
|
147 |
+
frame_indices = np.array([
|
148 |
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
149 |
+
for idx in range(num_segments)
|
150 |
+
])
|
151 |
+
return frame_indices
|
152 |
+
|
153 |
+
def load_video(self, video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
154 |
+
"""Load and preprocess video"""
|
155 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
156 |
+
max_frame = len(vr) - 1
|
157 |
+
fps = float(vr.get_avg_fps())
|
158 |
+
|
159 |
+
pixel_values_list, num_patches_list = [], []
|
160 |
+
transform = self.build_transform(input_size=input_size)
|
161 |
+
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
162 |
+
|
163 |
+
for frame_index in frame_indices:
|
164 |
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
165 |
+
img = self.dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
166 |
+
pixel_values = [transform(tile) for tile in img]
|
167 |
+
pixel_values = torch.stack(pixel_values)
|
168 |
+
num_patches_list.append(pixel_values.shape[0])
|
169 |
+
pixel_values_list.append(pixel_values)
|
170 |
+
|
171 |
+
pixel_values = torch.cat(pixel_values_list)
|
172 |
+
return pixel_values, num_patches_list
|
173 |
+
|
174 |
+
def chat(self, message, history, image=None, video=None):
|
175 |
+
"""Main chat function"""
|
176 |
+
try:
|
177 |
+
pixel_values = None
|
178 |
+
num_patches_list = None
|
179 |
+
|
180 |
+
# Process image if provided
|
181 |
+
if image is not None:
|
182 |
+
pixel_values = self.load_image(image, max_num=12)
|
183 |
+
if self.device == "cuda":
|
184 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
185 |
+
message = f"<image>\n{message}"
|
186 |
+
|
187 |
+
# Process video if provided
|
188 |
+
elif video is not None:
|
189 |
+
pixel_values, num_patches_list = self.load_video(video, num_segments=8, max_num=1)
|
190 |
+
if self.device == "cuda":
|
191 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
192 |
+
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
|
193 |
+
message = f"{video_prefix}{message}"
|
194 |
+
|
195 |
+
# Convert history to the expected format
|
196 |
+
chat_history = []
|
197 |
+
if history:
|
198 |
+
for item in history:
|
199 |
+
if len(item) == 2:
|
200 |
+
chat_history.append((item[0], item[1]))
|
201 |
+
|
202 |
+
# Generate response
|
203 |
+
if num_patches_list is not None:
|
204 |
+
response, new_history = self.model.chat(
|
205 |
+
self.tokenizer,
|
206 |
+
pixel_values,
|
207 |
+
message,
|
208 |
+
self.generation_config,
|
209 |
+
num_patches_list=num_patches_list,
|
210 |
+
history=chat_history,
|
211 |
+
return_history=True
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
response, new_history = self.model.chat(
|
215 |
+
self.tokenizer,
|
216 |
+
pixel_values,
|
217 |
+
message,
|
218 |
+
self.generation_config,
|
219 |
+
history=chat_history,
|
220 |
+
return_history=True
|
221 |
+
)
|
222 |
+
|
223 |
+
# Update history
|
224 |
+
history.append([message, response])
|
225 |
+
|
226 |
+
return "", history, None, None
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
logger.error(f"Error in chat: {str(e)}")
|
230 |
+
error_msg = f"Sorry, I encountered an error: {str(e)}"
|
231 |
+
history.append([message, error_msg])
|
232 |
+
return "", history, None, None
|
233 |
+
|
234 |
+
# Initialize the chatbot
|
235 |
+
chatbot = InternVLChatBot()
|
236 |
+
|
237 |
+
# Create Gradio interface
|
238 |
+
def create_interface():
|
239 |
+
"""Create the Gradio interface"""
|
240 |
+
|
241 |
+
# Custom CSS for better styling
|
242 |
+
custom_css = """
|
243 |
+
.gradio-container {
|
244 |
+
font-family: 'Arial', sans-serif;
|
245 |
+
}
|
246 |
+
.chat-message {
|
247 |
+
padding: 10px;
|
248 |
+
margin: 5px 0;
|
249 |
+
border-radius: 10px;
|
250 |
+
}
|
251 |
+
.user-message {
|
252 |
+
background-color: #e3f2fd;
|
253 |
+
margin-left: 20px;
|
254 |
+
}
|
255 |
+
.bot-message {
|
256 |
+
background-color: #f5f5f5;
|
257 |
+
margin-right: 20px;
|
258 |
+
}
|
259 |
+
"""
|
260 |
+
|
261 |
+
with gr.Blocks(css=custom_css, title="InternVL2.5-4B Chat") as interface:
|
262 |
+
gr.Markdown("""
|
263 |
+
# π€ InternVL2.5-4B Multimodal Chat
|
264 |
+
|
265 |
+
Welcome to the InternVL2.5-4B chat interface! This AI assistant can:
|
266 |
+
- π¬ Have conversations with text
|
267 |
+
- πΌοΈ Analyze and describe images
|
268 |
+
- π₯ Process and understand videos
|
269 |
+
- π Extract text from images (OCR)
|
270 |
+
- π― Answer questions about visual content
|
271 |
+
|
272 |
+
**Instructions:**
|
273 |
+
1. Type your message in the text box
|
274 |
+
2. Optionally upload an image or video
|
275 |
+
3. Click Send to get a response
|
276 |
+
4. Use "Clear" to reset the conversation
|
277 |
+
""")
|
278 |
+
|
279 |
+
with gr.Row():
|
280 |
+
with gr.Column(scale=3):
|
281 |
+
chatbot_interface = gr.Chatbot(
|
282 |
+
label="Chat History",
|
283 |
+
height=500,
|
284 |
+
show_copy_button=True,
|
285 |
+
avatar_images=["π€", "π€"]
|
286 |
+
)
|
287 |
+
|
288 |
+
with gr.Row():
|
289 |
+
msg = gr.Textbox(
|
290 |
+
label="Your Message",
|
291 |
+
placeholder="Type your message here... You can ask about images, videos, or just chat!",
|
292 |
+
lines=2,
|
293 |
+
scale=4
|
294 |
+
)
|
295 |
+
send_btn = gr.Button("Send π€", scale=1, variant="primary")
|
296 |
+
|
297 |
+
with gr.Row():
|
298 |
+
clear_btn = gr.Button("Clear ποΈ", scale=1)
|
299 |
+
|
300 |
+
with gr.Column(scale=1):
|
301 |
+
gr.Markdown("### π Upload Media")
|
302 |
+
|
303 |
+
image_input = gr.Image(
|
304 |
+
label="Upload Image",
|
305 |
+
type="pil",
|
306 |
+
height=200
|
307 |
+
)
|
308 |
+
|
309 |
+
video_input = gr.Video(
|
310 |
+
label="Upload Video",
|
311 |
+
height=200
|
312 |
+
)
|
313 |
+
|
314 |
+
gr.Markdown("""
|
315 |
+
**Supported formats:**
|
316 |
+
- Images: JPG, PNG, WEBP, GIF
|
317 |
+
- Videos: MP4, AVI, MOV, WEBM
|
318 |
+
|
319 |
+
**Tips:**
|
320 |
+
- For images: Ask about content, extract text, or describe what you see
|
321 |
+
- For videos: Ask for descriptions, analysis, or specific details
|
322 |
+
- You can upload one media file at a time
|
323 |
+
""")
|
324 |
+
|
325 |
+
# Example prompts
|
326 |
+
gr.Markdown("### π‘ Example Prompts")
|
327 |
+
with gr.Row():
|
328 |
+
example_btn1 = gr.Button("π Hello, introduce yourself")
|
329 |
+
example_btn2 = gr.Button("πΌοΈ Describe this image")
|
330 |
+
example_btn3 = gr.Button("π Extract text from image")
|
331 |
+
example_btn4 = gr.Button("π₯ Analyze this video")
|
332 |
+
|
333 |
+
# Event handlers
|
334 |
+
def submit_message(message, history, image, video):
|
335 |
+
if not message.strip():
|
336 |
+
return "", history, image, video
|
337 |
+
return chatbot.chat(message, history, image, video)
|
338 |
+
|
339 |
+
def clear_chat():
|
340 |
+
return [], None, None
|
341 |
+
|
342 |
+
def set_example_prompt(prompt):
|
343 |
+
return prompt
|
344 |
+
|
345 |
+
# Wire up the interface
|
346 |
+
send_btn.click(
|
347 |
+
fn=submit_message,
|
348 |
+
inputs=[msg, chatbot_interface, image_input, video_input],
|
349 |
+
outputs=[msg, chatbot_interface, image_input, video_input]
|
350 |
+
)
|
351 |
+
|
352 |
+
msg.submit(
|
353 |
+
fn=submit_message,
|
354 |
+
inputs=[msg, chatbot_interface, image_input, video_input],
|
355 |
+
outputs=[msg, chatbot_interface, image_input, video_input]
|
356 |
+
)
|
357 |
+
|
358 |
+
clear_btn.click(
|
359 |
+
fn=clear_chat,
|
360 |
+
outputs=[chatbot_interface, image_input, video_input]
|
361 |
+
)
|
362 |
+
|
363 |
+
# Example button handlers
|
364 |
+
example_btn1.click(
|
365 |
+
fn=set_example_prompt,
|
366 |
+
inputs=[gr.State("Hello, who are you?")],
|
367 |
+
outputs=[msg]
|
368 |
+
)
|
369 |
+
|
370 |
+
example_btn2.click(
|
371 |
+
fn=set_example_prompt,
|
372 |
+
inputs=[gr.State("Please describe this image in detail.")],
|
373 |
+
outputs=[msg]
|
374 |
+
)
|
375 |
+
|
376 |
+
example_btn3.click(
|
377 |
+
fn=set_example_prompt,
|
378 |
+
inputs=[gr.State("Extract the exact text provided in the image.")],
|
379 |
+
outputs=[msg]
|
380 |
+
)
|
381 |
+
|
382 |
+
example_btn4.click(
|
383 |
+
fn=set_example_prompt,
|
384 |
+
inputs=[gr.State("Describe this video in detail.")],
|
385 |
+
outputs=[msg]
|
386 |
+
)
|
387 |
+
|
388 |
+
# Footer
|
389 |
+
gr.Markdown("""
|
390 |
+
---
|
391 |
+
**About InternVL2.5-4B:** A powerful multimodal AI model developed by Shanghai AI Lab, Tsinghua University and partners.
|
392 |
+
|
393 |
+
**API Usage:** This interface supports API calls. The chat endpoint accepts JSON with `message`, `image`, and `video` fields.
|
394 |
+
""")
|
395 |
+
|
396 |
+
return interface
|
397 |
+
|
398 |
+
# API endpoint for external integrations
|
399 |
+
def api_chat(message: str, image: Optional[str] = None, video: Optional[str] = None, history: Optional[List] = None):
|
400 |
+
"""
|
401 |
+
API endpoint for chat functionality
|
402 |
+
|
403 |
+
Args:
|
404 |
+
message: Text message
|
405 |
+
image: Base64 encoded image or image path
|
406 |
+
video: Video file path
|
407 |
+
history: Chat history as list of [user_msg, bot_msg] pairs
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
Dictionary with response and updated history
|
411 |
+
"""
|
412 |
+
try:
|
413 |
+
if history is None:
|
414 |
+
history = []
|
415 |
+
|
416 |
+
# Process image if provided (handle base64 or file path)
|
417 |
+
image_obj = None
|
418 |
+
if image:
|
419 |
+
try:
|
420 |
+
if image.startswith('data:image'):
|
421 |
+
# Handle base64 image
|
422 |
+
import base64
|
423 |
+
from io import BytesIO
|
424 |
+
image_data = image.split(',')[1]
|
425 |
+
image_bytes = base64.b64decode(image_data)
|
426 |
+
image_obj = Image.open(BytesIO(image_bytes))
|
427 |
+
else:
|
428 |
+
# Handle file path
|
429 |
+
image_obj = Image.open(image)
|
430 |
+
except Exception as e:
|
431 |
+
logger.error(f"Error processing image: {str(e)}")
|
432 |
+
|
433 |
+
# Chat with the model
|
434 |
+
_, updated_history, _, _ = chatbot.chat(message, history, image_obj, video)
|
435 |
+
|
436 |
+
return {
|
437 |
+
"response": updated_history[-1][1] if updated_history else "",
|
438 |
+
"history": updated_history,
|
439 |
+
"status": "success"
|
440 |
+
}
|
441 |
+
except Exception as e:
|
442 |
+
logger.error(f"API Error: {str(e)}")
|
443 |
+
return {
|
444 |
+
"response": f"Error: {str(e)}",
|
445 |
+
"history": history,
|
446 |
+
"status": "error"
|
447 |
+
}
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
# Create and launch the interface
|
451 |
+
interface = create_interface()
|
452 |
+
|
453 |
+
# Launch with API access enabled
|
454 |
+
interface.launch(
|
455 |
+
server_name="0.0.0.0",
|
456 |
+
server_port=7860,
|
457 |
+
share=True,
|
458 |
+
show_api=True,
|
459 |
+
enable_queue=True
|
460 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.37.0
|
3 |
+
gradio>=4.0.0
|
4 |
+
torchvision>=0.15.0
|
5 |
+
pillow>=9.0.0
|
6 |
+
numpy>=1.21.0
|
7 |
+
decord>=0.6.0
|
8 |
+
accelerate>=0.20.0
|
9 |
+
bitsandbytes>=0.41.0
|
10 |
+
flash-attn>=2.3.0
|