Brightsun10's picture
Update app.py
c0e65eb verified
import gradio as gr
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
# --- 1. Global Setup & Model Loading ---
# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the image processor and the model
# The model is loaded once and cached for all subsequent inference calls
print("Loading model...")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").to(device)
print("Model loaded successfully.")
# Define the classes we are interested in.
# Note: "building" is not a class in the COCO-instance dataset.
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
# --- 2. Visualization & Drawing Logic ---
# Generate a consistent color for each class label
# This ensures that, for example, all 'car' masks are the same color.
label_to_color = {}
def get_label_color(label):
"""Returns a random, consistent color for a given label."""
if label not in label_to_color:
# Generate a random color
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
label_to_color[label] = color
return label_to_color[label]
def draw_segmentation(image, segments_info):
"""
Draws masks, bounding boxes, and labels on the image.
Args:
image (PIL.Image.Image): The original input image.
segments_info (list): A list of dictionaries, each containing info about a detected segment.
"""
# Make a copy of the image to draw on
annotated_image = image.convert("RGBA")
draw = ImageDraw.Draw(annotated_image)
# Load a font
try:
font = ImageFont.truetype("arial.ttf", size=20)
except IOError:
print("Arial font not found, using default font.")
font = ImageFont.load_default()
for segment in segments_info:
label = segment['label']
score = segment['score']
mask = segment['mask']
box = segment['box']
# Get the color for this label
color = get_label_color(label)
# --- Draw the mask ---
# Create a colored mask image
mask_image = Image.new("RGBA", image.size)
mask_draw = ImageDraw.Draw(mask_image)
# Convert mask tensor to a PIL-drawable format
# The mask tensor is a boolean tensor, we draw where it's True
pil_mask = Image.fromarray(mask.astype('uint8') * 255)
# Draw the mask with semi-transparency
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
# Composite the mask onto the main image
annotated_image.alpha_composite(mask_image)
# --- Draw the bounding box ---
draw.rectangle(box, outline=color, width=3)
# --- Draw the label and score ---
text = f"{label}: {score:.2f}"
text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
# Create a small background for the text for better readability
draw.rectangle(text_bbox, fill=color)
draw.text((box[0], box[1]), text, fill="white", font=font)
return annotated_image
# --- 3. Main Prediction Function ---
def predict(input_image):
"""
The main function that runs inference and orchestrates the process.
This function is called by the Gradio interface.
"""
if input_image is None:
return None, "Please upload an image."
print("Processing image...")
# Preprocess the image
inputs = processor(images=input_image, return_tensors="pt").to(device)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process the outputs to get instance segmentation results
# We specify the target image size to scale the masks and boxes correctly
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
# Filter results by score and class
segments_info = []
segmentation = result["segmentation"].cpu().numpy()
segments_info_raw = result["segments_info"]
for segment in segments_info_raw:
score = segment["score"]
label_id = segment["label_id"]
segment_id = segment["id"]
label_name = model.config.id2label[label_id]
if score > 0.5 and label_name in TARGET_CLASSES:
# Create a binary mask for this segment
mask = (segmentation == segment_id)
# Calculate bounding box from mask
pos = np.where(mask)
if pos[0].size > 0 and pos[1].size > 0:
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
segments_info.append({
"score": score,
"label": label_name,
"mask": mask,
"box": [xmin, ymin, xmax, ymax]
})
print(f"Found {len(segments_info)} objects.")
# Draw the results on the image
if not segments_info:
return input_image, "No objects from the target classes were detected with high confidence."
annotated_image = draw_segmentation(input_image, segments_info)
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
# --- 4. Gradio Interface Definition ---
# Load some example images
# Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
example_paths = [
"examples/street1.jpeg",
"examples/street2.jpeg",
"examples/street3.jpeg",
"examples/street4.jpeg",
"examples/street5.jpeg",
"examples/street6.jpeg",
"examples/street7.jpeg",
"examples/street8.jpeg",
"examples/street9.jpeg",
"examples/street10.jpeg",
"examples/catsanddogs1.jpeg",
"examples/catsanddogs2.jpeg",
"examples/catsanddogs3.jpeg",
"examples/catsanddogs4.jpeg",
"examples/catsanddogs5.jpeg",
]
# Build the Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Advanced Instance Segmentation with Mask2Former",
description="""
Upload an image or click an example to see instance segmentation in action.
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
Each object is highlighted with a colored mask, a bounding box, and a label.
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
""",
examples=example_paths,
cache_examples=True # Cache results for examples for faster demo
)
if __name__ == "__main__":
demo.launch('share=True')