|
import gradio as gr |
|
import torch |
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import random |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
print("Loading model...") |
|
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance") |
|
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").to(device) |
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person'] |
|
|
|
|
|
|
|
|
|
|
|
label_to_color = {} |
|
|
|
def get_label_color(label): |
|
"""Returns a random, consistent color for a given label.""" |
|
if label not in label_to_color: |
|
|
|
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255)) |
|
label_to_color[label] = color |
|
return label_to_color[label] |
|
|
|
def draw_segmentation(image, segments_info): |
|
""" |
|
Draws masks, bounding boxes, and labels on the image. |
|
|
|
Args: |
|
image (PIL.Image.Image): The original input image. |
|
segments_info (list): A list of dictionaries, each containing info about a detected segment. |
|
""" |
|
|
|
annotated_image = image.convert("RGBA") |
|
draw = ImageDraw.Draw(annotated_image) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", size=20) |
|
except IOError: |
|
print("Arial font not found, using default font.") |
|
font = ImageFont.load_default() |
|
|
|
for segment in segments_info: |
|
label = segment['label'] |
|
score = segment['score'] |
|
mask = segment['mask'] |
|
box = segment['box'] |
|
|
|
|
|
color = get_label_color(label) |
|
|
|
|
|
|
|
mask_image = Image.new("RGBA", image.size) |
|
mask_draw = ImageDraw.Draw(mask_image) |
|
|
|
|
|
|
|
pil_mask = Image.fromarray(mask.astype('uint8') * 255) |
|
|
|
|
|
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) |
|
|
|
|
|
annotated_image.alpha_composite(mask_image) |
|
|
|
|
|
draw.rectangle(box, outline=color, width=3) |
|
|
|
|
|
text = f"{label}: {score:.2f}" |
|
text_bbox = draw.textbbox((box[0], box[1]), text, font=font) |
|
|
|
draw.rectangle(text_bbox, fill=color) |
|
draw.text((box[0], box[1]), text, fill="white", font=font) |
|
|
|
return annotated_image |
|
|
|
|
|
|
|
def predict(input_image): |
|
""" |
|
The main function that runs inference and orchestrates the process. |
|
This function is called by the Gradio interface. |
|
""" |
|
if input_image is None: |
|
return None, "Please upload an image." |
|
|
|
print("Processing image...") |
|
|
|
inputs = processor(images=input_image, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] |
|
|
|
|
|
segments_info = [] |
|
segmentation = result["segmentation"].cpu().numpy() |
|
segments_info_raw = result["segments_info"] |
|
|
|
for segment in segments_info_raw: |
|
score = segment["score"] |
|
label_id = segment["label_id"] |
|
segment_id = segment["id"] |
|
label_name = model.config.id2label[label_id] |
|
|
|
if score > 0.5 and label_name in TARGET_CLASSES: |
|
|
|
mask = (segmentation == segment_id) |
|
|
|
|
|
pos = np.where(mask) |
|
if pos[0].size > 0 and pos[1].size > 0: |
|
xmin = np.min(pos[1]) |
|
xmax = np.max(pos[1]) |
|
ymin = np.min(pos[0]) |
|
ymax = np.max(pos[0]) |
|
|
|
segments_info.append({ |
|
"score": score, |
|
"label": label_name, |
|
"mask": mask, |
|
"box": [xmin, ymin, xmax, ymax] |
|
}) |
|
|
|
print(f"Found {len(segments_info)} objects.") |
|
|
|
|
|
if not segments_info: |
|
return input_image, "No objects from the target classes were detected with high confidence." |
|
|
|
annotated_image = draw_segmentation(input_image, segments_info) |
|
|
|
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects." |
|
|
|
|
|
|
|
|
|
|
|
example_paths = [ |
|
"examples/street1.jpeg", |
|
"examples/street2.jpeg", |
|
"examples/street3.jpeg", |
|
"examples/street4.jpeg", |
|
"examples/street5.jpeg", |
|
"examples/street6.jpeg", |
|
"examples/street7.jpeg", |
|
"examples/street8.jpeg", |
|
"examples/street9.jpeg", |
|
"examples/street10.jpeg", |
|
"examples/catsanddogs1.jpeg", |
|
"examples/catsanddogs2.jpeg", |
|
"examples/catsanddogs3.jpeg", |
|
"examples/catsanddogs4.jpeg", |
|
"examples/catsanddogs5.jpeg", |
|
] |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Segmented Image"), |
|
gr.Textbox(label="Status") |
|
], |
|
title="Advanced Instance Segmentation with Mask2Former", |
|
description=""" |
|
Upload an image or click an example to see instance segmentation in action. |
|
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**. |
|
Each object is highlighted with a colored mask, a bounding box, and a label. |
|
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.* |
|
""", |
|
examples=example_paths, |
|
cache_examples=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch('share=True') |
|
|