ntsc207 commited on
Commit
1f32ffc
·
verified ·
1 Parent(s): fcbb8b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -20
app.py CHANGED
@@ -1,18 +1,20 @@
1
- import spaces
2
  import gradio as gr
3
  from detect_deepsort import run_deepsort
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
7
  import torch
 
8
  from PIL import Image
9
  import cv2
10
  import numpy as np
 
11
  import threading
12
 
13
  should_continue = True
14
 
15
- @spaces.GPU(duration=240)
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
@@ -30,8 +32,8 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
30
  # Save the image
31
  img.save(img_path)
32
  input_path = img_path
33
- print(input_path)
34
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
35
  elif vid_path is not None:
36
  vid_name = 'output.mp4'
37
 
@@ -68,23 +70,48 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
68
  out.release()
69
  input_path = vid_name
70
  if tracking_algorithm == 'deep_sort':
71
- output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
72
  elif tracking_algorithm == 'strong_sort':
73
  device_strongsort = torch.device('cuda:0')
74
- output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
75
  else:
76
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
77
  # Assuming output_path is the path to the output file
78
  _, output_extension = os.path.splitext(output_path)
79
- output_image = None
80
- output_video = None
81
  if output_extension.lower() in img_extensions:
82
  output_image = output_path # Load the image file here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  elif output_extension.lower() in vid_extensions:
84
  output_video = output_path # Load the video file here
85
-
86
- return output_image, output_video, output_path
87
-
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def app():
90
  with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"):
@@ -97,8 +124,7 @@ def app():
97
  label="Model",
98
  choices=[
99
  "our-converted.pt",
100
- "yolov9_e_trained-converted.pt",
101
- "our-best-converted-120ep.pt"
102
  ],
103
  value="our-converted.pt"
104
 
@@ -113,13 +139,17 @@ def app():
113
  value="None"
114
  )
115
  yolov9_infer = gr.Button(value="Inference")
116
- gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples= False)
117
- gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples= False)
118
  with gr.Column():
119
  gr.HTML("<h2>Output</h2>")
120
  output_image = gr.Image(type="numpy",label="Output")
121
- output_video = gr.Video(label="Output")
122
- output_path = gr.Textbox(label="Output path")
 
 
 
 
123
 
124
  yolov9_infer.click(
125
  fn=yolov9_inference,
@@ -129,7 +159,7 @@ def app():
129
  vid_path,
130
  tracking_algorithm
131
  ],
132
- outputs=[output_image, output_video, output_path],
133
  )
134
 
135
 
@@ -155,4 +185,3 @@ with gradio_app:
155
 
156
  gradio_app.launch(debug=True)
157
 
158
-
 
1
+ #import spaces
2
  import gradio as gr
3
  from detect_deepsort import run_deepsort
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
7
  import torch
8
+ import seaborn as sns
9
  from PIL import Image
10
  import cv2
11
  import numpy as np
12
+ import matplotlib.pyplot as plt
13
  import threading
14
 
15
  should_continue = True
16
 
17
+ #@spaces.GPU(duration=120)
18
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
19
  global should_continue
20
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
 
32
  # Save the image
33
  img.save(img_path)
34
  input_path = img_path
35
+
36
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
 
70
  out.release()
71
  input_path = vid_name
72
  if tracking_algorithm == 'deep_sort':
73
+ output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
  device_strongsort = torch.device('cuda:0')
76
+ output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
77
  else:
78
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
 
 
81
  if output_extension.lower() in img_extensions:
82
  output_image = output_path # Load the image file here
83
+ output_video = None
84
+ plt.style.use("ggplot")
85
+ fig, ax = plt.subplots(figsize=(10, 6))
86
+ #for label in labels:
87
+ #df_label = frame_counts_df[frame_counts_df['label'] == label]
88
+ sns.barplot(ax = ax, data = df, x = 'label', y = 'count', palette='viridis', hue = 'label', legend = False)
89
+
90
+ # Customizations
91
+ ax.set_title('Count of Labels', fontsize=20)
92
+ ax.set_xlabel('Label', fontsize=15)
93
+ ax.set_ylabel('Count', fontsize=15)
94
+ ax.tick_params(axis='x', rotation=45) # Rotate x-axis labels for better readability
95
+ sns.despine() # Remove the top and right spines from plot
96
+ #ax.legend()
97
+ ax.grid(True)
98
+ #ax.set_facecolor('#D3D3D3')
99
  elif output_extension.lower() in vid_extensions:
100
  output_video = output_path # Load the video file here
101
+ output_image = None
102
+ plt.style.use("ggplot")
103
+ fig, ax = plt.subplots(figsize=(10, 6))
104
+ #for label in labels:
105
+ #df_label = frame_counts_df[frame_counts_df['label'] == label]
106
+ sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label')
107
+
108
+ ax.set_xlabel('Frame')
109
+ ax.set_ylabel('Count')
110
+ ax.set_title('Count of Labels over Frames')
111
+ ax.legend()
112
+ ax.grid(True)
113
+ ax.set_facecolor('#D3D3D3')
114
+ return output_image, output_video, fig
115
 
116
  def app():
117
  with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"):
 
124
  label="Model",
125
  choices=[
126
  "our-converted.pt",
127
+ "last_best_model.pt"
 
128
  ],
129
  value="our-converted.pt"
130
 
 
139
  value="None"
140
  )
141
  yolov9_infer = gr.Button(value="Inference")
142
+ gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False)
143
+ gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False)
144
  with gr.Column():
145
  gr.HTML("<h2>Output</h2>")
146
  output_image = gr.Image(type="numpy",label="Output")
147
+ #df = gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False)
148
+ output_video = gr.Video(label="Output")
149
+ #frame_counts_df = gr.LinePlot(show_label=False, x="frame", y="count", x_title="Frame", y_title="Counts", color="label")
150
+ fig = gr.Plot(label = "Plot")
151
+ #output_path = gr.Textbox(label="Output path")
152
+
153
 
154
  yolov9_infer.click(
155
  fn=yolov9_inference,
 
159
  vid_path,
160
  tracking_algorithm
161
  ],
162
+ outputs=[output_image, output_video, fig],
163
  )
164
 
165
 
 
185
 
186
  gradio_app.launch(debug=True)
187