Abhinav Singh commited on
Commit
8464b63
Β·
1 Parent(s): 0ebbc88
Files changed (2) hide show
  1. app.py +460 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+ from torchvision.transforms.functional import InterpolationMode
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from decord import VideoReader, cpu
10
+ import tempfile
11
+ import json
12
+ from typing import List, Tuple, Optional, Union
13
+ import logging
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Constants
20
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
+ IMAGENET_STD = (0.229, 0.224, 0.225)
22
+ MODEL_PATH = "OpenGVLab/InternVL2_5-4B"
23
+
24
+ class InternVLChatBot:
25
+ def __init__(self):
26
+ self.model = None
27
+ self.tokenizer = None
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.generation_config = dict(max_new_tokens=1024, do_sample=True)
30
+ self.load_model()
31
+
32
+ def load_model(self):
33
+ """Load the InternVL model and tokenizer"""
34
+ try:
35
+ logger.info("Loading InternVL2.5-4B model...")
36
+ self.model = AutoModel.from_pretrained(
37
+ MODEL_PATH,
38
+ torch_dtype=torch.bfloat16,
39
+ low_cpu_mem_usage=True,
40
+ trust_remote_code=True,
41
+ use_flash_attn=True if self.device == "cuda" else False,
42
+ device_map="auto" if self.device == "cuda" else None
43
+ )
44
+ self.tokenizer = AutoTokenizer.from_pretrained(
45
+ MODEL_PATH, trust_remote_code=True
46
+ )
47
+ logger.info("Model loaded successfully!")
48
+ except Exception as e:
49
+ logger.error(f"Error loading model: {str(e)}")
50
+ raise e
51
+
52
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
53
+ """Find the closest aspect ratio from target ratios"""
54
+ best_ratio_diff = float('inf')
55
+ best_ratio = (1, 1)
56
+ area = width * height
57
+
58
+ for ratio in target_ratios:
59
+ target_aspect_ratio = ratio[0] / ratio[1]
60
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
61
+ if ratio_diff < best_ratio_diff:
62
+ best_ratio_diff = ratio_diff
63
+ best_ratio = ratio
64
+ elif ratio_diff == best_ratio_diff:
65
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
66
+ best_ratio = ratio
67
+ return best_ratio
68
+
69
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
70
+ """Dynamically preprocess image based on aspect ratio"""
71
+ orig_width, orig_height = image.size
72
+ aspect_ratio = orig_width / orig_height
73
+
74
+ # Calculate target ratios
75
+ target_ratios = set(
76
+ (i, j) for n in range(min_num, max_num + 1)
77
+ for i in range(1, n + 1)
78
+ for j in range(1, n + 1)
79
+ if i * j <= max_num and i * j >= min_num
80
+ )
81
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
82
+
83
+ # Find closest aspect ratio
84
+ target_aspect_ratio = self.find_closest_aspect_ratio(
85
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
86
+ )
87
+
88
+ # Calculate target dimensions
89
+ target_width = image_size * target_aspect_ratio[0]
90
+ target_height = image_size * target_aspect_ratio[1]
91
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
92
+
93
+ # Resize and split image
94
+ resized_img = image.resize((target_width, target_height))
95
+ processed_images = []
96
+
97
+ for i in range(blocks):
98
+ box = (
99
+ (i % (target_width // image_size)) * image_size,
100
+ (i // (target_width // image_size)) * image_size,
101
+ ((i % (target_width // image_size)) + 1) * image_size,
102
+ ((i // (target_width // image_size)) + 1) * image_size
103
+ )
104
+ split_img = resized_img.crop(box)
105
+ processed_images.append(split_img)
106
+
107
+ if use_thumbnail and len(processed_images) != 1:
108
+ thumbnail_img = image.resize((image_size, image_size))
109
+ processed_images.append(thumbnail_img)
110
+
111
+ return processed_images
112
+
113
+ def build_transform(self, input_size):
114
+ """Build image transformation pipeline"""
115
+ transform = T.Compose([
116
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
117
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
118
+ T.ToTensor(),
119
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
120
+ ])
121
+ return transform
122
+
123
+ def load_image(self, image_path, input_size=448, max_num=12):
124
+ """Load and preprocess image"""
125
+ if isinstance(image_path, str):
126
+ image = Image.open(image_path).convert('RGB')
127
+ else:
128
+ image = image_path.convert('RGB')
129
+
130
+ transform = self.build_transform(input_size=input_size)
131
+ images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
132
+ pixel_values = [transform(img) for img in images]
133
+ pixel_values = torch.stack(pixel_values)
134
+ return pixel_values
135
+
136
+ def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32):
137
+ """Get frame indices for video processing"""
138
+ if bound:
139
+ start, end = bound[0], bound[1]
140
+ else:
141
+ start, end = -100000, 100000
142
+
143
+ start_idx = max(first_idx, round(start * fps))
144
+ end_idx = min(round(end * fps), max_frame)
145
+ seg_size = float(end_idx - start_idx) / num_segments
146
+
147
+ frame_indices = np.array([
148
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
149
+ for idx in range(num_segments)
150
+ ])
151
+ return frame_indices
152
+
153
+ def load_video(self, video_path, bound=None, input_size=448, max_num=1, num_segments=32):
154
+ """Load and preprocess video"""
155
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
156
+ max_frame = len(vr) - 1
157
+ fps = float(vr.get_avg_fps())
158
+
159
+ pixel_values_list, num_patches_list = [], []
160
+ transform = self.build_transform(input_size=input_size)
161
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
162
+
163
+ for frame_index in frame_indices:
164
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
165
+ img = self.dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
166
+ pixel_values = [transform(tile) for tile in img]
167
+ pixel_values = torch.stack(pixel_values)
168
+ num_patches_list.append(pixel_values.shape[0])
169
+ pixel_values_list.append(pixel_values)
170
+
171
+ pixel_values = torch.cat(pixel_values_list)
172
+ return pixel_values, num_patches_list
173
+
174
+ def chat(self, message, history, image=None, video=None):
175
+ """Main chat function"""
176
+ try:
177
+ pixel_values = None
178
+ num_patches_list = None
179
+
180
+ # Process image if provided
181
+ if image is not None:
182
+ pixel_values = self.load_image(image, max_num=12)
183
+ if self.device == "cuda":
184
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
185
+ message = f"<image>\n{message}"
186
+
187
+ # Process video if provided
188
+ elif video is not None:
189
+ pixel_values, num_patches_list = self.load_video(video, num_segments=8, max_num=1)
190
+ if self.device == "cuda":
191
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
192
+ video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
193
+ message = f"{video_prefix}{message}"
194
+
195
+ # Convert history to the expected format
196
+ chat_history = []
197
+ if history:
198
+ for item in history:
199
+ if len(item) == 2:
200
+ chat_history.append((item[0], item[1]))
201
+
202
+ # Generate response
203
+ if num_patches_list is not None:
204
+ response, new_history = self.model.chat(
205
+ self.tokenizer,
206
+ pixel_values,
207
+ message,
208
+ self.generation_config,
209
+ num_patches_list=num_patches_list,
210
+ history=chat_history,
211
+ return_history=True
212
+ )
213
+ else:
214
+ response, new_history = self.model.chat(
215
+ self.tokenizer,
216
+ pixel_values,
217
+ message,
218
+ self.generation_config,
219
+ history=chat_history,
220
+ return_history=True
221
+ )
222
+
223
+ # Update history
224
+ history.append([message, response])
225
+
226
+ return "", history, None, None
227
+
228
+ except Exception as e:
229
+ logger.error(f"Error in chat: {str(e)}")
230
+ error_msg = f"Sorry, I encountered an error: {str(e)}"
231
+ history.append([message, error_msg])
232
+ return "", history, None, None
233
+
234
+ # Initialize the chatbot
235
+ chatbot = InternVLChatBot()
236
+
237
+ # Create Gradio interface
238
+ def create_interface():
239
+ """Create the Gradio interface"""
240
+
241
+ # Custom CSS for better styling
242
+ custom_css = """
243
+ .gradio-container {
244
+ font-family: 'Arial', sans-serif;
245
+ }
246
+ .chat-message {
247
+ padding: 10px;
248
+ margin: 5px 0;
249
+ border-radius: 10px;
250
+ }
251
+ .user-message {
252
+ background-color: #e3f2fd;
253
+ margin-left: 20px;
254
+ }
255
+ .bot-message {
256
+ background-color: #f5f5f5;
257
+ margin-right: 20px;
258
+ }
259
+ """
260
+
261
+ with gr.Blocks(css=custom_css, title="InternVL2.5-4B Chat") as interface:
262
+ gr.Markdown("""
263
+ # πŸ€– InternVL2.5-4B Multimodal Chat
264
+
265
+ Welcome to the InternVL2.5-4B chat interface! This AI assistant can:
266
+ - πŸ’¬ Have conversations with text
267
+ - πŸ–ΌοΈ Analyze and describe images
268
+ - πŸŽ₯ Process and understand videos
269
+ - πŸ“ Extract text from images (OCR)
270
+ - 🎯 Answer questions about visual content
271
+
272
+ **Instructions:**
273
+ 1. Type your message in the text box
274
+ 2. Optionally upload an image or video
275
+ 3. Click Send to get a response
276
+ 4. Use "Clear" to reset the conversation
277
+ """)
278
+
279
+ with gr.Row():
280
+ with gr.Column(scale=3):
281
+ chatbot_interface = gr.Chatbot(
282
+ label="Chat History",
283
+ height=500,
284
+ show_copy_button=True,
285
+ avatar_images=["πŸ‘€", "πŸ€–"]
286
+ )
287
+
288
+ with gr.Row():
289
+ msg = gr.Textbox(
290
+ label="Your Message",
291
+ placeholder="Type your message here... You can ask about images, videos, or just chat!",
292
+ lines=2,
293
+ scale=4
294
+ )
295
+ send_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
296
+
297
+ with gr.Row():
298
+ clear_btn = gr.Button("Clear πŸ—‘οΈ", scale=1)
299
+
300
+ with gr.Column(scale=1):
301
+ gr.Markdown("### πŸ“Ž Upload Media")
302
+
303
+ image_input = gr.Image(
304
+ label="Upload Image",
305
+ type="pil",
306
+ height=200
307
+ )
308
+
309
+ video_input = gr.Video(
310
+ label="Upload Video",
311
+ height=200
312
+ )
313
+
314
+ gr.Markdown("""
315
+ **Supported formats:**
316
+ - Images: JPG, PNG, WEBP, GIF
317
+ - Videos: MP4, AVI, MOV, WEBM
318
+
319
+ **Tips:**
320
+ - For images: Ask about content, extract text, or describe what you see
321
+ - For videos: Ask for descriptions, analysis, or specific details
322
+ - You can upload one media file at a time
323
+ """)
324
+
325
+ # Example prompts
326
+ gr.Markdown("### πŸ’‘ Example Prompts")
327
+ with gr.Row():
328
+ example_btn1 = gr.Button("πŸ‘‹ Hello, introduce yourself")
329
+ example_btn2 = gr.Button("πŸ–ΌοΈ Describe this image")
330
+ example_btn3 = gr.Button("πŸ“ Extract text from image")
331
+ example_btn4 = gr.Button("πŸŽ₯ Analyze this video")
332
+
333
+ # Event handlers
334
+ def submit_message(message, history, image, video):
335
+ if not message.strip():
336
+ return "", history, image, video
337
+ return chatbot.chat(message, history, image, video)
338
+
339
+ def clear_chat():
340
+ return [], None, None
341
+
342
+ def set_example_prompt(prompt):
343
+ return prompt
344
+
345
+ # Wire up the interface
346
+ send_btn.click(
347
+ fn=submit_message,
348
+ inputs=[msg, chatbot_interface, image_input, video_input],
349
+ outputs=[msg, chatbot_interface, image_input, video_input]
350
+ )
351
+
352
+ msg.submit(
353
+ fn=submit_message,
354
+ inputs=[msg, chatbot_interface, image_input, video_input],
355
+ outputs=[msg, chatbot_interface, image_input, video_input]
356
+ )
357
+
358
+ clear_btn.click(
359
+ fn=clear_chat,
360
+ outputs=[chatbot_interface, image_input, video_input]
361
+ )
362
+
363
+ # Example button handlers
364
+ example_btn1.click(
365
+ fn=set_example_prompt,
366
+ inputs=[gr.State("Hello, who are you?")],
367
+ outputs=[msg]
368
+ )
369
+
370
+ example_btn2.click(
371
+ fn=set_example_prompt,
372
+ inputs=[gr.State("Please describe this image in detail.")],
373
+ outputs=[msg]
374
+ )
375
+
376
+ example_btn3.click(
377
+ fn=set_example_prompt,
378
+ inputs=[gr.State("Extract the exact text provided in the image.")],
379
+ outputs=[msg]
380
+ )
381
+
382
+ example_btn4.click(
383
+ fn=set_example_prompt,
384
+ inputs=[gr.State("Describe this video in detail.")],
385
+ outputs=[msg]
386
+ )
387
+
388
+ # Footer
389
+ gr.Markdown("""
390
+ ---
391
+ **About InternVL2.5-4B:** A powerful multimodal AI model developed by Shanghai AI Lab, Tsinghua University and partners.
392
+
393
+ **API Usage:** This interface supports API calls. The chat endpoint accepts JSON with `message`, `image`, and `video` fields.
394
+ """)
395
+
396
+ return interface
397
+
398
+ # API endpoint for external integrations
399
+ def api_chat(message: str, image: Optional[str] = None, video: Optional[str] = None, history: Optional[List] = None):
400
+ """
401
+ API endpoint for chat functionality
402
+
403
+ Args:
404
+ message: Text message
405
+ image: Base64 encoded image or image path
406
+ video: Video file path
407
+ history: Chat history as list of [user_msg, bot_msg] pairs
408
+
409
+ Returns:
410
+ Dictionary with response and updated history
411
+ """
412
+ try:
413
+ if history is None:
414
+ history = []
415
+
416
+ # Process image if provided (handle base64 or file path)
417
+ image_obj = None
418
+ if image:
419
+ try:
420
+ if image.startswith('data:image'):
421
+ # Handle base64 image
422
+ import base64
423
+ from io import BytesIO
424
+ image_data = image.split(',')[1]
425
+ image_bytes = base64.b64decode(image_data)
426
+ image_obj = Image.open(BytesIO(image_bytes))
427
+ else:
428
+ # Handle file path
429
+ image_obj = Image.open(image)
430
+ except Exception as e:
431
+ logger.error(f"Error processing image: {str(e)}")
432
+
433
+ # Chat with the model
434
+ _, updated_history, _, _ = chatbot.chat(message, history, image_obj, video)
435
+
436
+ return {
437
+ "response": updated_history[-1][1] if updated_history else "",
438
+ "history": updated_history,
439
+ "status": "success"
440
+ }
441
+ except Exception as e:
442
+ logger.error(f"API Error: {str(e)}")
443
+ return {
444
+ "response": f"Error: {str(e)}",
445
+ "history": history,
446
+ "status": "error"
447
+ }
448
+
449
+ if __name__ == "__main__":
450
+ # Create and launch the interface
451
+ interface = create_interface()
452
+
453
+ # Launch with API access enabled
454
+ interface.launch(
455
+ server_name="0.0.0.0",
456
+ server_port=7860,
457
+ share=True,
458
+ show_api=True,
459
+ enable_queue=True
460
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.37.0
3
+ gradio>=4.0.0
4
+ torchvision>=0.15.0
5
+ pillow>=9.0.0
6
+ numpy>=1.21.0
7
+ decord>=0.6.0
8
+ accelerate>=0.20.0
9
+ bitsandbytes>=0.41.0
10
+ flash-attn>=2.3.0