import gradio as gr import torch from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from PIL import Image, ImageDraw, ImageFont import numpy as np import random # --- 1. Global Setup & Model Loading --- # Check for GPU availability device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load the image processor and the model # The model is loaded once and cached for all subsequent inference calls print("Loading model...") processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").to(device) print("Model loaded successfully.") # Define the classes we are interested in. # Note: "building" is not a class in the COCO-instance dataset. TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person'] # --- 2. Visualization & Drawing Logic --- # Generate a consistent color for each class label # This ensures that, for example, all 'car' masks are the same color. label_to_color = {} def get_label_color(label): """Returns a random, consistent color for a given label.""" if label not in label_to_color: # Generate a random color color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255)) label_to_color[label] = color return label_to_color[label] def draw_segmentation(image, segments_info): """ Draws masks, bounding boxes, and labels on the image. Args: image (PIL.Image.Image): The original input image. segments_info (list): A list of dictionaries, each containing info about a detected segment. """ # Make a copy of the image to draw on annotated_image = image.convert("RGBA") draw = ImageDraw.Draw(annotated_image) # Load a font try: font = ImageFont.truetype("arial.ttf", size=20) except IOError: print("Arial font not found, using default font.") font = ImageFont.load_default() for segment in segments_info: label = segment['label'] score = segment['score'] mask = segment['mask'] box = segment['box'] # Get the color for this label color = get_label_color(label) # --- Draw the mask --- # Create a colored mask image mask_image = Image.new("RGBA", image.size) mask_draw = ImageDraw.Draw(mask_image) # Convert mask tensor to a PIL-drawable format # The mask tensor is a boolean tensor, we draw where it's True pil_mask = Image.fromarray(mask.astype('uint8') * 255) # Draw the mask with semi-transparency mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency # Composite the mask onto the main image annotated_image.alpha_composite(mask_image) # --- Draw the bounding box --- draw.rectangle(box, outline=color, width=3) # --- Draw the label and score --- text = f"{label}: {score:.2f}" text_bbox = draw.textbbox((box[0], box[1]), text, font=font) # Create a small background for the text for better readability draw.rectangle(text_bbox, fill=color) draw.text((box[0], box[1]), text, fill="white", font=font) return annotated_image # --- 3. Main Prediction Function --- def predict(input_image): """ The main function that runs inference and orchestrates the process. This function is called by the Gradio interface. """ if input_image is None: return None, "Please upload an image." print("Processing image...") # Preprocess the image inputs = processor(images=input_image, return_tensors="pt").to(device) # Perform inference with torch.no_grad(): outputs = model(**inputs) # Post-process the outputs to get instance segmentation results # We specify the target image size to scale the masks and boxes correctly result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] # Filter results by score and class segments_info = [] segmentation = result["segmentation"].cpu().numpy() segments_info_raw = result["segments_info"] for segment in segments_info_raw: score = segment["score"] label_id = segment["label_id"] segment_id = segment["id"] label_name = model.config.id2label[label_id] if score > 0.5 and label_name in TARGET_CLASSES: # Create a binary mask for this segment mask = (segmentation == segment_id) # Calculate bounding box from mask pos = np.where(mask) if pos[0].size > 0 and pos[1].size > 0: xmin = np.min(pos[1]) xmax = np.max(pos[1]) ymin = np.min(pos[0]) ymax = np.max(pos[0]) segments_info.append({ "score": score, "label": label_name, "mask": mask, "box": [xmin, ymin, xmax, ymax] }) print(f"Found {len(segments_info)} objects.") # Draw the results on the image if not segments_info: return input_image, "No objects from the target classes were detected with high confidence." annotated_image = draw_segmentation(input_image, segments_info) return annotated_image, f"Successfully processed. Found {len(segments_info)} objects." # --- 4. Gradio Interface Definition --- # Load some example images # Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples' example_paths = [ "examples/street1.jpeg", "examples/street2.jpeg", "examples/street3.jpeg", "examples/street4.jpeg", "examples/street5.jpeg", "examples/street6.jpeg", "examples/street7.jpeg", "examples/street8.jpeg", "examples/street9.jpeg", "examples/street10.jpeg", "examples/catsanddogs1.jpeg", "examples/catsanddogs2.jpeg", "examples/catsanddogs3.jpeg", "examples/catsanddogs4.jpeg", "examples/catsanddogs5.jpeg", ] # Build the Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Image(type="pil", label="Segmented Image"), gr.Textbox(label="Status") ], title="Advanced Instance Segmentation with Mask2Former", description=""" Upload an image or click an example to see instance segmentation in action. The model identifies objects from the classes: **car, bus, truck, person, dog, cat**. Each object is highlighted with a colored mask, a bounding box, and a label. *Note: The free CPU can be slow; please allow up to 30 seconds for processing.* """, examples=example_paths, cache_examples=True # Cache results for examples for faster demo ) if __name__ == "__main__": demo.launch('share=True')