LukasHug commited on
Commit
9f0f2b7
·
verified ·
1 Parent(s): bd758e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -59
app.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  from PIL import Image
13
  from transformers import (
14
  AutoProcessor,
15
- AutoTokenizer,
16
  Qwen2_5_VLForConditionalGeneration,
17
  LlavaOnevisionForConditionalGeneration
18
  )
@@ -74,16 +74,16 @@ class SimpleConversation:
74
  def to_gradio_chatbot(self):
75
  if not self.messages:
76
  return []
77
-
78
  ret = []
79
  for msg in self.messages:
80
  prompt = msg[0]
81
  if isinstance(prompt, tuple) and len(prompt) > 0:
82
  prompt = prompt[0]
83
-
84
  if prompt and isinstance(prompt, str) and "<image>" in prompt:
85
  prompt = prompt.replace("<image>", "")
86
-
87
  ret.append([prompt, msg[1]])
88
  return ret
89
 
@@ -123,6 +123,7 @@ class SimpleConversation:
123
  new_conv.messages = self.messages.copy() if self.messages else []
124
  return new_conv
125
 
 
126
  default_conversation = SimpleConversation()
127
 
128
  # Model and processor storage
@@ -131,55 +132,56 @@ model = None
131
  processor = None
132
  context_len = 8048
133
 
 
134
  def wrap_taxonomy(text):
135
  """Wraps user input with taxonomy if not already present"""
136
  if policy_v1 not in text:
137
  return policy_v1 + "\n\n" + text
138
  return text
139
 
 
140
  # UI component states
141
  no_change_btn = gr.Button()
142
  enable_btn = gr.Button(interactive=True)
143
  disable_btn = gr.Button(interactive=False)
144
 
 
145
  # Model loading function
146
  def load_model(model_path):
147
  global tokenizer, model, processor, context_len
148
-
149
  logger.info(f"Loading model: {model_path}")
150
-
151
  try:
152
  # Check if it's a Qwen model
153
  if "qwenguard" in model_path.lower():
154
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
155
- model_path,
156
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
157
  torch_dtype="auto",
158
  device_map="auto"
159
  )
160
  processor = AutoProcessor.from_pretrained(model_path)
161
  tokenizer = processor.tokenizer
162
-
163
  # Otherwise assume it's a LlavaGuard model
164
  else:
165
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
166
  model_path,
167
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
168
  torch_dtype="auto",
169
  device_map="auto",
170
  trust_remote_code=True
171
  )
172
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
173
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
174
-
175
  context_len = getattr(model.config, "max_position_embeddings", 8048)
176
- logger.info(f"Model {model_path} loaded successfully to device: {model.device}")
177
- model = model.to("cuda")
178
- return True
179
 
180
  except Exception as e:
181
  logger.error(f"Error loading model {model_path}: {str(e)}")
182
- return False
 
183
 
184
  def get_model_list():
185
  models = [
@@ -190,17 +192,19 @@ def get_model_list():
190
  ]
191
  return models
192
 
 
193
  def get_conv_log_filename():
194
  t = datetime.datetime.now()
195
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
196
  os.makedirs(os.path.dirname(name), exist_ok=True)
197
  return name
198
 
 
199
  # Inference function
200
  @spaces.GPU
201
  def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
202
  global model, tokenizer, processor
203
-
204
  if model is None or processor is None:
205
  return "Model not loaded. Please select a model first."
206
  try:
@@ -227,7 +231,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
227
  return_tensors="pt",
228
  )
229
 
230
-
231
  # Otherwise assume it's a LlavaGuard model
232
  else:
233
  conversation = [
@@ -272,6 +276,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
272
  logger.error(error_msg)
273
  return f"Error processing image. Please try again."
274
 
 
275
  # Gradio UI functions
276
  get_window_url_params = """
277
  function() {
@@ -282,10 +287,11 @@ function() {
282
  }
283
  """
284
 
 
285
  def load_demo(url_params, request: gr.Request):
286
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
287
  models = get_model_list()
288
-
289
  dropdown_update = gr.Dropdown(visible=True)
290
  if "model" in url_params:
291
  model = url_params["model"]
@@ -296,6 +302,7 @@ def load_demo(url_params, request: gr.Request):
296
  state = default_conversation.copy()
297
  return state, dropdown_update
298
 
 
299
  def load_demo_refresh_model_list(request: gr.Request):
300
  logger.info(f"load_demo. ip: {request.client.host}")
301
  models = get_model_list()
@@ -306,6 +313,7 @@ def load_demo_refresh_model_list(request: gr.Request):
306
  )
307
  return state, dropdown_update
308
 
 
309
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
310
  with open(get_conv_log_filename(), "a") as fout:
311
  data = {
@@ -317,21 +325,25 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
317
  }
318
  fout.write(json.dumps(data) + "\n")
319
 
 
320
  def upvote_last_response(state, model_selector, request: gr.Request):
321
  logger.info(f"upvote. ip: {request.client.host}")
322
  vote_last_response(state, "upvote", model_selector, request)
323
  return ("",) + (disable_btn,) * 3
324
 
 
325
  def downvote_last_response(state, model_selector, request: gr.Request):
326
  logger.info(f"downvote. ip: {request.client.host}")
327
  vote_last_response(state, "downvote", model_selector, request)
328
  return ("",) + (disable_btn,) * 3
329
 
 
330
  def flag_last_response(state, model_selector, request: gr.Request):
331
  logger.info(f"flag. ip: {request.client.host}")
332
  vote_last_response(state, "flag", model_selector, request)
333
  return ("",) + (disable_btn,) * 3
334
 
 
335
  def regenerate(state, image_process_mode, request: gr.Request):
336
  logger.info(f"regenerate. ip: {request.client.host}")
337
  if state.messages and len(state.messages) > 0:
@@ -344,15 +356,17 @@ def regenerate(state, image_process_mode, request: gr.Request):
344
  if len(prev_human_msg[0]) >= 3:
345
  new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
346
  state.messages[-2] = new_msg
347
-
348
  state.skip_next = False
349
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
350
 
 
351
  def clear_history(request: gr.Request):
352
  logger.info(f"clear_history. ip: {request.client.host}")
353
  state = default_conversation.copy()
354
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
355
 
 
356
  def add_text(state, text, image, image_process_mode, request: gr.Request):
357
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
358
  if len(text) <= 0 or image is None:
@@ -360,24 +374,25 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
360
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
361
 
362
  text = wrap_taxonomy(text)
363
-
364
  # Reset conversation for new image-based query
365
  if image is not None:
366
  state = default_conversation.copy()
367
-
368
  # Set new prompt with image
369
  prompt = text
370
  if image is not None:
371
  prompt = (text, image, image_process_mode)
372
-
373
  state.set_prompt(prompt=prompt, image=image)
374
  state.skip_next = False
375
-
376
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
377
 
 
378
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
379
  start_tstamp = time.time()
380
-
381
  if state.skip_next:
382
  # This generate call is skipped due to invalid inputs
383
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
@@ -386,7 +401,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
386
  # Get the prompt and images
387
  prompt = state.get_prompt()
388
  all_images = state.get_image(return_pil=True)
389
-
390
  if not all_images:
391
  if not state.messages:
392
  state.messages = [["Error: No image provided", None]]
@@ -394,14 +409,14 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
394
  state.messages[-1][-1] = "Error: No image provided"
395
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
396
  return
397
-
398
  # Load model if needed
399
  if model is None or model_selector != getattr(model, "_name_or_path", ""):
400
  load_model(model_selector)
401
-
402
  # Run inference
403
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
404
-
405
  # Update the response in the conversation state
406
  if not state.messages:
407
  state.messages = [[prompt, output]]
@@ -430,6 +445,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
430
  except Exception as e:
431
  logger.error(f"Error writing log: {str(e)}")
432
 
 
433
  # UI Components
434
  title_markdown = """
435
  # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
@@ -459,6 +475,7 @@ block_css = """
459
  }
460
  """
461
 
 
462
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
463
  models = get_model_list()
464
 
@@ -486,17 +503,18 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
486
 
487
  if cur_dir is None:
488
  cur_dir = os.path.dirname(os.path.abspath(__file__))
489
-
490
  gr.Examples(examples=[
491
- [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if os.path.exists(f"{cur_dir}/examples/image{i}.png")
 
492
  ], inputs=imagebox)
493
 
494
  with gr.Accordion("Parameters", open=False) as parameter_row:
495
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
496
- label="Temperature")
497
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
498
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
499
- label="Max output tokens")
500
 
501
  with gr.Accordion("Safety Risk Taxonomy", open=False):
502
  taxonomy_textbox = gr.Textbox(
@@ -538,25 +556,25 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
538
 
539
  # Register listeners
540
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
541
-
542
  upvote_btn.click(
543
  upvote_last_response,
544
  [state, model_selector],
545
  [textbox, upvote_btn, downvote_btn, flag_btn]
546
  )
547
-
548
  downvote_btn.click(
549
  downvote_last_response,
550
  [state, model_selector],
551
  [textbox, upvote_btn, downvote_btn, flag_btn]
552
  )
553
-
554
  flag_btn.click(
555
  flag_last_response,
556
  [state, model_selector],
557
  [textbox, upvote_btn, downvote_btn, flag_btn]
558
  )
559
-
560
  model_selector.change(
561
  load_model,
562
  [model_selector],
@@ -626,38 +644,39 @@ if __name__ == "__main__":
626
 
627
  # Create log directory if it doesn't exist
628
  os.makedirs(LOGDIR, exist_ok=True)
629
-
630
  # GPU Check
631
  if torch.cuda.is_available():
632
  logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
633
  else:
634
  logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
635
-
636
  # Hugging Face token handling
637
  api_key = os.getenv("token")
638
  if api_key:
639
  from huggingface_hub import login
 
640
  login(token=api_key)
641
  logger.info("Logged in to Hugging Face Hub")
642
-
643
- # Load initial model
644
- models = get_model_list()
645
- model_path = os.getenv("model", models[0])
646
- logger.info(f"Initial model selected: {model_path}")
647
- load_model(model_path)
648
-
649
- # Launch Gradio app
650
- try:
651
- demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
652
- demo.queue(
653
- status_update_rate=10,
654
- api_open=False
655
- ).launch(
656
- server_name=args.host,
657
- server_port=args.port,
658
- share=args.share
659
- )
660
- except Exception as e:
661
- logger.error(f"Error launching demo: {e}")
662
- sys.exit(1)
663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from PIL import Image
13
  from transformers import (
14
  AutoProcessor,
15
+ AutoTokenizer,
16
  Qwen2_5_VLForConditionalGeneration,
17
  LlavaOnevisionForConditionalGeneration
18
  )
 
74
  def to_gradio_chatbot(self):
75
  if not self.messages:
76
  return []
77
+
78
  ret = []
79
  for msg in self.messages:
80
  prompt = msg[0]
81
  if isinstance(prompt, tuple) and len(prompt) > 0:
82
  prompt = prompt[0]
83
+
84
  if prompt and isinstance(prompt, str) and "<image>" in prompt:
85
  prompt = prompt.replace("<image>", "")
86
+
87
  ret.append([prompt, msg[1]])
88
  return ret
89
 
 
123
  new_conv.messages = self.messages.copy() if self.messages else []
124
  return new_conv
125
 
126
+
127
  default_conversation = SimpleConversation()
128
 
129
  # Model and processor storage
 
132
  processor = None
133
  context_len = 8048
134
 
135
+
136
  def wrap_taxonomy(text):
137
  """Wraps user input with taxonomy if not already present"""
138
  if policy_v1 not in text:
139
  return policy_v1 + "\n\n" + text
140
  return text
141
 
142
+
143
  # UI component states
144
  no_change_btn = gr.Button()
145
  enable_btn = gr.Button(interactive=True)
146
  disable_btn = gr.Button(interactive=False)
147
 
148
+
149
  # Model loading function
150
  def load_model(model_path):
151
  global tokenizer, model, processor, context_len
152
+
153
  logger.info(f"Loading model: {model_path}")
154
+
155
  try:
156
  # Check if it's a Qwen model
157
  if "qwenguard" in model_path.lower():
158
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
159
+ model_path,
 
160
  torch_dtype="auto",
161
  device_map="auto"
162
  )
163
  processor = AutoProcessor.from_pretrained(model_path)
164
  tokenizer = processor.tokenizer
165
+
166
  # Otherwise assume it's a LlavaGuard model
167
  else:
168
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
169
  model_path,
 
170
  torch_dtype="auto",
171
  device_map="auto",
172
  trust_remote_code=True
173
  )
174
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
175
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
176
+
177
  context_len = getattr(model.config, "max_position_embeddings", 8048)
178
+ logger.info(f"Model {model_path} loaded successfully")
179
+ return # Remove return value to avoid Gradio warnings
 
180
 
181
  except Exception as e:
182
  logger.error(f"Error loading model {model_path}: {str(e)}")
183
+ return # Remove return value to avoid Gradio warnings
184
+
185
 
186
  def get_model_list():
187
  models = [
 
192
  ]
193
  return models
194
 
195
+
196
  def get_conv_log_filename():
197
  t = datetime.datetime.now()
198
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
199
  os.makedirs(os.path.dirname(name), exist_ok=True)
200
  return name
201
 
202
+
203
  # Inference function
204
  @spaces.GPU
205
  def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
206
  global model, tokenizer, processor
207
+
208
  if model is None or processor is None:
209
  return "Model not loaded. Please select a model first."
210
  try:
 
231
  return_tensors="pt",
232
  )
233
 
234
+
235
  # Otherwise assume it's a LlavaGuard model
236
  else:
237
  conversation = [
 
276
  logger.error(error_msg)
277
  return f"Error processing image. Please try again."
278
 
279
+
280
  # Gradio UI functions
281
  get_window_url_params = """
282
  function() {
 
287
  }
288
  """
289
 
290
+
291
  def load_demo(url_params, request: gr.Request):
292
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
293
  models = get_model_list()
294
+
295
  dropdown_update = gr.Dropdown(visible=True)
296
  if "model" in url_params:
297
  model = url_params["model"]
 
302
  state = default_conversation.copy()
303
  return state, dropdown_update
304
 
305
+
306
  def load_demo_refresh_model_list(request: gr.Request):
307
  logger.info(f"load_demo. ip: {request.client.host}")
308
  models = get_model_list()
 
313
  )
314
  return state, dropdown_update
315
 
316
+
317
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
318
  with open(get_conv_log_filename(), "a") as fout:
319
  data = {
 
325
  }
326
  fout.write(json.dumps(data) + "\n")
327
 
328
+
329
  def upvote_last_response(state, model_selector, request: gr.Request):
330
  logger.info(f"upvote. ip: {request.client.host}")
331
  vote_last_response(state, "upvote", model_selector, request)
332
  return ("",) + (disable_btn,) * 3
333
 
334
+
335
  def downvote_last_response(state, model_selector, request: gr.Request):
336
  logger.info(f"downvote. ip: {request.client.host}")
337
  vote_last_response(state, "downvote", model_selector, request)
338
  return ("",) + (disable_btn,) * 3
339
 
340
+
341
  def flag_last_response(state, model_selector, request: gr.Request):
342
  logger.info(f"flag. ip: {request.client.host}")
343
  vote_last_response(state, "flag", model_selector, request)
344
  return ("",) + (disable_btn,) * 3
345
 
346
+
347
  def regenerate(state, image_process_mode, request: gr.Request):
348
  logger.info(f"regenerate. ip: {request.client.host}")
349
  if state.messages and len(state.messages) > 0:
 
356
  if len(prev_human_msg[0]) >= 3:
357
  new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
358
  state.messages[-2] = new_msg
359
+
360
  state.skip_next = False
361
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
362
 
363
+
364
  def clear_history(request: gr.Request):
365
  logger.info(f"clear_history. ip: {request.client.host}")
366
  state = default_conversation.copy()
367
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
368
 
369
+
370
  def add_text(state, text, image, image_process_mode, request: gr.Request):
371
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
372
  if len(text) <= 0 or image is None:
 
374
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
375
 
376
  text = wrap_taxonomy(text)
377
+
378
  # Reset conversation for new image-based query
379
  if image is not None:
380
  state = default_conversation.copy()
381
+
382
  # Set new prompt with image
383
  prompt = text
384
  if image is not None:
385
  prompt = (text, image, image_process_mode)
386
+
387
  state.set_prompt(prompt=prompt, image=image)
388
  state.skip_next = False
389
+
390
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
391
 
392
+
393
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
394
  start_tstamp = time.time()
395
+
396
  if state.skip_next:
397
  # This generate call is skipped due to invalid inputs
398
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
 
401
  # Get the prompt and images
402
  prompt = state.get_prompt()
403
  all_images = state.get_image(return_pil=True)
404
+
405
  if not all_images:
406
  if not state.messages:
407
  state.messages = [["Error: No image provided", None]]
 
409
  state.messages[-1][-1] = "Error: No image provided"
410
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
411
  return
412
+
413
  # Load model if needed
414
  if model is None or model_selector != getattr(model, "_name_or_path", ""):
415
  load_model(model_selector)
416
+
417
  # Run inference
418
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
419
+
420
  # Update the response in the conversation state
421
  if not state.messages:
422
  state.messages = [[prompt, output]]
 
445
  except Exception as e:
446
  logger.error(f"Error writing log: {str(e)}")
447
 
448
+
449
  # UI Components
450
  title_markdown = """
451
  # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
 
475
  }
476
  """
477
 
478
+
479
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
480
  models = get_model_list()
481
 
 
503
 
504
  if cur_dir is None:
505
  cur_dir = os.path.dirname(os.path.abspath(__file__))
506
+
507
  gr.Examples(examples=[
508
+ [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if
509
+ os.path.exists(f"{cur_dir}/examples/image{i}.png")
510
  ], inputs=imagebox)
511
 
512
  with gr.Accordion("Parameters", open=False) as parameter_row:
513
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
514
+ label="Temperature")
515
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
516
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
517
+ label="Max output tokens")
518
 
519
  with gr.Accordion("Safety Risk Taxonomy", open=False):
520
  taxonomy_textbox = gr.Textbox(
 
556
 
557
  # Register listeners
558
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
559
+
560
  upvote_btn.click(
561
  upvote_last_response,
562
  [state, model_selector],
563
  [textbox, upvote_btn, downvote_btn, flag_btn]
564
  )
565
+
566
  downvote_btn.click(
567
  downvote_last_response,
568
  [state, model_selector],
569
  [textbox, upvote_btn, downvote_btn, flag_btn]
570
  )
571
+
572
  flag_btn.click(
573
  flag_last_response,
574
  [state, model_selector],
575
  [textbox, upvote_btn, downvote_btn, flag_btn]
576
  )
577
+
578
  model_selector.change(
579
  load_model,
580
  [model_selector],
 
644
 
645
  # Create log directory if it doesn't exist
646
  os.makedirs(LOGDIR, exist_ok=True)
647
+
648
  # GPU Check
649
  if torch.cuda.is_available():
650
  logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
651
  else:
652
  logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
653
+
654
  # Hugging Face token handling
655
  api_key = os.getenv("token")
656
  if api_key:
657
  from huggingface_hub import login
658
+
659
  login(token=api_key)
660
  logger.info("Logged in to Hugging Face Hub")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
 
662
+ # Launch Gradio app in a subprocess to avoid CUDA initialization in the main process
663
+ from torch.multiprocessing import Process
664
+
665
+ def launch_demo():
666
+ try:
667
+ demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
668
+ demo.queue(
669
+ status_update_rate=10,
670
+ api_open=False
671
+ ).launch(
672
+ server_name=args.host,
673
+ server_port=args.port,
674
+ share=args.share
675
+ )
676
+ except Exception as e:
677
+ logger.error(f"Error launching demo: {e}")
678
+ sys.exit(1)
679
+
680
+ p = Process(target=launch_demo)
681
+ p.start()
682
+ p.join()