File size: 7,035 Bytes
320875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b456e70
 
320875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f85d8
 
 
 
 
 
 
320875d
 
c0e65eb
e7f85d8
 
 
320875d
 
e7f85d8
320875d
 
 
 
e7f85d8
320875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f85d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc921d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import gradio as gr
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random

# --- 1. Global Setup & Model Loading ---

# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the image processor and the model
# The model is loaded once and cached for all subsequent inference calls
print("Loading model...")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").to(device)
print("Model loaded successfully.")

# Define the classes we are interested in.
# Note: "building" is not a class in the COCO-instance dataset.
TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']

# --- 2. Visualization & Drawing Logic ---

# Generate a consistent color for each class label
# This ensures that, for example, all 'car' masks are the same color.
label_to_color = {}

def get_label_color(label):
    """Returns a random, consistent color for a given label."""
    if label not in label_to_color:
        # Generate a random color
        color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
        label_to_color[label] = color
    return label_to_color[label]

def draw_segmentation(image, segments_info):
    """
    Draws masks, bounding boxes, and labels on the image.
    
    Args:
        image (PIL.Image.Image): The original input image.
        segments_info (list): A list of dictionaries, each containing info about a detected segment.
    """
    # Make a copy of the image to draw on
    annotated_image = image.convert("RGBA")
    draw = ImageDraw.Draw(annotated_image)

    # Load a font
    try:
        font = ImageFont.truetype("arial.ttf", size=20)
    except IOError:
        print("Arial font not found, using default font.")
        font = ImageFont.load_default()

    for segment in segments_info:
        label = segment['label']
        score = segment['score']
        mask = segment['mask']
        box = segment['box']

        # Get the color for this label
        color = get_label_color(label)
        
        # --- Draw the mask ---
        # Create a colored mask image
        mask_image = Image.new("RGBA", image.size)
        mask_draw = ImageDraw.Draw(mask_image)
        
        # Convert mask tensor to a PIL-drawable format
        # The mask tensor is a boolean tensor, we draw where it's True
        pil_mask = Image.fromarray(mask.astype('uint8') * 255)
        
        # Draw the mask with semi-transparency
        mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
        
        # Composite the mask onto the main image
        annotated_image.alpha_composite(mask_image)
        
        # --- Draw the bounding box ---
        draw.rectangle(box, outline=color, width=3)

        # --- Draw the label and score ---
        text = f"{label}: {score:.2f}"
        text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
        # Create a small background for the text for better readability
        draw.rectangle(text_bbox, fill=color)
        draw.text((box[0], box[1]), text, fill="white", font=font)

    return annotated_image

# --- 3. Main Prediction Function ---

def predict(input_image):
    """
    The main function that runs inference and orchestrates the process.
    This function is called by the Gradio interface.
    """
    if input_image is None:
        return None, "Please upload an image."

    print("Processing image...")
    # Preprocess the image
    inputs = processor(images=input_image, return_tensors="pt").to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process the outputs to get instance segmentation results
    # We specify the target image size to scale the masks and boxes correctly
    result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
    
    # Filter results by score and class
    segments_info = []
    segmentation = result["segmentation"].cpu().numpy()
    segments_info_raw = result["segments_info"]

    for segment in segments_info_raw:
        score = segment["score"]
        label_id = segment["label_id"]
        segment_id = segment["id"]
        label_name = model.config.id2label[label_id]

        if score > 0.5 and label_name in TARGET_CLASSES:
            # Create a binary mask for this segment
            mask = (segmentation == segment_id)

            # Calculate bounding box from mask
            pos = np.where(mask)
            if pos[0].size > 0 and pos[1].size > 0:
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])

                segments_info.append({
                    "score": score,
                    "label": label_name,
                    "mask": mask,
                    "box": [xmin, ymin, xmax, ymax]
                })

    print(f"Found {len(segments_info)} objects.")
    
    # Draw the results on the image
    if not segments_info:
        return input_image, "No objects from the target classes were detected with high confidence."
        
    annotated_image = draw_segmentation(input_image, segments_info)
    
    return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."

# --- 4. Gradio Interface Definition ---

# Load some example images
# Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
example_paths = [
    "examples/street1.jpeg",
    "examples/street2.jpeg",
    "examples/street3.jpeg",
    "examples/street4.jpeg",
    "examples/street5.jpeg",
    "examples/street6.jpeg",
    "examples/street7.jpeg",
    "examples/street8.jpeg",
    "examples/street9.jpeg",
    "examples/street10.jpeg",
    "examples/catsanddogs1.jpeg",
    "examples/catsanddogs2.jpeg",
    "examples/catsanddogs3.jpeg",
    "examples/catsanddogs4.jpeg",
    "examples/catsanddogs5.jpeg",
]


# Build the Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[
        gr.Image(type="pil", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="Advanced Instance Segmentation with Mask2Former",
    description="""
    Upload an image or click an example to see instance segmentation in action. 
    The model identifies objects from the classes: **car, bus, truck, person, dog, cat**. 
    Each object is highlighted with a colored mask, a bounding box, and a label.
    *Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
    """,
    examples=example_paths,
    cache_examples=True # Cache results for examples for faster demo
)

if __name__ == "__main__":
    demo.launch('share=True')