|
import gradio as gr
|
|
import torch
|
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
import numpy as np
|
|
import random
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
print("Loading model...")
|
|
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-instance")
|
|
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance").to(device)
|
|
print("Model loaded successfully.")
|
|
|
|
|
|
|
|
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
|
|
|
|
|
|
|
|
|
|
|
|
label_to_color = {}
|
|
|
|
def get_label_color(label):
|
|
"""Returns a random, consistent color for a given label."""
|
|
if label not in label_to_color:
|
|
|
|
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
|
|
label_to_color[label] = color
|
|
return label_to_color[label]
|
|
|
|
def draw_segmentation(image, segments_info):
|
|
"""
|
|
Draws masks, bounding boxes, and labels on the image.
|
|
|
|
Args:
|
|
image (PIL.Image.Image): The original input image.
|
|
segments_info (list): A list of dictionaries, each containing info about a detected segment.
|
|
"""
|
|
|
|
annotated_image = image.convert("RGBA")
|
|
draw = ImageDraw.Draw(annotated_image)
|
|
|
|
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", size=20)
|
|
except IOError:
|
|
print("Arial font not found, using default font.")
|
|
font = ImageFont.load_default()
|
|
|
|
for segment in segments_info:
|
|
label = segment['label']
|
|
score = segment['score']
|
|
mask = segment['mask']
|
|
box = segment['box']
|
|
|
|
|
|
color = get_label_color(label)
|
|
|
|
|
|
|
|
mask_image = Image.new("RGBA", image.size)
|
|
mask_draw = ImageDraw.Draw(mask_image)
|
|
|
|
|
|
|
|
pil_mask = Image.fromarray(mask.astype('uint8') * 255)
|
|
|
|
|
|
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,))
|
|
|
|
|
|
annotated_image.alpha_composite(mask_image)
|
|
|
|
|
|
draw.rectangle(box, outline=color, width=3)
|
|
|
|
|
|
text = f"{label}: {score:.2f}"
|
|
text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
|
|
|
|
draw.rectangle(text_bbox, fill=color)
|
|
draw.text((box[0], box[1]), text, fill="white", font=font)
|
|
|
|
return annotated_image
|
|
|
|
|
|
|
|
def predict(input_image):
|
|
"""
|
|
The main function that runs inference and orchestrates the process.
|
|
This function is called by the Gradio interface.
|
|
"""
|
|
if input_image is None:
|
|
return None, "Please upload an image."
|
|
|
|
print("Processing image...")
|
|
|
|
inputs = processor(images=input_image, return_tensors="pt").to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
|
|
|
|
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
|
|
|
|
|
|
segments_info = []
|
|
scores = result['scores'].cpu().numpy()
|
|
labels = result['labels'].cpu().numpy()
|
|
masks = result['masks'].cpu().numpy()
|
|
|
|
|
|
for i in range(len(scores)):
|
|
score = scores[i]
|
|
label_id = labels[i]
|
|
label_name = model.config.id2label[label_id]
|
|
|
|
|
|
if score > 0.9 and label_name in TARGET_CLASSES:
|
|
mask = masks[i]
|
|
|
|
|
|
pos = np.where(mask)
|
|
if pos[0].size > 0 and pos[1].size > 0:
|
|
xmin = np.min(pos[1])
|
|
xmax = np.max(pos[1])
|
|
ymin = np.min(pos[0])
|
|
ymax = np.max(pos[0])
|
|
|
|
segments_info.append({
|
|
"score": score,
|
|
"label": label_name,
|
|
"mask": mask,
|
|
"box": [xmin, ymin, xmax, ymax]
|
|
})
|
|
|
|
print(f"Found {len(segments_info)} objects.")
|
|
|
|
|
|
if not segments_info:
|
|
return input_image, "No objects from the target classes were detected with high confidence."
|
|
|
|
annotated_image = draw_segmentation(input_image, segments_info)
|
|
|
|
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
|
|
|
|
|
|
|
|
|
|
|
|
example_paths = [
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats-vs-dogs.png",
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/instance-segmentation-input.jpg",
|
|
"https://placehold.co/800x600/ FFF/333?text=A+busy+street+scene",
|
|
]
|
|
|
|
|
|
|
|
demo = gr.Interface(
|
|
fn=predict,
|
|
inputs=gr.Image(type="pil", label="Upload Image"),
|
|
outputs=[
|
|
gr.Image(type="pil", label="Segmented Image"),
|
|
gr.Textbox(label="Status")
|
|
],
|
|
title="Advanced Instance Segmentation with Mask2Former",
|
|
description="""
|
|
Upload an image or click an example to see instance segmentation in action.
|
|
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
|
|
Each object is highlighted with a colored mask, a bounding box, and a label.
|
|
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
|
|
""",
|
|
examples=example_paths,
|
|
cache_examples=True
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch()
|
|
|
|
|