Update app.py
Browse files
app.py
CHANGED
@@ -1,188 +1,187 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
4 |
-
from PIL import Image, ImageDraw, ImageFont
|
5 |
-
import numpy as np
|
6 |
-
import random
|
7 |
-
|
8 |
-
# --- 1. Global Setup & Model Loading ---
|
9 |
-
|
10 |
-
# Check for GPU availability
|
11 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
12 |
-
|
13 |
-
# Load the image processor and the model
|
14 |
-
# The model is loaded once and cached for all subsequent inference calls
|
15 |
-
print("Loading model...")
|
16 |
-
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-instance")
|
17 |
-
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance").to(device)
|
18 |
-
print("Model loaded successfully.")
|
19 |
-
|
20 |
-
# Define the classes we are interested in.
|
21 |
-
# Note: "building" is not a class in the COCO-instance dataset.
|
22 |
-
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
|
23 |
-
|
24 |
-
# --- 2. Visualization & Drawing Logic ---
|
25 |
-
|
26 |
-
# Generate a consistent color for each class label
|
27 |
-
# This ensures that, for example, all 'car' masks are the same color.
|
28 |
-
label_to_color = {}
|
29 |
-
|
30 |
-
def get_label_color(label):
|
31 |
-
"""Returns a random, consistent color for a given label."""
|
32 |
-
if label not in label_to_color:
|
33 |
-
# Generate a random color
|
34 |
-
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
|
35 |
-
label_to_color[label] = color
|
36 |
-
return label_to_color[label]
|
37 |
-
|
38 |
-
def draw_segmentation(image, segments_info):
|
39 |
-
"""
|
40 |
-
Draws masks, bounding boxes, and labels on the image.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
image (PIL.Image.Image): The original input image.
|
44 |
-
segments_info (list): A list of dictionaries, each containing info about a detected segment.
|
45 |
-
"""
|
46 |
-
# Make a copy of the image to draw on
|
47 |
-
annotated_image = image.convert("RGBA")
|
48 |
-
draw = ImageDraw.Draw(annotated_image)
|
49 |
-
|
50 |
-
# Load a font
|
51 |
-
try:
|
52 |
-
font = ImageFont.truetype("arial.ttf", size=20)
|
53 |
-
except IOError:
|
54 |
-
print("Arial font not found, using default font.")
|
55 |
-
font = ImageFont.load_default()
|
56 |
-
|
57 |
-
for segment in segments_info:
|
58 |
-
label = segment['label']
|
59 |
-
score = segment['score']
|
60 |
-
mask = segment['mask']
|
61 |
-
box = segment['box']
|
62 |
-
|
63 |
-
# Get the color for this label
|
64 |
-
color = get_label_color(label)
|
65 |
-
|
66 |
-
# --- Draw the mask ---
|
67 |
-
# Create a colored mask image
|
68 |
-
mask_image = Image.new("RGBA", image.size)
|
69 |
-
mask_draw = ImageDraw.Draw(mask_image)
|
70 |
-
|
71 |
-
# Convert mask tensor to a PIL-drawable format
|
72 |
-
# The mask tensor is a boolean tensor, we draw where it's True
|
73 |
-
pil_mask = Image.fromarray(mask.astype('uint8') * 255)
|
74 |
-
|
75 |
-
# Draw the mask with semi-transparency
|
76 |
-
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
|
77 |
-
|
78 |
-
# Composite the mask onto the main image
|
79 |
-
annotated_image.alpha_composite(mask_image)
|
80 |
-
|
81 |
-
# --- Draw the bounding box ---
|
82 |
-
draw.rectangle(box, outline=color, width=3)
|
83 |
-
|
84 |
-
# --- Draw the label and score ---
|
85 |
-
text = f"{label}: {score:.2f}"
|
86 |
-
text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
|
87 |
-
# Create a small background for the text for better readability
|
88 |
-
draw.rectangle(text_bbox, fill=color)
|
89 |
-
draw.text((box[0], box[1]), text, fill="white", font=font)
|
90 |
-
|
91 |
-
return annotated_image
|
92 |
-
|
93 |
-
# --- 3. Main Prediction Function ---
|
94 |
-
|
95 |
-
def predict(input_image):
|
96 |
-
"""
|
97 |
-
The main function that runs inference and orchestrates the process.
|
98 |
-
This function is called by the Gradio interface.
|
99 |
-
"""
|
100 |
-
if input_image is None:
|
101 |
-
return None, "Please upload an image."
|
102 |
-
|
103 |
-
print("Processing image...")
|
104 |
-
# Preprocess the image
|
105 |
-
inputs = processor(images=input_image, return_tensors="pt").to(device)
|
106 |
-
|
107 |
-
# Perform inference
|
108 |
-
with torch.no_grad():
|
109 |
-
outputs = model(**inputs)
|
110 |
-
|
111 |
-
# Post-process the outputs to get instance segmentation results
|
112 |
-
# We specify the target image size to scale the masks and boxes correctly
|
113 |
-
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
|
114 |
-
|
115 |
-
# Filter results by score and class
|
116 |
-
segments_info = []
|
117 |
-
scores = result['scores'].cpu().numpy()
|
118 |
-
labels = result['labels'].cpu().numpy()
|
119 |
-
masks = result['masks'].cpu().numpy()
|
120 |
-
|
121 |
-
# Get bounding boxes from masks
|
122 |
-
for i in range(len(scores)):
|
123 |
-
score = scores[i]
|
124 |
-
label_id = labels[i]
|
125 |
-
label_name = model.config.id2label[label_id]
|
126 |
-
|
127 |
-
# Filter out low-confidence scores and unwanted classes
|
128 |
-
if score > 0.9 and label_name in TARGET_CLASSES:
|
129 |
-
mask = masks[i]
|
130 |
-
|
131 |
-
# Calculate bounding box from mask
|
132 |
-
pos = np.where(mask)
|
133 |
-
if pos[0].size > 0 and pos[1].size > 0: # Ensure mask is not empty
|
134 |
-
xmin = np.min(pos[1])
|
135 |
-
xmax = np.max(pos[1])
|
136 |
-
ymin = np.min(pos[0])
|
137 |
-
ymax = np.max(pos[0])
|
138 |
-
|
139 |
-
segments_info.append({
|
140 |
-
"score": score,
|
141 |
-
"label": label_name,
|
142 |
-
"mask": mask,
|
143 |
-
"box": [xmin, ymin, xmax, ymax]
|
144 |
-
})
|
145 |
-
|
146 |
-
print(f"Found {len(segments_info)} objects.")
|
147 |
-
|
148 |
-
# Draw the results on the image
|
149 |
-
if not segments_info:
|
150 |
-
return input_image, "No objects from the target classes were detected with high confidence."
|
151 |
-
|
152 |
-
annotated_image = draw_segmentation(input_image, segments_info)
|
153 |
-
|
154 |
-
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
|
155 |
-
|
156 |
-
# --- 4. Gradio Interface Definition ---
|
157 |
-
|
158 |
-
# Load some example images
|
159 |
-
# Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
|
160 |
-
example_paths = [
|
161 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats-vs-dogs.png",
|
162 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/instance-segmentation-input.jpg",
|
163 |
-
"https://placehold.co/800x600/
|
164 |
-
]
|
165 |
-
|
166 |
-
|
167 |
-
# Build the Gradio interface
|
168 |
-
demo = gr.Interface(
|
169 |
-
fn=predict,
|
170 |
-
inputs=gr.Image(type="pil", label="Upload Image"),
|
171 |
-
outputs=[
|
172 |
-
gr.Image(type="pil", label="Segmented Image"),
|
173 |
-
gr.Textbox(label="Status")
|
174 |
-
],
|
175 |
-
title="Advanced Instance Segmentation with Mask2Former",
|
176 |
-
description="""
|
177 |
-
Upload an image or click an example to see instance segmentation in action.
|
178 |
-
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
|
179 |
-
Each object is highlighted with a colored mask, a bounding box, and a label.
|
180 |
-
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
|
181 |
-
""",
|
182 |
-
examples=example_paths,
|
183 |
-
cache_examples=True # Cache results for examples for faster demo
|
184 |
-
)
|
185 |
-
|
186 |
-
if __name__ == "__main__":
|
187 |
-
demo.launch()
|
188 |
-
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
|
8 |
+
# --- 1. Global Setup & Model Loading ---
|
9 |
+
|
10 |
+
# Check for GPU availability
|
11 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
+
# Load the image processor and the model
|
14 |
+
# The model is loaded once and cached for all subsequent inference calls
|
15 |
+
print("Loading model...")
|
16 |
+
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-instance")
|
17 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance").to(device)
|
18 |
+
print("Model loaded successfully.")
|
19 |
+
|
20 |
+
# Define the classes we are interested in.
|
21 |
+
# Note: "building" is not a class in the COCO-instance dataset.
|
22 |
+
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
|
23 |
+
|
24 |
+
# --- 2. Visualization & Drawing Logic ---
|
25 |
+
|
26 |
+
# Generate a consistent color for each class label
|
27 |
+
# This ensures that, for example, all 'car' masks are the same color.
|
28 |
+
label_to_color = {}
|
29 |
+
|
30 |
+
def get_label_color(label):
|
31 |
+
"""Returns a random, consistent color for a given label."""
|
32 |
+
if label not in label_to_color:
|
33 |
+
# Generate a random color
|
34 |
+
color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
|
35 |
+
label_to_color[label] = color
|
36 |
+
return label_to_color[label]
|
37 |
+
|
38 |
+
def draw_segmentation(image, segments_info):
|
39 |
+
"""
|
40 |
+
Draws masks, bounding boxes, and labels on the image.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
image (PIL.Image.Image): The original input image.
|
44 |
+
segments_info (list): A list of dictionaries, each containing info about a detected segment.
|
45 |
+
"""
|
46 |
+
# Make a copy of the image to draw on
|
47 |
+
annotated_image = image.convert("RGBA")
|
48 |
+
draw = ImageDraw.Draw(annotated_image)
|
49 |
+
|
50 |
+
# Load a font
|
51 |
+
try:
|
52 |
+
font = ImageFont.truetype("arial.ttf", size=20)
|
53 |
+
except IOError:
|
54 |
+
print("Arial font not found, using default font.")
|
55 |
+
font = ImageFont.load_default()
|
56 |
+
|
57 |
+
for segment in segments_info:
|
58 |
+
label = segment['label']
|
59 |
+
score = segment['score']
|
60 |
+
mask = segment['mask']
|
61 |
+
box = segment['box']
|
62 |
+
|
63 |
+
# Get the color for this label
|
64 |
+
color = get_label_color(label)
|
65 |
+
|
66 |
+
# --- Draw the mask ---
|
67 |
+
# Create a colored mask image
|
68 |
+
mask_image = Image.new("RGBA", image.size)
|
69 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
70 |
+
|
71 |
+
# Convert mask tensor to a PIL-drawable format
|
72 |
+
# The mask tensor is a boolean tensor, we draw where it's True
|
73 |
+
pil_mask = Image.fromarray(mask.astype('uint8') * 255)
|
74 |
+
|
75 |
+
# Draw the mask with semi-transparency
|
76 |
+
mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
|
77 |
+
|
78 |
+
# Composite the mask onto the main image
|
79 |
+
annotated_image.alpha_composite(mask_image)
|
80 |
+
|
81 |
+
# --- Draw the bounding box ---
|
82 |
+
draw.rectangle(box, outline=color, width=3)
|
83 |
+
|
84 |
+
# --- Draw the label and score ---
|
85 |
+
text = f"{label}: {score:.2f}"
|
86 |
+
text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
|
87 |
+
# Create a small background for the text for better readability
|
88 |
+
draw.rectangle(text_bbox, fill=color)
|
89 |
+
draw.text((box[0], box[1]), text, fill="white", font=font)
|
90 |
+
|
91 |
+
return annotated_image
|
92 |
+
|
93 |
+
# --- 3. Main Prediction Function ---
|
94 |
+
|
95 |
+
def predict(input_image):
|
96 |
+
"""
|
97 |
+
The main function that runs inference and orchestrates the process.
|
98 |
+
This function is called by the Gradio interface.
|
99 |
+
"""
|
100 |
+
if input_image is None:
|
101 |
+
return None, "Please upload an image."
|
102 |
+
|
103 |
+
print("Processing image...")
|
104 |
+
# Preprocess the image
|
105 |
+
inputs = processor(images=input_image, return_tensors="pt").to(device)
|
106 |
+
|
107 |
+
# Perform inference
|
108 |
+
with torch.no_grad():
|
109 |
+
outputs = model(**inputs)
|
110 |
+
|
111 |
+
# Post-process the outputs to get instance segmentation results
|
112 |
+
# We specify the target image size to scale the masks and boxes correctly
|
113 |
+
result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
|
114 |
+
|
115 |
+
# Filter results by score and class
|
116 |
+
segments_info = []
|
117 |
+
scores = result['scores'].cpu().numpy()
|
118 |
+
labels = result['labels'].cpu().numpy()
|
119 |
+
masks = result['masks'].cpu().numpy()
|
120 |
+
|
121 |
+
# Get bounding boxes from masks
|
122 |
+
for i in range(len(scores)):
|
123 |
+
score = scores[i]
|
124 |
+
label_id = labels[i]
|
125 |
+
label_name = model.config.id2label[label_id]
|
126 |
+
|
127 |
+
# Filter out low-confidence scores and unwanted classes
|
128 |
+
if score > 0.9 and label_name in TARGET_CLASSES:
|
129 |
+
mask = masks[i]
|
130 |
+
|
131 |
+
# Calculate bounding box from mask
|
132 |
+
pos = np.where(mask)
|
133 |
+
if pos[0].size > 0 and pos[1].size > 0: # Ensure mask is not empty
|
134 |
+
xmin = np.min(pos[1])
|
135 |
+
xmax = np.max(pos[1])
|
136 |
+
ymin = np.min(pos[0])
|
137 |
+
ymax = np.max(pos[0])
|
138 |
+
|
139 |
+
segments_info.append({
|
140 |
+
"score": score,
|
141 |
+
"label": label_name,
|
142 |
+
"mask": mask,
|
143 |
+
"box": [xmin, ymin, xmax, ymax]
|
144 |
+
})
|
145 |
+
|
146 |
+
print(f"Found {len(segments_info)} objects.")
|
147 |
+
|
148 |
+
# Draw the results on the image
|
149 |
+
if not segments_info:
|
150 |
+
return input_image, "No objects from the target classes were detected with high confidence."
|
151 |
+
|
152 |
+
annotated_image = draw_segmentation(input_image, segments_info)
|
153 |
+
|
154 |
+
return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
|
155 |
+
|
156 |
+
# --- 4. Gradio Interface Definition ---
|
157 |
+
|
158 |
+
# Load some example images
|
159 |
+
# Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
|
160 |
+
example_paths = [
|
161 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats-vs-dogs.png",
|
162 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/instance-segmentation-input.jpg",
|
163 |
+
"https://placehold.co/800x600/FFF/333?text=A+busy+street+scene", # Corrected URL (removed space)
|
164 |
+
]
|
165 |
+
|
166 |
+
|
167 |
+
# Build the Gradio interface
|
168 |
+
demo = gr.Interface(
|
169 |
+
fn=predict,
|
170 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
171 |
+
outputs=[
|
172 |
+
gr.Image(type="pil", label="Segmented Image"),
|
173 |
+
gr.Textbox(label="Status")
|
174 |
+
],
|
175 |
+
title="Advanced Instance Segmentation with Mask2Former",
|
176 |
+
description="""
|
177 |
+
Upload an image or click an example to see instance segmentation in action.
|
178 |
+
The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
|
179 |
+
Each object is highlighted with a colored mask, a bounding box, and a label.
|
180 |
+
*Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
|
181 |
+
""",
|
182 |
+
examples=example_paths,
|
183 |
+
cache_examples=True # Cache results for examples for faster demo
|
184 |
+
)
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
demo.launch()
|
|