ntsc207 commited on
Commit
fd60a59
·
verified ·
1 Parent(s): 912adfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -21
app.py CHANGED
@@ -33,7 +33,7 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
@@ -70,12 +70,12 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
70
  out.release()
71
  input_path = vid_name
72
  if tracking_algorithm == 'deep_sort':
73
- output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
- device_strongsort = torch.device('cuda:0')
76
- output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True, hide_labels = True)
77
  else:
78
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
@@ -86,16 +86,25 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
86
  fig, ax = plt.subplots(figsize=(10, 6))
87
  #for label in labels:
88
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
89
- sns.barplot(ax = ax, data = df, x = 'label', y = 'count', palette=palette, hue = 'label', legend = False)
 
90
 
91
  # Customizations
92
- ax.set_title('Number of Objects', fontsize=20)
93
- ax.set_xlabel('Object Class', fontsize=15)
94
- ax.set_ylabel('Object Count', fontsize=15)
95
- ax.tick_params(axis='x', rotation=45) # Rotate x-axis labels for better readability
 
96
  sns.despine() # Remove the top and right spines from plot
97
- #ax.legend()
98
- ax.grid(True)
 
 
 
 
 
 
 
99
  #ax.set_facecolor('#D3D3D3')
100
  elif output_extension.lower() in vid_extensions:
101
  output_video = output_path # Load the video file here
@@ -104,14 +113,25 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
104
  fig, ax = plt.subplots(figsize=(10, 6))
105
  #for label in labels:
106
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
107
- sns.lineplot(ax = ax, data = frame_counts_df[::30], x = 'frame', y = 'count', hue = 'label', palette = palette)
108
-
109
- ax.set_xlabel('Second')
110
- ax.set_ylabel('Object Count')
111
- ax.set_title('Number of Objects over Seconds')
112
- ax.legend()
113
- ax.grid(True)
114
- ax.set_facecolor('#D3D3D3')
 
 
 
 
 
 
 
 
 
 
 
115
  return output_image, output_video, fig
116
 
117
  def app():
@@ -167,6 +187,7 @@ def app():
167
  outputs=[output_image, output_video, fig],
168
  )
169
 
 
170
  gradio_app = gr.Blocks()
171
  with gradio_app:
172
  gr.HTML(
@@ -188,4 +209,3 @@ with gradio_app:
188
  app()
189
 
190
  gradio_app.launch(debug=True)
191
-
 
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
 
70
  out.release()
71
  input_path = vid_name
72
  if tracking_algorithm == 'deep_sort':
73
+ output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
+ device_strongsort = torch.device('cpu')
76
+ output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
77
  else:
78
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
 
86
  fig, ax = plt.subplots(figsize=(10, 6))
87
  #for label in labels:
88
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
89
+
90
+ sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label')
91
 
92
  # Customizations
93
+ ax.set_title('Count of Labels', fontsize=20, pad=20) # Increase padding for the title
94
+ ax.set_xlabel('Label', fontsize=16) # Increase font size
95
+ ax.set_ylabel('Count', fontsize=16) # Increase font size
96
+ ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability
97
+ ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
98
  sns.despine() # Remove the top and right spines from plot
99
+
100
+ # Add grid but make it lighter and put it behind bars
101
+ ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
102
+ ax.set_axisbelow(True)
103
+
104
+ # Add a legend with a smaller font size
105
+ ax.legend(fontsize=10)
106
+
107
+ plt.tight_layout() # Ensure the entire plot fits into the figure area
108
  #ax.set_facecolor('#D3D3D3')
109
  elif output_extension.lower() in vid_extensions:
110
  output_video = output_path # Load the video file here
 
113
  fig, ax = plt.subplots(figsize=(10, 6))
114
  #for label in labels:
115
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
116
+ sns.lineplot(ax = ax, data = frame_counts_df[::4], x = 'frame', y = 'count', hue = 'label', palette=palette, linewidth=2.5)
117
+
118
+ ax.set_title('Count of Labels over Frames', fontsize=20, pad=20) # Increase padding for the title
119
+ ax.set_xlabel('Frame', fontsize=16) # Increase font size
120
+ ax.set_ylabel('Count', fontsize=16) # Increase font size
121
+ ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis
122
+ ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
123
+
124
+ # Add grid but make it lighter and put it behind bars
125
+ ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
126
+ ax.set_axisbelow(True)
127
+
128
+ # Change the background color to a lighter shade
129
+ ax.set_facecolor('#F0F0F0')
130
+
131
+ # Add a legend with a smaller font size
132
+ ax.legend(fontsize=10)
133
+
134
+ plt.tight_layout() # Ensure the entire
135
  return output_image, output_video, fig
136
 
137
  def app():
 
187
  outputs=[output_image, output_video, fig],
188
  )
189
 
190
+
191
  gradio_app = gr.Blocks()
192
  with gradio_app:
193
  gr.HTML(
 
209
  app()
210
 
211
  gradio_app.launch(debug=True)