File size: 7,035 Bytes
320875d b456e70 320875d e7f85d8 320875d c0e65eb e7f85d8 320875d e7f85d8 320875d e7f85d8 320875d e7f85d8 320875d adc921d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
# --- 1. Global Setup & Model Loading ---
# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the image processor and the model
# The model is loaded once and cached for all subsequent inference calls
print("Loading model...")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").to(device)
print("Model loaded successfully.")
# Define the classes we are interested in.
# Note: "building" is not a class in the COCO-instance dataset.
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
# --- 2. Visualization & Drawing Logic ---
# Generate a consistent color for each class label
# This ensures that, for example, all 'car' masks are the same color.
label_to_color = {}
def get_label_color(label):
"""Returns a random, consistent color for a given label."""
if label not in label_to_color:
# Generate a random color
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
label_to_color[label] = color
return label_to_color[label]
def draw_segmentation(image, segments_info):
"""
Draws masks, bounding boxes, and labels on the image.
Args:
image (PIL.Image.Image): The original input image.
segments_info (list): A list of dictionaries, each containing info about a detected segment.
"""
# Make a copy of the image to draw on
annotated_image = image.convert("RGBA")
draw = ImageDraw.Draw(annotated_image)
# Load a font
try:
font = ImageFont.truetype("arial.ttf", size=20)
except IOError:
print("Arial font not found, using default font.")
font = ImageFont.load_default()
for segment in segments_info:
label = segment['label']
score = segment['score']
mask = segment['mask']
box = segment['box']
# Get the color for this label
color = get_label_color(label)
# --- Draw the mask ---
# Create a colored mask image
mask_image = Image.new("RGBA", image.size)
mask_draw = ImageDraw.Draw(mask_image)
# Convert mask tensor to a PIL-drawable format
# The mask tensor is a boolean tensor, we draw where it's True
pil_mask = Image.fromarray(mask.astype('uint8') * 255)
# Draw the mask with semi-transparency
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
# Composite the mask onto the main image
annotated_image.alpha_composite(mask_image)
# --- Draw the bounding box ---
draw.rectangle(box, outline=color, width=3)
# --- Draw the label and score ---
text = f"{label}: {score:.2f}"
text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
# Create a small background for the text for better readability
draw.rectangle(text_bbox, fill=color)
draw.text((box[0], box[1]), text, fill="white", font=font)
return annotated_image
# --- 3. Main Prediction Function ---
def predict(input_image):
"""
The main function that runs inference and orchestrates the process.
This function is called by the Gradio interface.
"""
if input_image is None:
return None, "Please upload an image."
print("Processing image...")
# Preprocess the image
inputs = processor(images=input_image, return_tensors="pt").to(device)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process the outputs to get instance segmentation results
# We specify the target image size to scale the masks and boxes correctly
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
# Filter results by score and class
segments_info = []
segmentation = result["segmentation"].cpu().numpy()
segments_info_raw = result["segments_info"]
for segment in segments_info_raw:
score = segment["score"]
label_id = segment["label_id"]
segment_id = segment["id"]
label_name = model.config.id2label[label_id]
if score > 0.5 and label_name in TARGET_CLASSES:
# Create a binary mask for this segment
mask = (segmentation == segment_id)
# Calculate bounding box from mask
pos = np.where(mask)
if pos[0].size > 0 and pos[1].size > 0:
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
segments_info.append({
"score": score,
"label": label_name,
"mask": mask,
"box": [xmin, ymin, xmax, ymax]
})
print(f"Found {len(segments_info)} objects.")
# Draw the results on the image
if not segments_info:
return input_image, "No objects from the target classes were detected with high confidence."
annotated_image = draw_segmentation(input_image, segments_info)
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
# --- 4. Gradio Interface Definition ---
# Load some example images
# Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
example_paths = [
"examples/street1.jpeg",
"examples/street2.jpeg",
"examples/street3.jpeg",
"examples/street4.jpeg",
"examples/street5.jpeg",
"examples/street6.jpeg",
"examples/street7.jpeg",
"examples/street8.jpeg",
"examples/street9.jpeg",
"examples/street10.jpeg",
"examples/catsanddogs1.jpeg",
"examples/catsanddogs2.jpeg",
"examples/catsanddogs3.jpeg",
"examples/catsanddogs4.jpeg",
"examples/catsanddogs5.jpeg",
]
# Build the Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Advanced Instance Segmentation with Mask2Former",
description="""
Upload an image or click an example to see instance segmentation in action.
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
Each object is highlighted with a colored mask, a bounding box, and a label.
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
""",
examples=example_paths,
cache_examples=True # Cache results for examples for faster demo
)
if __name__ == "__main__":
demo.launch('share=True')
|