John Ho commited on
Commit
05d862a
·
1 Parent(s): 2a081a5

testing object detection with confidence

Browse files
Files changed (1) hide show
  1. app.py +248 -110
app.py CHANGED
@@ -5,7 +5,7 @@ import spaces
5
  import requests
6
  import copy
7
 
8
- from PIL import Image, ImageDraw, ImageFont
9
  import io
10
  import matplotlib.pyplot as plt
11
  import matplotlib.patches as patches
@@ -14,36 +14,88 @@ import random
14
  import numpy as np
15
 
16
  import subprocess
17
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
 
18
 
19
  models = {
20
- 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True).to("cuda").eval(),
21
- 'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
22
- 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
23
- 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
  processors = {
27
- 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
28
- 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
29
- 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
30
- 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
 
 
 
 
 
 
 
 
31
  }
32
 
33
 
34
  DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-large)"
35
 
36
- colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
37
- 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def fig_to_pil(fig):
40
  buf = io.BytesIO()
41
- fig.savefig(buf, format='png')
42
  buf.seek(0)
43
  return Image.open(buf)
44
 
 
45
  @spaces.GPU
46
- def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-large'):
 
 
47
  model = models[model_id]
48
  processor = processors[model_id]
49
  if text_input is None:
@@ -58,38 +110,89 @@ def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florenc
58
  early_stopping=False,
59
  do_sample=False,
60
  num_beams=3,
61
- output_scores= True
62
  )
63
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
64
  parsed_answer = processor.post_process_generation(
65
- generated_text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  task=task_prompt,
67
- image_size=(image.width, image.height)
68
  )
 
69
  return parsed_answer
70
 
 
71
  def plot_bbox(image, data):
72
  fig, ax = plt.subplots()
73
  ax.imshow(image)
74
- for bbox, label in zip(data['bboxes'], data['labels']):
75
  x1, y1, x2, y2 = bbox
76
- rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
 
 
77
  ax.add_patch(rect)
78
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
79
- ax.axis('off')
 
 
 
 
 
 
 
80
  return fig
81
 
 
82
  def draw_polygons(image, prediction, fill_mask=False):
83
 
84
  draw = ImageDraw.Draw(image)
85
  scale = 1
86
- for polygons, label in zip(prediction['polygons'], prediction['labels']):
87
  color = random.choice(colormap)
88
  fill_color = random.choice(colormap) if fill_mask else None
89
  for _polygon in polygons:
90
  _polygon = np.array(_polygon).reshape(-1, 2)
91
  if len(_polygon) < 3:
92
- print('Invalid polygon:', _polygon)
93
  continue
94
  _polygon = (_polygon * scale).reshape(-1).tolist()
95
  if fill_mask:
@@ -99,157 +202,176 @@ def draw_polygons(image, prediction, fill_mask=False):
99
  draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
100
  return image
101
 
 
102
  def convert_to_od_format(data):
103
- bboxes = data.get('bboxes', [])
104
- labels = data.get('bboxes_labels', [])
105
- od_results = {
106
- 'bboxes': bboxes,
107
- 'labels': labels
108
- }
109
  return od_results
110
 
 
111
  def draw_ocr_bboxes(image, prediction):
112
  scale = 1
113
  draw = ImageDraw.Draw(image)
114
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
115
  for box, label in zip(bboxes, labels):
116
  color = random.choice(colormap)
117
  new_box = (np.array(box) * scale).tolist()
118
  draw.polygon(new_box, width=3, outline=color)
119
- draw.text((new_box[0]+8, new_box[1]+2),
120
- "{}".format(label),
121
- align="right",
122
- fill=color)
 
 
123
  return image
124
 
125
- def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
 
 
 
126
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
127
- if task_prompt == 'Caption':
128
- task_prompt = '<CAPTION>'
129
  results = run_example(task_prompt, image, model_id=model_id)
130
  return results, None
131
- elif task_prompt == 'Detailed Caption':
132
- task_prompt = '<DETAILED_CAPTION>'
133
  results = run_example(task_prompt, image, model_id=model_id)
134
  return results, None
135
- elif task_prompt == 'More Detailed Caption':
136
- task_prompt = '<MORE_DETAILED_CAPTION>'
137
  results = run_example(task_prompt, image, model_id=model_id)
138
  return results, None
139
- elif task_prompt == 'Caption + Grounding':
140
- task_prompt = '<CAPTION>'
141
  results = run_example(task_prompt, image, model_id=model_id)
142
  text_input = results[task_prompt]
143
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
144
  results = run_example(task_prompt, image, text_input, model_id)
145
- results['<CAPTION>'] = text_input
146
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
147
  return results, fig_to_pil(fig)
148
- elif task_prompt == 'Detailed Caption + Grounding':
149
- task_prompt = '<DETAILED_CAPTION>'
150
  results = run_example(task_prompt, image, model_id=model_id)
151
  text_input = results[task_prompt]
152
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
153
  results = run_example(task_prompt, image, text_input, model_id)
154
- results['<DETAILED_CAPTION>'] = text_input
155
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
156
  return results, fig_to_pil(fig)
157
- elif task_prompt == 'More Detailed Caption + Grounding':
158
- task_prompt = '<MORE_DETAILED_CAPTION>'
159
  results = run_example(task_prompt, image, model_id=model_id)
160
  text_input = results[task_prompt]
161
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
162
  results = run_example(task_prompt, image, text_input, model_id)
163
- results['<MORE_DETAILED_CAPTION>'] = text_input
164
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
165
  return results, fig_to_pil(fig)
166
- elif task_prompt == 'Object Detection':
167
- task_prompt = '<OD>'
168
- results = run_example(task_prompt, image, model_id=model_id)
169
- fig = plot_bbox(image, results['<OD>'])
170
  return results, fig_to_pil(fig)
171
- elif task_prompt == 'Dense Region Caption':
172
- task_prompt = '<DENSE_REGION_CAPTION>'
173
  results = run_example(task_prompt, image, model_id=model_id)
174
- fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
175
  return results, fig_to_pil(fig)
176
- elif task_prompt == 'Region Proposal':
177
- task_prompt = '<REGION_PROPOSAL>'
178
  results = run_example(task_prompt, image, model_id=model_id)
179
- fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
180
  return results, fig_to_pil(fig)
181
- elif task_prompt == 'Caption to Phrase Grounding':
182
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
183
  results = run_example(task_prompt, image, text_input, model_id)
184
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
185
  return results, fig_to_pil(fig)
186
- elif task_prompt == 'Referring Expression Segmentation':
187
- task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
188
  results = run_example(task_prompt, image, text_input, model_id)
189
  output_image = copy.deepcopy(image)
190
- output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
 
 
191
  return results, output_image
192
- elif task_prompt == 'Region to Segmentation':
193
- task_prompt = '<REGION_TO_SEGMENTATION>'
194
  results = run_example(task_prompt, image, text_input, model_id)
195
  output_image = copy.deepcopy(image)
196
- output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
 
 
197
  return results, output_image
198
- elif task_prompt == 'Open Vocabulary Detection':
199
- task_prompt = '<OPEN_VOCABULARY_DETECTION>'
200
  results = run_example(task_prompt, image, text_input, model_id)
201
- bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
202
  fig = plot_bbox(image, bbox_results)
203
  return results, fig_to_pil(fig)
204
- elif task_prompt == 'Region to Category':
205
- task_prompt = '<REGION_TO_CATEGORY>'
206
  results = run_example(task_prompt, image, text_input, model_id)
207
  return results, None
208
- elif task_prompt == 'Region to Description':
209
- task_prompt = '<REGION_TO_DESCRIPTION>'
210
  results = run_example(task_prompt, image, text_input, model_id)
211
  return results, None
212
- elif task_prompt == 'OCR':
213
- task_prompt = '<OCR>'
214
  results = run_example(task_prompt, image, model_id=model_id)
215
  return results, None
216
- elif task_prompt == 'OCR with Region':
217
- task_prompt = '<OCR_WITH_REGION>'
218
  results = run_example(task_prompt, image, model_id=model_id)
219
  output_image = copy.deepcopy(image)
220
- output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
221
  return results, output_image
222
  else:
223
  return "", None # Return empty string and None for unknown task prompts
224
 
 
225
  css = """
226
  #output {
227
- height: 500px;
228
- overflow: auto;
229
- border: 1px solid #ccc;
230
  }
231
  """
232
 
233
 
234
- single_task_list =[
235
- 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
236
- 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
237
- 'Referring Expression Segmentation', 'Region to Segmentation',
238
- 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
239
- 'OCR', 'OCR with Region'
 
 
 
 
 
 
 
 
 
240
  ]
241
 
242
- cascased_task_list =[
243
- 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
 
 
244
  ]
245
 
246
 
247
  def update_task_dropdown(choice):
248
- if choice == 'Cascased task':
249
- return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
250
  else:
251
- return gr.Dropdown(choices=single_task_list, value='Caption')
252
-
253
 
254
 
255
  with gr.Blocks(css=css) as demo:
@@ -258,10 +380,22 @@ with gr.Blocks(css=css) as demo:
258
  with gr.Row():
259
  with gr.Column():
260
  input_img = gr.Image(label="Input Picture")
261
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
262
- task_type = gr.Radio(choices=['Single task', 'Cascased task'], label='Task type selector', value='Single task')
263
- task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
264
- task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
265
  text_input = gr.Textbox(label="Text Input (optional)")
266
  submit_btn = gr.Button(value="Submit")
267
  with gr.Column():
@@ -270,16 +404,20 @@ with gr.Blocks(css=css) as demo:
270
 
271
  gr.Examples(
272
  examples=[
273
- ["image1.jpg", 'Object Detection'],
274
- ["image2.jpg", 'OCR with Region']
275
  ],
276
  inputs=[input_img, task_prompt],
277
  outputs=[output_text, output_img],
278
  fn=process_image,
279
  cache_examples=True,
280
- label='Try examples'
281
  )
282
 
283
- submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
 
 
 
 
284
 
285
  demo.launch(debug=True)
 
5
  import requests
6
  import copy
7
 
8
+ from PIL import Image, ImageDraw, ImageFont
9
  import io
10
  import matplotlib.pyplot as plt
11
  import matplotlib.patches as patches
 
14
  import numpy as np
15
 
16
  import subprocess
17
+
18
+ subprocess.run(
19
+ "pip install flash-attn --no-build-isolation",
20
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
21
+ shell=True,
22
+ )
23
 
24
  models = {
25
+ "microsoft/Florence-2-large-ft": AutoModelForCausalLM.from_pretrained(
26
+ "microsoft/Florence-2-large-ft", trust_remote_code=True
27
+ )
28
+ .to("cuda")
29
+ .eval(),
30
+ "microsoft/Florence-2-large": AutoModelForCausalLM.from_pretrained(
31
+ "microsoft/Florence-2-large", trust_remote_code=True
32
+ )
33
+ .to("cuda")
34
+ .eval(),
35
+ "microsoft/Florence-2-base-ft": AutoModelForCausalLM.from_pretrained(
36
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
37
+ )
38
+ .to("cuda")
39
+ .eval(),
40
+ "microsoft/Florence-2-base": AutoModelForCausalLM.from_pretrained(
41
+ "microsoft/Florence-2-base", trust_remote_code=True
42
+ )
43
+ .to("cuda")
44
+ .eval(),
45
  }
46
 
47
  processors = {
48
+ "microsoft/Florence-2-large-ft": AutoProcessor.from_pretrained(
49
+ "microsoft/Florence-2-large-ft", trust_remote_code=True
50
+ ),
51
+ "microsoft/Florence-2-large": AutoProcessor.from_pretrained(
52
+ "microsoft/Florence-2-large", trust_remote_code=True
53
+ ),
54
+ "microsoft/Florence-2-base-ft": AutoProcessor.from_pretrained(
55
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
56
+ ),
57
+ "microsoft/Florence-2-base": AutoProcessor.from_pretrained(
58
+ "microsoft/Florence-2-base", trust_remote_code=True
59
+ ),
60
  }
61
 
62
 
63
  DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-large)"
64
 
65
+ colormap = [
66
+ "blue",
67
+ "orange",
68
+ "green",
69
+ "purple",
70
+ "brown",
71
+ "pink",
72
+ "gray",
73
+ "olive",
74
+ "cyan",
75
+ "red",
76
+ "lime",
77
+ "indigo",
78
+ "violet",
79
+ "aqua",
80
+ "magenta",
81
+ "coral",
82
+ "gold",
83
+ "tan",
84
+ "skyblue",
85
+ ]
86
+
87
 
88
  def fig_to_pil(fig):
89
  buf = io.BytesIO()
90
+ fig.savefig(buf, format="png")
91
  buf.seek(0)
92
  return Image.open(buf)
93
 
94
+
95
  @spaces.GPU
96
+ def run_example(
97
+ task_prompt, image, text_input=None, model_id="microsoft/Florence-2-large"
98
+ ):
99
  model = models[model_id]
100
  processor = processors[model_id]
101
  if text_input is None:
 
110
  early_stopping=False,
111
  do_sample=False,
112
  num_beams=3,
 
113
  )
114
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
115
  parsed_answer = processor.post_process_generation(
116
+ generated_text, task=task_prompt, image_size=(image.width, image.height)
117
+ )
118
+ return parsed_answer
119
+
120
+
121
+ def run_example_with_score(
122
+ task_prompt, image, text_input=None, model_id="microsoft/Florence-2-large"
123
+ ):
124
+ model = models[model_id]
125
+ processor = processors[model_id]
126
+ if text_input is None:
127
+ prompt = task_prompt
128
+ else:
129
+ prompt = task_prompt + text_input
130
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
131
+ generated_ids = model.generate(
132
+ input_ids=inputs["input_ids"],
133
+ pixel_values=inputs["pixel_values"],
134
+ max_new_tokens=1024,
135
+ num_beams=3,
136
+ return_dict_in_generate=True,
137
+ output_scores=True,
138
+ )
139
+ generated_text = processor.batch_decode(
140
+ generated_ids.sequences, skip_special_tokens=False
141
+ )[0]
142
+
143
+ prediction, scores, beam_indices = (
144
+ generated_ids.sequences,
145
+ generated_ids.scores,
146
+ generated_ids.beam_indices,
147
+ )
148
+ transition_beam_scores = model.compute_transition_scores(
149
+ sequences=prediction,
150
+ scores=scores,
151
+ beam_indices=beam_indices,
152
+ )
153
+
154
+ parsed_answer = processor.post_process_generation(
155
+ sequence=generated_ids.sequences[0],
156
+ transition_beam_score=transition_beam_scores[0],
157
  task=task_prompt,
158
+ image_size=(image.width, image.height),
159
  )
160
+
161
  return parsed_answer
162
 
163
+
164
  def plot_bbox(image, data):
165
  fig, ax = plt.subplots()
166
  ax.imshow(image)
167
+ for bbox, label in zip(data["bboxes"], data["labels"]):
168
  x1, y1, x2, y2 = bbox
169
+ rect = patches.Rectangle(
170
+ (x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor="r", facecolor="none"
171
+ )
172
  ax.add_patch(rect)
173
+ plt.text(
174
+ x1,
175
+ y1,
176
+ label,
177
+ color="white",
178
+ fontsize=8,
179
+ bbox=dict(facecolor="red", alpha=0.5),
180
+ )
181
+ ax.axis("off")
182
  return fig
183
 
184
+
185
  def draw_polygons(image, prediction, fill_mask=False):
186
 
187
  draw = ImageDraw.Draw(image)
188
  scale = 1
189
+ for polygons, label in zip(prediction["polygons"], prediction["labels"]):
190
  color = random.choice(colormap)
191
  fill_color = random.choice(colormap) if fill_mask else None
192
  for _polygon in polygons:
193
  _polygon = np.array(_polygon).reshape(-1, 2)
194
  if len(_polygon) < 3:
195
+ print("Invalid polygon:", _polygon)
196
  continue
197
  _polygon = (_polygon * scale).reshape(-1).tolist()
198
  if fill_mask:
 
202
  draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
203
  return image
204
 
205
+
206
  def convert_to_od_format(data):
207
+ bboxes = data.get("bboxes", [])
208
+ labels = data.get("bboxes_labels", [])
209
+ od_results = {"bboxes": bboxes, "labels": labels}
 
 
 
210
  return od_results
211
 
212
+
213
  def draw_ocr_bboxes(image, prediction):
214
  scale = 1
215
  draw = ImageDraw.Draw(image)
216
+ bboxes, labels = prediction["quad_boxes"], prediction["labels"]
217
  for box, label in zip(bboxes, labels):
218
  color = random.choice(colormap)
219
  new_box = (np.array(box) * scale).tolist()
220
  draw.polygon(new_box, width=3, outline=color)
221
+ draw.text(
222
+ (new_box[0] + 8, new_box[1] + 2),
223
+ "{}".format(label),
224
+ align="right",
225
+ fill=color,
226
+ )
227
  return image
228
 
229
+
230
+ def process_image(
231
+ image, task_prompt, text_input=None, model_id="microsoft/Florence-2-large"
232
+ ):
233
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
234
+ if task_prompt == "Caption":
235
+ task_prompt = "<CAPTION>"
236
  results = run_example(task_prompt, image, model_id=model_id)
237
  return results, None
238
+ elif task_prompt == "Detailed Caption":
239
+ task_prompt = "<DETAILED_CAPTION>"
240
  results = run_example(task_prompt, image, model_id=model_id)
241
  return results, None
242
+ elif task_prompt == "More Detailed Caption":
243
+ task_prompt = "<MORE_DETAILED_CAPTION>"
244
  results = run_example(task_prompt, image, model_id=model_id)
245
  return results, None
246
+ elif task_prompt == "Caption + Grounding":
247
+ task_prompt = "<CAPTION>"
248
  results = run_example(task_prompt, image, model_id=model_id)
249
  text_input = results[task_prompt]
250
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
251
  results = run_example(task_prompt, image, text_input, model_id)
252
+ results["<CAPTION>"] = text_input
253
+ fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"])
254
  return results, fig_to_pil(fig)
255
+ elif task_prompt == "Detailed Caption + Grounding":
256
+ task_prompt = "<DETAILED_CAPTION>"
257
  results = run_example(task_prompt, image, model_id=model_id)
258
  text_input = results[task_prompt]
259
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
260
  results = run_example(task_prompt, image, text_input, model_id)
261
+ results["<DETAILED_CAPTION>"] = text_input
262
+ fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"])
263
  return results, fig_to_pil(fig)
264
+ elif task_prompt == "More Detailed Caption + Grounding":
265
+ task_prompt = "<MORE_DETAILED_CAPTION>"
266
  results = run_example(task_prompt, image, model_id=model_id)
267
  text_input = results[task_prompt]
268
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
269
  results = run_example(task_prompt, image, text_input, model_id)
270
+ results["<MORE_DETAILED_CAPTION>"] = text_input
271
+ fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"])
272
  return results, fig_to_pil(fig)
273
+ elif task_prompt == "Object Detection":
274
+ task_prompt = "<OD>"
275
+ results = run_example_with_score(task_prompt, image, model_id=model_id)
276
+ fig = plot_bbox(image, results["<OD>"])
277
  return results, fig_to_pil(fig)
278
+ elif task_prompt == "Dense Region Caption":
279
+ task_prompt = "<DENSE_REGION_CAPTION>"
280
  results = run_example(task_prompt, image, model_id=model_id)
281
+ fig = plot_bbox(image, results["<DENSE_REGION_CAPTION>"])
282
  return results, fig_to_pil(fig)
283
+ elif task_prompt == "Region Proposal":
284
+ task_prompt = "<REGION_PROPOSAL>"
285
  results = run_example(task_prompt, image, model_id=model_id)
286
+ fig = plot_bbox(image, results["<REGION_PROPOSAL>"])
287
  return results, fig_to_pil(fig)
288
+ elif task_prompt == "Caption to Phrase Grounding":
289
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
290
  results = run_example(task_prompt, image, text_input, model_id)
291
+ fig = plot_bbox(image, results["<CAPTION_TO_PHRASE_GROUNDING>"])
292
  return results, fig_to_pil(fig)
293
+ elif task_prompt == "Referring Expression Segmentation":
294
+ task_prompt = "<REFERRING_EXPRESSION_SEGMENTATION>"
295
  results = run_example(task_prompt, image, text_input, model_id)
296
  output_image = copy.deepcopy(image)
297
+ output_image = draw_polygons(
298
+ output_image, results["<REFERRING_EXPRESSION_SEGMENTATION>"], fill_mask=True
299
+ )
300
  return results, output_image
301
+ elif task_prompt == "Region to Segmentation":
302
+ task_prompt = "<REGION_TO_SEGMENTATION>"
303
  results = run_example(task_prompt, image, text_input, model_id)
304
  output_image = copy.deepcopy(image)
305
+ output_image = draw_polygons(
306
+ output_image, results["<REGION_TO_SEGMENTATION>"], fill_mask=True
307
+ )
308
  return results, output_image
309
+ elif task_prompt == "Open Vocabulary Detection":
310
+ task_prompt = "<OPEN_VOCABULARY_DETECTION>"
311
  results = run_example(task_prompt, image, text_input, model_id)
312
+ bbox_results = convert_to_od_format(results["<OPEN_VOCABULARY_DETECTION>"])
313
  fig = plot_bbox(image, bbox_results)
314
  return results, fig_to_pil(fig)
315
+ elif task_prompt == "Region to Category":
316
+ task_prompt = "<REGION_TO_CATEGORY>"
317
  results = run_example(task_prompt, image, text_input, model_id)
318
  return results, None
319
+ elif task_prompt == "Region to Description":
320
+ task_prompt = "<REGION_TO_DESCRIPTION>"
321
  results = run_example(task_prompt, image, text_input, model_id)
322
  return results, None
323
+ elif task_prompt == "OCR":
324
+ task_prompt = "<OCR>"
325
  results = run_example(task_prompt, image, model_id=model_id)
326
  return results, None
327
+ elif task_prompt == "OCR with Region":
328
+ task_prompt = "<OCR_WITH_REGION>"
329
  results = run_example(task_prompt, image, model_id=model_id)
330
  output_image = copy.deepcopy(image)
331
+ output_image = draw_ocr_bboxes(output_image, results["<OCR_WITH_REGION>"])
332
  return results, output_image
333
  else:
334
  return "", None # Return empty string and None for unknown task prompts
335
 
336
+
337
  css = """
338
  #output {
339
+ height: 500px;
340
+ overflow: auto;
341
+ border: 1px solid #ccc;
342
  }
343
  """
344
 
345
 
346
+ single_task_list = [
347
+ "Caption",
348
+ "Detailed Caption",
349
+ "More Detailed Caption",
350
+ "Object Detection",
351
+ "Dense Region Caption",
352
+ "Region Proposal",
353
+ "Caption to Phrase Grounding",
354
+ "Referring Expression Segmentation",
355
+ "Region to Segmentation",
356
+ "Open Vocabulary Detection",
357
+ "Region to Category",
358
+ "Region to Description",
359
+ "OCR",
360
+ "OCR with Region",
361
  ]
362
 
363
+ cascased_task_list = [
364
+ "Caption + Grounding",
365
+ "Detailed Caption + Grounding",
366
+ "More Detailed Caption + Grounding",
367
  ]
368
 
369
 
370
  def update_task_dropdown(choice):
371
+ if choice == "Cascased task":
372
+ return gr.Dropdown(choices=cascased_task_list, value="Caption + Grounding")
373
  else:
374
+ return gr.Dropdown(choices=single_task_list, value="Caption")
 
375
 
376
 
377
  with gr.Blocks(css=css) as demo:
 
380
  with gr.Row():
381
  with gr.Column():
382
  input_img = gr.Image(label="Input Picture")
383
+ model_selector = gr.Dropdown(
384
+ choices=list(models.keys()),
385
+ label="Model",
386
+ value="microsoft/Florence-2-large",
387
+ )
388
+ task_type = gr.Radio(
389
+ choices=["Single task", "Cascased task"],
390
+ label="Task type selector",
391
+ value="Single task",
392
+ )
393
+ task_prompt = gr.Dropdown(
394
+ choices=single_task_list, label="Task Prompt", value="Caption"
395
+ )
396
+ task_type.change(
397
+ fn=update_task_dropdown, inputs=task_type, outputs=task_prompt
398
+ )
399
  text_input = gr.Textbox(label="Text Input (optional)")
400
  submit_btn = gr.Button(value="Submit")
401
  with gr.Column():
 
404
 
405
  gr.Examples(
406
  examples=[
407
+ ["image1.jpg", "Object Detection"],
408
+ ["image2.jpg", "OCR with Region"],
409
  ],
410
  inputs=[input_img, task_prompt],
411
  outputs=[output_text, output_img],
412
  fn=process_image,
413
  cache_examples=True,
414
+ label="Try examples",
415
  )
416
 
417
+ submit_btn.click(
418
+ process_image,
419
+ [input_img, task_prompt, text_input, model_selector],
420
+ [output_text, output_img],
421
+ )
422
 
423
  demo.launch(debug=True)