import argparse import datetime import hashlib import json import logging import os import sys import time import gradio as gr import torch from PIL import Image from transformers import ( AutoProcessor, AutoTokenizer, Qwen2_5_VLForConditionalGeneration, LlavaOnevisionForConditionalGeneration ) from taxonomy import policy_v1 # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("gradio_web_server.log"), logging.StreamHandler() ] ) logger = logging.getLogger("gradio_web_server") # Constants LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True) default_taxonomy = policy_v1 class Conversation: def __init__(self): self.messages = [] self.roles = ["user", "assistant"] self.offset = 0 self.skip_next = False def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for role, message in self.messages: if message is None: continue if role == self.roles[0]: if isinstance(message, tuple): ret.append([self.render_user_message(message[0]), None]) else: ret.append([self.render_user_message(message), None]) elif role == self.roles[1]: if ret[-1][1] is None: ret[-1][1] = message else: ret.append([None, message]) else: raise ValueError(f"Invalid role: {role}") return ret def render_user_message(self, message): if "" in message: return message.replace("", "") return message def dict(self): # Create a serializable version of messages serialized_messages = [] for role, message in self.messages: if isinstance(message, tuple) and len(message) > 1: # If the message contains an image (tuple format) if isinstance(message[1], Image.Image): # Just keep the text part and ignore the image serialized_message = (message[0], "[IMAGE_IGNORED]") else: # For non-image tuples, keep as is serialized_message = message else: # For non-tuple messages, keep as is serialized_message = message serialized_messages.append([role, serialized_message]) return { "messages": serialized_messages, "roles": self.roles, "offset": self.offset, "skip_next": self.skip_next, } def get_prompt(self): prompt = "" for role, message in self.messages: if message is None: continue if isinstance(message, tuple): message = message[0] if role == self.roles[0]: prompt += f"USER: {message}\n" else: prompt += f"ASSISTANT: {message}\n" return prompt + "ASSISTANT: " def get_images(self, return_pil=False): images = [] for role, message in self.messages: if isinstance(message, tuple) and len(message) > 1: if isinstance(message[1], Image.Image): images.append(message[1] if return_pil else message[1]) return images def copy(self): new_conv = Conversation() new_conv.messages = self.messages.copy() new_conv.roles = self.roles.copy() new_conv.offset = self.offset new_conv.skip_next = self.skip_next return new_conv default_conversation = Conversation() # Model and processor storage tokenizer = None model = None processor = None context_len = 8048 # Helper functions def clear_conv(conv): conv.messages = [] return conv def wrap_taxonomy(text): """Wraps user input with taxonomy if not already present""" if policy_v1 not in text: return policy_v1 + "\n\n" + text return text # UI component states no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) # Model loading function def load_model(model_path): global tokenizer, model, processor, context_len logger.info(f"Loading model: {model_path}") try: # Check if it's a Qwen model if "qwenguard" in model_path.lower(): model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) processor = AutoProcessor.from_pretrained(model_path) tokenizer = processor.tokenizer # Otherwise assume it's a LlavaGuard model else: model = LlavaOnevisionForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) context_len = getattr(model.config, "max_position_embeddings", 8048) logger.info(f"Model {model_path} loaded successfully") return True except Exception as e: logger.error(f"Error loading model {model_path}: {str(e)}") return False def get_model_list(): models = [ 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf', 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf', 'AIML-TUDA/QwenGuard-v1.2-3B', 'AIML-TUDA/QwenGuard-v1.2-7B', ] return models def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") os.makedirs(os.path.dirname(name), exist_ok=True) return name # Inference function def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512): global model, tokenizer, processor if model is None or processor is None: return "Model not loaded. Please select a model first." try: # Check if it's a Qwen model if isinstance(model, Qwen2_5_VLForConditionalGeneration): # Format for Qwen models messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] # Process input text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[image], padding=True, return_tensors="pt" ) # Move to GPU if available if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Generate with torch.no_grad(): generated_ids = model.generate( **inputs, do_sample=temperature > 0, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, ) # Decode generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] response = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Otherwise assume it's a LlavaGuard model else: conversation = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], }, ] text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) # Process input for LlavaGuard models inputs = processor(text=text_prompt, images=image, return_tensors="pt") # Move to GPU if available if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Generate with torch.no_grad(): generated_ids = model.generate( **inputs, do_sample=temperature > 0, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, ) # Decode response = tokenizer.batch_decode( generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True )[0] return response.strip() except Exception as e: logger.error(f"Error during inference: {e}") return f"Error during inference: {e}" # Gradio UI functions get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") models = get_model_list() dropdown_update = gr.Dropdown(visible=True) if "model" in url_params: model = url_params["model"] if model in models: dropdown_update = gr.Dropdown(value=model, visible=True) load_model(model) state = default_conversation.copy() return state, dropdown_update def load_demo_refresh_model_list(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") models = get_model_list() state = default_conversation.copy() dropdown_update = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "" ) return state, dropdown_update def vote_last_response(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, model_selector, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, model_selector, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, model_selector, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def add_text(state, text, image, image_process_mode, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0 or image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 text = wrap_taxonomy(text) if image is not None: if '' not in text: text = text + '\n' text = (text, image, image_process_mode) state = default_conversation.copy() state = clear_conv(state) state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5 def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): start_tstamp = time.time() if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return # Get the prompt and images prompt = state.get_prompt() all_images = state.get_images(return_pil=True) if not all_images: state.messages[-1][-1] = "Error: No image provided" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 return # Save image for logging all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] for image, hash_val in zip(all_images, all_image_hash): t = datetime.datetime.now() filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg") if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) # Load model if needed if model is None or model_selector != getattr(model, "_name_or_path", ""): load_model(model_selector) # Run inference output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens) state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 finish_tstamp = time.time() logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_selector, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "images": all_image_hash, "ip": request.client.host, } fout.write(json.dumps(data) + "\n") # UI Components title_markdown = """ # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)] [[Code](https://github.com/ml-research/LlavaGuard)] [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)] [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)] [[LavaGuard](https://arxiv.org/abs/2406.05113)] """ tos_markdown = """ ### Terms of use By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. """ learn_more_markdown = """ ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """ block_css = """ #buttons button { min-width: min(120px,100%); } """ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): models = get_model_list() with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() if not embed_mode: gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=3): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) imagebox = gr.Image(type="pil", label="Image", container=False) image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) if cur_dir is None: cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples(examples=[ [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if os.path.exists(f"{cur_dir}/examples/image{i}.png") ], inputs=imagebox) with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature") top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P") max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens") with gr.Accordion("Safety Risk Taxonomy", open=False): taxonomy_textbox = gr.Textbox( label="Safety Risk Taxonomy", show_label=True, placeholder="Enter your safety policy here", value=default_taxonomy, lines=20) with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="LLavaGuard Safety Assessment", height=650, layout="panel", ) with gr.Row(): with gr.Column(scale=8): textbox = gr.Textbox( show_label=False, placeholder="Enter your message here", container=True, value=default_taxonomy, lines=3, ) with gr.Column(scale=1, min_width=50): submit_btn = gr.Button(value="Send", variant="primary") with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠️ Flag", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear", interactive=False) if not embed_mode: gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) url_params = gr.JSON(visible=False) # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click( upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) downvote_btn.click( downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) flag_btn.click( flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) model_selector.change( load_model, [model_selector], None ) regenerate_btn.click( regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( llava_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) clear_btn.click( clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False ) textbox.submit( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( llava_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) submit_btn.click( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( llava_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) demo.load( load_demo_refresh_model_list, None, [state, model_selector], queue=False ) return demo if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--concurrency-count", type=int, default=5) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") args = parser.parse_args() # Create log directory if it doesn't exist os.makedirs(LOGDIR, exist_ok=True) # GPU Check if torch.cuda.is_available(): logger.info(f"CUDA available with {torch.cuda.device_count()} devices") else: logger.warning("CUDA not available! Models will run on CPU which may be very slow.") # Hugging Face token handling api_key = os.getenv("token") if api_key: from huggingface_hub import login login(token=api_key) logger.info("Logged in to Hugging Face Hub") # Load initial model models = get_model_list() model_path = os.getenv("model", models[0]) logger.info(f"Initial model selected: {model_path}") load_model(model_path) # Launch Gradio app try: demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count) demo.queue( status_update_rate=10, api_open=False ).launch( server_name=args.host, server_port=args.port, share=True ) except Exception as e: logger.error(f"Error launching demo: {e}") sys.exit(1)