Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ import torch
|
|
12 |
from PIL import Image
|
13 |
from transformers import (
|
14 |
AutoProcessor,
|
15 |
-
AutoTokenizer,
|
16 |
Qwen2_5_VLForConditionalGeneration,
|
17 |
LlavaOnevisionForConditionalGeneration
|
18 |
)
|
@@ -74,16 +74,16 @@ class SimpleConversation:
|
|
74 |
def to_gradio_chatbot(self):
|
75 |
if not self.messages:
|
76 |
return []
|
77 |
-
|
78 |
ret = []
|
79 |
for msg in self.messages:
|
80 |
prompt = msg[0]
|
81 |
if isinstance(prompt, tuple) and len(prompt) > 0:
|
82 |
prompt = prompt[0]
|
83 |
-
|
84 |
if prompt and isinstance(prompt, str) and "<image>" in prompt:
|
85 |
prompt = prompt.replace("<image>", "")
|
86 |
-
|
87 |
ret.append([prompt, msg[1]])
|
88 |
return ret
|
89 |
|
@@ -123,6 +123,7 @@ class SimpleConversation:
|
|
123 |
new_conv.messages = self.messages.copy() if self.messages else []
|
124 |
return new_conv
|
125 |
|
|
|
126 |
default_conversation = SimpleConversation()
|
127 |
|
128 |
# Model and processor storage
|
@@ -131,55 +132,56 @@ model = None
|
|
131 |
processor = None
|
132 |
context_len = 8048
|
133 |
|
|
|
134 |
def wrap_taxonomy(text):
|
135 |
"""Wraps user input with taxonomy if not already present"""
|
136 |
if policy_v1 not in text:
|
137 |
return policy_v1 + "\n\n" + text
|
138 |
return text
|
139 |
|
|
|
140 |
# UI component states
|
141 |
no_change_btn = gr.Button()
|
142 |
enable_btn = gr.Button(interactive=True)
|
143 |
disable_btn = gr.Button(interactive=False)
|
144 |
|
|
|
145 |
# Model loading function
|
146 |
def load_model(model_path):
|
147 |
global tokenizer, model, processor, context_len
|
148 |
-
|
149 |
logger.info(f"Loading model: {model_path}")
|
150 |
-
|
151 |
try:
|
152 |
# Check if it's a Qwen model
|
153 |
if "qwenguard" in model_path.lower():
|
154 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
155 |
-
model_path,
|
156 |
-
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
157 |
torch_dtype="auto",
|
158 |
device_map="auto"
|
159 |
)
|
160 |
processor = AutoProcessor.from_pretrained(model_path)
|
161 |
tokenizer = processor.tokenizer
|
162 |
-
|
163 |
# Otherwise assume it's a LlavaGuard model
|
164 |
else:
|
165 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
166 |
model_path,
|
167 |
-
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
168 |
torch_dtype="auto",
|
169 |
device_map="auto",
|
170 |
trust_remote_code=True
|
171 |
)
|
172 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
173 |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
174 |
-
|
175 |
context_len = getattr(model.config, "max_position_embeddings", 8048)
|
176 |
-
logger.info(f"Model {model_path} loaded successfully
|
177 |
-
|
178 |
-
return True
|
179 |
|
180 |
except Exception as e:
|
181 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
182 |
-
return
|
|
|
183 |
|
184 |
def get_model_list():
|
185 |
models = [
|
@@ -190,17 +192,19 @@ def get_model_list():
|
|
190 |
]
|
191 |
return models
|
192 |
|
|
|
193 |
def get_conv_log_filename():
|
194 |
t = datetime.datetime.now()
|
195 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
196 |
os.makedirs(os.path.dirname(name), exist_ok=True)
|
197 |
return name
|
198 |
|
|
|
199 |
# Inference function
|
200 |
@spaces.GPU
|
201 |
def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
202 |
global model, tokenizer, processor
|
203 |
-
|
204 |
if model is None or processor is None:
|
205 |
return "Model not loaded. Please select a model first."
|
206 |
try:
|
@@ -227,7 +231,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
227 |
return_tensors="pt",
|
228 |
)
|
229 |
|
230 |
-
|
231 |
# Otherwise assume it's a LlavaGuard model
|
232 |
else:
|
233 |
conversation = [
|
@@ -272,6 +276,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
272 |
logger.error(error_msg)
|
273 |
return f"Error processing image. Please try again."
|
274 |
|
|
|
275 |
# Gradio UI functions
|
276 |
get_window_url_params = """
|
277 |
function() {
|
@@ -282,10 +287,11 @@ function() {
|
|
282 |
}
|
283 |
"""
|
284 |
|
|
|
285 |
def load_demo(url_params, request: gr.Request):
|
286 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
287 |
models = get_model_list()
|
288 |
-
|
289 |
dropdown_update = gr.Dropdown(visible=True)
|
290 |
if "model" in url_params:
|
291 |
model = url_params["model"]
|
@@ -296,6 +302,7 @@ def load_demo(url_params, request: gr.Request):
|
|
296 |
state = default_conversation.copy()
|
297 |
return state, dropdown_update
|
298 |
|
|
|
299 |
def load_demo_refresh_model_list(request: gr.Request):
|
300 |
logger.info(f"load_demo. ip: {request.client.host}")
|
301 |
models = get_model_list()
|
@@ -306,6 +313,7 @@ def load_demo_refresh_model_list(request: gr.Request):
|
|
306 |
)
|
307 |
return state, dropdown_update
|
308 |
|
|
|
309 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
310 |
with open(get_conv_log_filename(), "a") as fout:
|
311 |
data = {
|
@@ -317,21 +325,25 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
|
317 |
}
|
318 |
fout.write(json.dumps(data) + "\n")
|
319 |
|
|
|
320 |
def upvote_last_response(state, model_selector, request: gr.Request):
|
321 |
logger.info(f"upvote. ip: {request.client.host}")
|
322 |
vote_last_response(state, "upvote", model_selector, request)
|
323 |
return ("",) + (disable_btn,) * 3
|
324 |
|
|
|
325 |
def downvote_last_response(state, model_selector, request: gr.Request):
|
326 |
logger.info(f"downvote. ip: {request.client.host}")
|
327 |
vote_last_response(state, "downvote", model_selector, request)
|
328 |
return ("",) + (disable_btn,) * 3
|
329 |
|
|
|
330 |
def flag_last_response(state, model_selector, request: gr.Request):
|
331 |
logger.info(f"flag. ip: {request.client.host}")
|
332 |
vote_last_response(state, "flag", model_selector, request)
|
333 |
return ("",) + (disable_btn,) * 3
|
334 |
|
|
|
335 |
def regenerate(state, image_process_mode, request: gr.Request):
|
336 |
logger.info(f"regenerate. ip: {request.client.host}")
|
337 |
if state.messages and len(state.messages) > 0:
|
@@ -344,15 +356,17 @@ def regenerate(state, image_process_mode, request: gr.Request):
|
|
344 |
if len(prev_human_msg[0]) >= 3:
|
345 |
new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
|
346 |
state.messages[-2] = new_msg
|
347 |
-
|
348 |
state.skip_next = False
|
349 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
350 |
|
|
|
351 |
def clear_history(request: gr.Request):
|
352 |
logger.info(f"clear_history. ip: {request.client.host}")
|
353 |
state = default_conversation.copy()
|
354 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
355 |
|
|
|
356 |
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
357 |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
358 |
if len(text) <= 0 or image is None:
|
@@ -360,24 +374,25 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
360 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
361 |
|
362 |
text = wrap_taxonomy(text)
|
363 |
-
|
364 |
# Reset conversation for new image-based query
|
365 |
if image is not None:
|
366 |
state = default_conversation.copy()
|
367 |
-
|
368 |
# Set new prompt with image
|
369 |
prompt = text
|
370 |
if image is not None:
|
371 |
prompt = (text, image, image_process_mode)
|
372 |
-
|
373 |
state.set_prompt(prompt=prompt, image=image)
|
374 |
state.skip_next = False
|
375 |
-
|
376 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
377 |
|
|
|
378 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
379 |
start_tstamp = time.time()
|
380 |
-
|
381 |
if state.skip_next:
|
382 |
# This generate call is skipped due to invalid inputs
|
383 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
@@ -386,7 +401,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
386 |
# Get the prompt and images
|
387 |
prompt = state.get_prompt()
|
388 |
all_images = state.get_image(return_pil=True)
|
389 |
-
|
390 |
if not all_images:
|
391 |
if not state.messages:
|
392 |
state.messages = [["Error: No image provided", None]]
|
@@ -394,14 +409,14 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
394 |
state.messages[-1][-1] = "Error: No image provided"
|
395 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
396 |
return
|
397 |
-
|
398 |
# Load model if needed
|
399 |
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
400 |
load_model(model_selector)
|
401 |
-
|
402 |
# Run inference
|
403 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
404 |
-
|
405 |
# Update the response in the conversation state
|
406 |
if not state.messages:
|
407 |
state.messages = [[prompt, output]]
|
@@ -430,6 +445,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
430 |
except Exception as e:
|
431 |
logger.error(f"Error writing log: {str(e)}")
|
432 |
|
|
|
433 |
# UI Components
|
434 |
title_markdown = """
|
435 |
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
|
@@ -459,6 +475,7 @@ block_css = """
|
|
459 |
}
|
460 |
"""
|
461 |
|
|
|
462 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
463 |
models = get_model_list()
|
464 |
|
@@ -486,17 +503,18 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
486 |
|
487 |
if cur_dir is None:
|
488 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
489 |
-
|
490 |
gr.Examples(examples=[
|
491 |
-
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if
|
|
|
492 |
], inputs=imagebox)
|
493 |
|
494 |
with gr.Accordion("Parameters", open=False) as parameter_row:
|
495 |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
496 |
-
|
497 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
|
498 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
499 |
-
|
500 |
|
501 |
with gr.Accordion("Safety Risk Taxonomy", open=False):
|
502 |
taxonomy_textbox = gr.Textbox(
|
@@ -538,25 +556,25 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
538 |
|
539 |
# Register listeners
|
540 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
541 |
-
|
542 |
upvote_btn.click(
|
543 |
upvote_last_response,
|
544 |
[state, model_selector],
|
545 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
546 |
)
|
547 |
-
|
548 |
downvote_btn.click(
|
549 |
downvote_last_response,
|
550 |
[state, model_selector],
|
551 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
552 |
)
|
553 |
-
|
554 |
flag_btn.click(
|
555 |
flag_last_response,
|
556 |
[state, model_selector],
|
557 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
558 |
)
|
559 |
-
|
560 |
model_selector.change(
|
561 |
load_model,
|
562 |
[model_selector],
|
@@ -626,38 +644,39 @@ if __name__ == "__main__":
|
|
626 |
|
627 |
# Create log directory if it doesn't exist
|
628 |
os.makedirs(LOGDIR, exist_ok=True)
|
629 |
-
|
630 |
# GPU Check
|
631 |
if torch.cuda.is_available():
|
632 |
logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
|
633 |
else:
|
634 |
logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
|
635 |
-
|
636 |
# Hugging Face token handling
|
637 |
api_key = os.getenv("token")
|
638 |
if api_key:
|
639 |
from huggingface_hub import login
|
|
|
640 |
login(token=api_key)
|
641 |
logger.info("Logged in to Hugging Face Hub")
|
642 |
-
|
643 |
-
# Load initial model
|
644 |
-
models = get_model_list()
|
645 |
-
model_path = os.getenv("model", models[0])
|
646 |
-
logger.info(f"Initial model selected: {model_path}")
|
647 |
-
load_model(model_path)
|
648 |
-
|
649 |
-
# Launch Gradio app
|
650 |
-
try:
|
651 |
-
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
652 |
-
demo.queue(
|
653 |
-
status_update_rate=10,
|
654 |
-
api_open=False
|
655 |
-
).launch(
|
656 |
-
server_name=args.host,
|
657 |
-
server_port=args.port,
|
658 |
-
share=args.share
|
659 |
-
)
|
660 |
-
except Exception as e:
|
661 |
-
logger.error(f"Error launching demo: {e}")
|
662 |
-
sys.exit(1)
|
663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from PIL import Image
|
13 |
from transformers import (
|
14 |
AutoProcessor,
|
15 |
+
AutoTokenizer,
|
16 |
Qwen2_5_VLForConditionalGeneration,
|
17 |
LlavaOnevisionForConditionalGeneration
|
18 |
)
|
|
|
74 |
def to_gradio_chatbot(self):
|
75 |
if not self.messages:
|
76 |
return []
|
77 |
+
|
78 |
ret = []
|
79 |
for msg in self.messages:
|
80 |
prompt = msg[0]
|
81 |
if isinstance(prompt, tuple) and len(prompt) > 0:
|
82 |
prompt = prompt[0]
|
83 |
+
|
84 |
if prompt and isinstance(prompt, str) and "<image>" in prompt:
|
85 |
prompt = prompt.replace("<image>", "")
|
86 |
+
|
87 |
ret.append([prompt, msg[1]])
|
88 |
return ret
|
89 |
|
|
|
123 |
new_conv.messages = self.messages.copy() if self.messages else []
|
124 |
return new_conv
|
125 |
|
126 |
+
|
127 |
default_conversation = SimpleConversation()
|
128 |
|
129 |
# Model and processor storage
|
|
|
132 |
processor = None
|
133 |
context_len = 8048
|
134 |
|
135 |
+
|
136 |
def wrap_taxonomy(text):
|
137 |
"""Wraps user input with taxonomy if not already present"""
|
138 |
if policy_v1 not in text:
|
139 |
return policy_v1 + "\n\n" + text
|
140 |
return text
|
141 |
|
142 |
+
|
143 |
# UI component states
|
144 |
no_change_btn = gr.Button()
|
145 |
enable_btn = gr.Button(interactive=True)
|
146 |
disable_btn = gr.Button(interactive=False)
|
147 |
|
148 |
+
|
149 |
# Model loading function
|
150 |
def load_model(model_path):
|
151 |
global tokenizer, model, processor, context_len
|
152 |
+
|
153 |
logger.info(f"Loading model: {model_path}")
|
154 |
+
|
155 |
try:
|
156 |
# Check if it's a Qwen model
|
157 |
if "qwenguard" in model_path.lower():
|
158 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
159 |
+
model_path,
|
|
|
160 |
torch_dtype="auto",
|
161 |
device_map="auto"
|
162 |
)
|
163 |
processor = AutoProcessor.from_pretrained(model_path)
|
164 |
tokenizer = processor.tokenizer
|
165 |
+
|
166 |
# Otherwise assume it's a LlavaGuard model
|
167 |
else:
|
168 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
169 |
model_path,
|
|
|
170 |
torch_dtype="auto",
|
171 |
device_map="auto",
|
172 |
trust_remote_code=True
|
173 |
)
|
174 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
175 |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
176 |
+
|
177 |
context_len = getattr(model.config, "max_position_embeddings", 8048)
|
178 |
+
logger.info(f"Model {model_path} loaded successfully")
|
179 |
+
return # Remove return value to avoid Gradio warnings
|
|
|
180 |
|
181 |
except Exception as e:
|
182 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
183 |
+
return # Remove return value to avoid Gradio warnings
|
184 |
+
|
185 |
|
186 |
def get_model_list():
|
187 |
models = [
|
|
|
192 |
]
|
193 |
return models
|
194 |
|
195 |
+
|
196 |
def get_conv_log_filename():
|
197 |
t = datetime.datetime.now()
|
198 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
199 |
os.makedirs(os.path.dirname(name), exist_ok=True)
|
200 |
return name
|
201 |
|
202 |
+
|
203 |
# Inference function
|
204 |
@spaces.GPU
|
205 |
def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
206 |
global model, tokenizer, processor
|
207 |
+
|
208 |
if model is None or processor is None:
|
209 |
return "Model not loaded. Please select a model first."
|
210 |
try:
|
|
|
231 |
return_tensors="pt",
|
232 |
)
|
233 |
|
234 |
+
|
235 |
# Otherwise assume it's a LlavaGuard model
|
236 |
else:
|
237 |
conversation = [
|
|
|
276 |
logger.error(error_msg)
|
277 |
return f"Error processing image. Please try again."
|
278 |
|
279 |
+
|
280 |
# Gradio UI functions
|
281 |
get_window_url_params = """
|
282 |
function() {
|
|
|
287 |
}
|
288 |
"""
|
289 |
|
290 |
+
|
291 |
def load_demo(url_params, request: gr.Request):
|
292 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
293 |
models = get_model_list()
|
294 |
+
|
295 |
dropdown_update = gr.Dropdown(visible=True)
|
296 |
if "model" in url_params:
|
297 |
model = url_params["model"]
|
|
|
302 |
state = default_conversation.copy()
|
303 |
return state, dropdown_update
|
304 |
|
305 |
+
|
306 |
def load_demo_refresh_model_list(request: gr.Request):
|
307 |
logger.info(f"load_demo. ip: {request.client.host}")
|
308 |
models = get_model_list()
|
|
|
313 |
)
|
314 |
return state, dropdown_update
|
315 |
|
316 |
+
|
317 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
318 |
with open(get_conv_log_filename(), "a") as fout:
|
319 |
data = {
|
|
|
325 |
}
|
326 |
fout.write(json.dumps(data) + "\n")
|
327 |
|
328 |
+
|
329 |
def upvote_last_response(state, model_selector, request: gr.Request):
|
330 |
logger.info(f"upvote. ip: {request.client.host}")
|
331 |
vote_last_response(state, "upvote", model_selector, request)
|
332 |
return ("",) + (disable_btn,) * 3
|
333 |
|
334 |
+
|
335 |
def downvote_last_response(state, model_selector, request: gr.Request):
|
336 |
logger.info(f"downvote. ip: {request.client.host}")
|
337 |
vote_last_response(state, "downvote", model_selector, request)
|
338 |
return ("",) + (disable_btn,) * 3
|
339 |
|
340 |
+
|
341 |
def flag_last_response(state, model_selector, request: gr.Request):
|
342 |
logger.info(f"flag. ip: {request.client.host}")
|
343 |
vote_last_response(state, "flag", model_selector, request)
|
344 |
return ("",) + (disable_btn,) * 3
|
345 |
|
346 |
+
|
347 |
def regenerate(state, image_process_mode, request: gr.Request):
|
348 |
logger.info(f"regenerate. ip: {request.client.host}")
|
349 |
if state.messages and len(state.messages) > 0:
|
|
|
356 |
if len(prev_human_msg[0]) >= 3:
|
357 |
new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
|
358 |
state.messages[-2] = new_msg
|
359 |
+
|
360 |
state.skip_next = False
|
361 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
362 |
|
363 |
+
|
364 |
def clear_history(request: gr.Request):
|
365 |
logger.info(f"clear_history. ip: {request.client.host}")
|
366 |
state = default_conversation.copy()
|
367 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
368 |
|
369 |
+
|
370 |
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
371 |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
372 |
if len(text) <= 0 or image is None:
|
|
|
374 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
375 |
|
376 |
text = wrap_taxonomy(text)
|
377 |
+
|
378 |
# Reset conversation for new image-based query
|
379 |
if image is not None:
|
380 |
state = default_conversation.copy()
|
381 |
+
|
382 |
# Set new prompt with image
|
383 |
prompt = text
|
384 |
if image is not None:
|
385 |
prompt = (text, image, image_process_mode)
|
386 |
+
|
387 |
state.set_prompt(prompt=prompt, image=image)
|
388 |
state.skip_next = False
|
389 |
+
|
390 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
391 |
|
392 |
+
|
393 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
394 |
start_tstamp = time.time()
|
395 |
+
|
396 |
if state.skip_next:
|
397 |
# This generate call is skipped due to invalid inputs
|
398 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
|
|
401 |
# Get the prompt and images
|
402 |
prompt = state.get_prompt()
|
403 |
all_images = state.get_image(return_pil=True)
|
404 |
+
|
405 |
if not all_images:
|
406 |
if not state.messages:
|
407 |
state.messages = [["Error: No image provided", None]]
|
|
|
409 |
state.messages[-1][-1] = "Error: No image provided"
|
410 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
411 |
return
|
412 |
+
|
413 |
# Load model if needed
|
414 |
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
415 |
load_model(model_selector)
|
416 |
+
|
417 |
# Run inference
|
418 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
419 |
+
|
420 |
# Update the response in the conversation state
|
421 |
if not state.messages:
|
422 |
state.messages = [[prompt, output]]
|
|
|
445 |
except Exception as e:
|
446 |
logger.error(f"Error writing log: {str(e)}")
|
447 |
|
448 |
+
|
449 |
# UI Components
|
450 |
title_markdown = """
|
451 |
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
|
|
|
475 |
}
|
476 |
"""
|
477 |
|
478 |
+
|
479 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
480 |
models = get_model_list()
|
481 |
|
|
|
503 |
|
504 |
if cur_dir is None:
|
505 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
506 |
+
|
507 |
gr.Examples(examples=[
|
508 |
+
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if
|
509 |
+
os.path.exists(f"{cur_dir}/examples/image{i}.png")
|
510 |
], inputs=imagebox)
|
511 |
|
512 |
with gr.Accordion("Parameters", open=False) as parameter_row:
|
513 |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
514 |
+
label="Temperature")
|
515 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
|
516 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
517 |
+
label="Max output tokens")
|
518 |
|
519 |
with gr.Accordion("Safety Risk Taxonomy", open=False):
|
520 |
taxonomy_textbox = gr.Textbox(
|
|
|
556 |
|
557 |
# Register listeners
|
558 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
559 |
+
|
560 |
upvote_btn.click(
|
561 |
upvote_last_response,
|
562 |
[state, model_selector],
|
563 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
564 |
)
|
565 |
+
|
566 |
downvote_btn.click(
|
567 |
downvote_last_response,
|
568 |
[state, model_selector],
|
569 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
570 |
)
|
571 |
+
|
572 |
flag_btn.click(
|
573 |
flag_last_response,
|
574 |
[state, model_selector],
|
575 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
576 |
)
|
577 |
+
|
578 |
model_selector.change(
|
579 |
load_model,
|
580 |
[model_selector],
|
|
|
644 |
|
645 |
# Create log directory if it doesn't exist
|
646 |
os.makedirs(LOGDIR, exist_ok=True)
|
647 |
+
|
648 |
# GPU Check
|
649 |
if torch.cuda.is_available():
|
650 |
logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
|
651 |
else:
|
652 |
logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
|
653 |
+
|
654 |
# Hugging Face token handling
|
655 |
api_key = os.getenv("token")
|
656 |
if api_key:
|
657 |
from huggingface_hub import login
|
658 |
+
|
659 |
login(token=api_key)
|
660 |
logger.info("Logged in to Hugging Face Hub")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
|
662 |
+
# Launch Gradio app in a subprocess to avoid CUDA initialization in the main process
|
663 |
+
from torch.multiprocessing import Process
|
664 |
+
|
665 |
+
def launch_demo():
|
666 |
+
try:
|
667 |
+
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
668 |
+
demo.queue(
|
669 |
+
status_update_rate=10,
|
670 |
+
api_open=False
|
671 |
+
).launch(
|
672 |
+
server_name=args.host,
|
673 |
+
server_port=args.port,
|
674 |
+
share=args.share
|
675 |
+
)
|
676 |
+
except Exception as e:
|
677 |
+
logger.error(f"Error launching demo: {e}")
|
678 |
+
sys.exit(1)
|
679 |
+
|
680 |
+
p = Process(target=launch_demo)
|
681 |
+
p.start()
|
682 |
+
p.join()
|