Brightsun10 commited on
Commit
320875d
·
verified ·
1 Parent(s): 7273a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -188
app.py CHANGED
@@ -1,188 +1,187 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
4
- from PIL import Image, ImageDraw, ImageFont
5
- import numpy as np
6
- import random
7
-
8
- # --- 1. Global Setup & Model Loading ---
9
-
10
- # Check for GPU availability
11
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
-
13
- # Load the image processor and the model
14
- # The model is loaded once and cached for all subsequent inference calls
15
- print("Loading model...")
16
- processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-instance")
17
- model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance").to(device)
18
- print("Model loaded successfully.")
19
-
20
- # Define the classes we are interested in.
21
- # Note: "building" is not a class in the COCO-instance dataset.
22
- TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
23
-
24
- # --- 2. Visualization & Drawing Logic ---
25
-
26
- # Generate a consistent color for each class label
27
- # This ensures that, for example, all 'car' masks are the same color.
28
- label_to_color = {}
29
-
30
- def get_label_color(label):
31
- """Returns a random, consistent color for a given label."""
32
- if label not in label_to_color:
33
- # Generate a random color
34
- color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
35
- label_to_color[label] = color
36
- return label_to_color[label]
37
-
38
- def draw_segmentation(image, segments_info):
39
- """
40
- Draws masks, bounding boxes, and labels on the image.
41
-
42
- Args:
43
- image (PIL.Image.Image): The original input image.
44
- segments_info (list): A list of dictionaries, each containing info about a detected segment.
45
- """
46
- # Make a copy of the image to draw on
47
- annotated_image = image.convert("RGBA")
48
- draw = ImageDraw.Draw(annotated_image)
49
-
50
- # Load a font
51
- try:
52
- font = ImageFont.truetype("arial.ttf", size=20)
53
- except IOError:
54
- print("Arial font not found, using default font.")
55
- font = ImageFont.load_default()
56
-
57
- for segment in segments_info:
58
- label = segment['label']
59
- score = segment['score']
60
- mask = segment['mask']
61
- box = segment['box']
62
-
63
- # Get the color for this label
64
- color = get_label_color(label)
65
-
66
- # --- Draw the mask ---
67
- # Create a colored mask image
68
- mask_image = Image.new("RGBA", image.size)
69
- mask_draw = ImageDraw.Draw(mask_image)
70
-
71
- # Convert mask tensor to a PIL-drawable format
72
- # The mask tensor is a boolean tensor, we draw where it's True
73
- pil_mask = Image.fromarray(mask.astype('uint8') * 255)
74
-
75
- # Draw the mask with semi-transparency
76
- mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
77
-
78
- # Composite the mask onto the main image
79
- annotated_image.alpha_composite(mask_image)
80
-
81
- # --- Draw the bounding box ---
82
- draw.rectangle(box, outline=color, width=3)
83
-
84
- # --- Draw the label and score ---
85
- text = f"{label}: {score:.2f}"
86
- text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
87
- # Create a small background for the text for better readability
88
- draw.rectangle(text_bbox, fill=color)
89
- draw.text((box[0], box[1]), text, fill="white", font=font)
90
-
91
- return annotated_image
92
-
93
- # --- 3. Main Prediction Function ---
94
-
95
- def predict(input_image):
96
- """
97
- The main function that runs inference and orchestrates the process.
98
- This function is called by the Gradio interface.
99
- """
100
- if input_image is None:
101
- return None, "Please upload an image."
102
-
103
- print("Processing image...")
104
- # Preprocess the image
105
- inputs = processor(images=input_image, return_tensors="pt").to(device)
106
-
107
- # Perform inference
108
- with torch.no_grad():
109
- outputs = model(**inputs)
110
-
111
- # Post-process the outputs to get instance segmentation results
112
- # We specify the target image size to scale the masks and boxes correctly
113
- result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
114
-
115
- # Filter results by score and class
116
- segments_info = []
117
- scores = result['scores'].cpu().numpy()
118
- labels = result['labels'].cpu().numpy()
119
- masks = result['masks'].cpu().numpy()
120
-
121
- # Get bounding boxes from masks
122
- for i in range(len(scores)):
123
- score = scores[i]
124
- label_id = labels[i]
125
- label_name = model.config.id2label[label_id]
126
-
127
- # Filter out low-confidence scores and unwanted classes
128
- if score > 0.9 and label_name in TARGET_CLASSES:
129
- mask = masks[i]
130
-
131
- # Calculate bounding box from mask
132
- pos = np.where(mask)
133
- if pos[0].size > 0 and pos[1].size > 0: # Ensure mask is not empty
134
- xmin = np.min(pos[1])
135
- xmax = np.max(pos[1])
136
- ymin = np.min(pos[0])
137
- ymax = np.max(pos[0])
138
-
139
- segments_info.append({
140
- "score": score,
141
- "label": label_name,
142
- "mask": mask,
143
- "box": [xmin, ymin, xmax, ymax]
144
- })
145
-
146
- print(f"Found {len(segments_info)} objects.")
147
-
148
- # Draw the results on the image
149
- if not segments_info:
150
- return input_image, "No objects from the target classes were detected with high confidence."
151
-
152
- annotated_image = draw_segmentation(input_image, segments_info)
153
-
154
- return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
155
-
156
- # --- 4. Gradio Interface Definition ---
157
-
158
- # Load some example images
159
- # Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
160
- example_paths = [
161
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats-vs-dogs.png",
162
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/instance-segmentation-input.jpg",
163
- "https://placehold.co/800x600/ FFF/333?text=A+busy+street+scene",
164
- ]
165
-
166
-
167
- # Build the Gradio interface
168
- demo = gr.Interface(
169
- fn=predict,
170
- inputs=gr.Image(type="pil", label="Upload Image"),
171
- outputs=[
172
- gr.Image(type="pil", label="Segmented Image"),
173
- gr.Textbox(label="Status")
174
- ],
175
- title="Advanced Instance Segmentation with Mask2Former",
176
- description="""
177
- Upload an image or click an example to see instance segmentation in action.
178
- The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
179
- Each object is highlighted with a colored mask, a bounding box, and a label.
180
- *Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
181
- """,
182
- examples=example_paths,
183
- cache_examples=True # Cache results for examples for faster demo
184
- )
185
-
186
- if __name__ == "__main__":
187
- demo.launch()
188
-
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ import random
7
+
8
+ # --- 1. Global Setup & Model Loading ---
9
+
10
+ # Check for GPU availability
11
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Load the image processor and the model
14
+ # The model is loaded once and cached for all subsequent inference calls
15
+ print("Loading model...")
16
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-instance")
17
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance").to(device)
18
+ print("Model loaded successfully.")
19
+
20
+ # Define the classes we are interested in.
21
+ # Note: "building" is not a class in the COCO-instance dataset.
22
+ TARGET_CLASSES = ['cat', 'dog', 'car', 'truck', 'bus', 'person']
23
+
24
+ # --- 2. Visualization & Drawing Logic ---
25
+
26
+ # Generate a consistent color for each class label
27
+ # This ensures that, for example, all 'car' masks are the same color.
28
+ label_to_color = {}
29
+
30
+ def get_label_color(label):
31
+ """Returns a random, consistent color for a given label."""
32
+ if label not in label_to_color:
33
+ # Generate a random color
34
+ color = (random.randint(50, 255), random.randint(50, 200), random.randint(50, 255))
35
+ label_to_color[label] = color
36
+ return label_to_color[label]
37
+
38
+ def draw_segmentation(image, segments_info):
39
+ """
40
+ Draws masks, bounding boxes, and labels on the image.
41
+
42
+ Args:
43
+ image (PIL.Image.Image): The original input image.
44
+ segments_info (list): A list of dictionaries, each containing info about a detected segment.
45
+ """
46
+ # Make a copy of the image to draw on
47
+ annotated_image = image.convert("RGBA")
48
+ draw = ImageDraw.Draw(annotated_image)
49
+
50
+ # Load a font
51
+ try:
52
+ font = ImageFont.truetype("arial.ttf", size=20)
53
+ except IOError:
54
+ print("Arial font not found, using default font.")
55
+ font = ImageFont.load_default()
56
+
57
+ for segment in segments_info:
58
+ label = segment['label']
59
+ score = segment['score']
60
+ mask = segment['mask']
61
+ box = segment['box']
62
+
63
+ # Get the color for this label
64
+ color = get_label_color(label)
65
+
66
+ # --- Draw the mask ---
67
+ # Create a colored mask image
68
+ mask_image = Image.new("RGBA", image.size)
69
+ mask_draw = ImageDraw.Draw(mask_image)
70
+
71
+ # Convert mask tensor to a PIL-drawable format
72
+ # The mask tensor is a boolean tensor, we draw where it's True
73
+ pil_mask = Image.fromarray(mask.astype('uint8') * 255)
74
+
75
+ # Draw the mask with semi-transparency
76
+ mask_draw.bitmap((0, 0), pil_mask, fill=color + (150,)) # RGBA with transparency
77
+
78
+ # Composite the mask onto the main image
79
+ annotated_image.alpha_composite(mask_image)
80
+
81
+ # --- Draw the bounding box ---
82
+ draw.rectangle(box, outline=color, width=3)
83
+
84
+ # --- Draw the label and score ---
85
+ text = f"{label}: {score:.2f}"
86
+ text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
87
+ # Create a small background for the text for better readability
88
+ draw.rectangle(text_bbox, fill=color)
89
+ draw.text((box[0], box[1]), text, fill="white", font=font)
90
+
91
+ return annotated_image
92
+
93
+ # --- 3. Main Prediction Function ---
94
+
95
+ def predict(input_image):
96
+ """
97
+ The main function that runs inference and orchestrates the process.
98
+ This function is called by the Gradio interface.
99
+ """
100
+ if input_image is None:
101
+ return None, "Please upload an image."
102
+
103
+ print("Processing image...")
104
+ # Preprocess the image
105
+ inputs = processor(images=input_image, return_tensors="pt").to(device)
106
+
107
+ # Perform inference
108
+ with torch.no_grad():
109
+ outputs = model(**inputs)
110
+
111
+ # Post-process the outputs to get instance segmentation results
112
+ # We specify the target image size to scale the masks and boxes correctly
113
+ result = processor.post_process_instance_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0]
114
+
115
+ # Filter results by score and class
116
+ segments_info = []
117
+ scores = result['scores'].cpu().numpy()
118
+ labels = result['labels'].cpu().numpy()
119
+ masks = result['masks'].cpu().numpy()
120
+
121
+ # Get bounding boxes from masks
122
+ for i in range(len(scores)):
123
+ score = scores[i]
124
+ label_id = labels[i]
125
+ label_name = model.config.id2label[label_id]
126
+
127
+ # Filter out low-confidence scores and unwanted classes
128
+ if score > 0.9 and label_name in TARGET_CLASSES:
129
+ mask = masks[i]
130
+
131
+ # Calculate bounding box from mask
132
+ pos = np.where(mask)
133
+ if pos[0].size > 0 and pos[1].size > 0: # Ensure mask is not empty
134
+ xmin = np.min(pos[1])
135
+ xmax = np.max(pos[1])
136
+ ymin = np.min(pos[0])
137
+ ymax = np.max(pos[0])
138
+
139
+ segments_info.append({
140
+ "score": score,
141
+ "label": label_name,
142
+ "mask": mask,
143
+ "box": [xmin, ymin, xmax, ymax]
144
+ })
145
+
146
+ print(f"Found {len(segments_info)} objects.")
147
+
148
+ # Draw the results on the image
149
+ if not segments_info:
150
+ return input_image, "No objects from the target classes were detected with high confidence."
151
+
152
+ annotated_image = draw_segmentation(input_image, segments_info)
153
+
154
+ return annotated_image, f"Successfully processed. Found {len(segments_info)} objects."
155
+
156
+ # --- 4. Gradio Interface Definition ---
157
+
158
+ # Load some example images
159
+ # Note: You must upload these images to your Hugging Face Space repository in a folder named 'examples'
160
+ example_paths = [
161
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats-vs-dogs.png",
162
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/instance-segmentation-input.jpg",
163
+ "https://placehold.co/800x600/FFF/333?text=A+busy+street+scene", # Corrected URL (removed space)
164
+ ]
165
+
166
+
167
+ # Build the Gradio interface
168
+ demo = gr.Interface(
169
+ fn=predict,
170
+ inputs=gr.Image(type="pil", label="Upload Image"),
171
+ outputs=[
172
+ gr.Image(type="pil", label="Segmented Image"),
173
+ gr.Textbox(label="Status")
174
+ ],
175
+ title="Advanced Instance Segmentation with Mask2Former",
176
+ description="""
177
+ Upload an image or click an example to see instance segmentation in action.
178
+ The model identifies objects from the classes: **car, bus, truck, person, dog, cat**.
179
+ Each object is highlighted with a colored mask, a bounding box, and a label.
180
+ *Note: The free CPU can be slow; please allow up to 30 seconds for processing.*
181
+ """,
182
+ examples=example_paths,
183
+ cache_examples=True # Cache results for examples for faster demo
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch()