sergiopaniego HF Staff commited on
Commit
8d73cc6
·
verified ·
1 Parent(s): 29d9240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -35
app.py CHANGED
@@ -4,37 +4,16 @@ import torch
4
 
5
  from PIL import Image
6
  import requests
7
- #from transformers import DetrImageProcessor
8
- #from transformers import DetrForObjectDetection
9
  from transformers import pipeline
10
  import matplotlib.pyplot as plt
11
  import io
12
 
13
-
14
- #processor = DetrImageProcessor.from_pretrained("sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned")
15
- #model = DetrForObjectDetection.from_pretrained("sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned")
16
  model_pipeline = pipeline("object-detection", model="sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned")
17
 
18
 
19
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
20
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
21
 
22
- '''
23
- def get_output_figure(pil_img, scores, labels, boxes, threshold):
24
- plt.figure(figsize=(16, 10))
25
- plt.imshow(pil_img)
26
- ax = plt.gca()
27
- colors = COLORS * 100
28
- for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
29
- if score > threshold:
30
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
31
- text = f'{model.config.id2label[label]}: {score:0.2f}'
32
- ax.text(xmin, ymin, text, fontsize=15,
33
- bbox=dict(facecolor='yellow', alpha=0.5))
34
- plt.axis('off')
35
-
36
- return plt.gcf()
37
- '''
38
 
39
  def get_output_figure(pil_img, results, threshold):
40
  plt.figure(figsize=(16, 10))
@@ -58,23 +37,10 @@ def get_output_figure(pil_img, results, threshold):
58
 
59
  @spaces.GPU
60
  def detect(image):
61
- #encoding = processor(image, return_tensors='pt')
62
- #print(encoding.keys())
63
-
64
- #with torch.no_grad():
65
- # outputs = model(**encoding)
66
-
67
-
68
  results = model_pipeline(image)
69
  print(results)
70
 
71
- #width, height = image.size
72
- #postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.5)
73
- #results = postprocessed_outputs[0]
74
-
75
-
76
- #output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'], threshold=0.5)
77
- output_figure = get_output_figure(image, results, threshold=0.5)
78
 
79
  buf = io.BytesIO()
80
  output_figure.savefig(buf, bbox_inches='tight')
 
4
 
5
  from PIL import Image
6
  import requests
 
 
7
  from transformers import pipeline
8
  import matplotlib.pyplot as plt
9
  import io
10
 
 
 
 
11
  model_pipeline = pipeline("object-detection", model="sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned")
12
 
13
 
14
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
15
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def get_output_figure(pil_img, results, threshold):
19
  plt.figure(figsize=(16, 10))
 
37
 
38
  @spaces.GPU
39
  def detect(image):
 
 
 
 
 
 
 
40
  results = model_pipeline(image)
41
  print(results)
42
 
43
+ output_figure = get_output_figure(image, results, threshold=0.7)
 
 
 
 
 
 
44
 
45
  buf = io.BytesIO()
46
  output_figure.savefig(buf, bbox_inches='tight')