ntsc207 commited on
Commit
e2da6d0
·
verified ·
1 Parent(s): 1729f65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -34,7 +34,7 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
34
  img.save(img_path)
35
  input_path = img_path
36
 
37
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
38
  elif vid_path is not None:
39
  vid_name = 'output.mp4'
40
 
@@ -74,9 +74,9 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
74
  output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
75
  elif tracking_algorithm == 'strong_sort':
76
  device_strongsort = torch.device('cuda:0')
77
- output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
78
  else:
79
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
80
  # Assuming output_path is the path to the output file
81
  _, output_extension = os.path.splitext(output_path)
82
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
@@ -91,9 +91,9 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
91
  sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label')
92
 
93
  # Customizations
94
- ax.set_title('Count of Labels', fontsize=20, pad=20) # Increase padding for the title
95
- ax.set_xlabel('Label', fontsize=16) # Increase font size
96
- ax.set_ylabel('Count', fontsize=16) # Increase font size
97
  ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability
98
  ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
99
  sns.despine() # Remove the top and right spines from plot
@@ -116,9 +116,9 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
116
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
117
  sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label', palette=palette,linewidth=2.5)
118
 
119
- ax.set_title('Count of Labels over Frames', fontsize=20, pad=20) # Increase padding for the title
120
- ax.set_xlabel('Frame', fontsize=16) # Increase font size
121
- ax.set_ylabel('Count', fontsize=16) # Increase font size
122
  ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis
123
  ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
124
 
 
34
  img.save(img_path)
35
  input_path = img_path
36
 
37
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
38
  elif vid_path is not None:
39
  vid_name = 'output.mp4'
40
 
 
74
  output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
75
  elif tracking_algorithm == 'strong_sort':
76
  device_strongsort = torch.device('cuda:0')
77
+ output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True,hide_labels = True)
78
  else:
79
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
80
  # Assuming output_path is the path to the output file
81
  _, output_extension = os.path.splitext(output_path)
82
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
 
91
  sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label')
92
 
93
  # Customizations
94
+ ax.set_title('Number of Objects', fontsize=20, pad=20) # Increase padding for the title
95
+ ax.set_xlabel('Object Class', fontsize=16) # Increase font size
96
+ ax.set_ylabel('Object Count', fontsize=16) # Increase font size
97
  ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability
98
  ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
99
  sns.despine() # Remove the top and right spines from plot
 
116
  #df_label = frame_counts_df[frame_counts_df['label'] == label]
117
  sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label', palette=palette,linewidth=2.5)
118
 
119
+ ax.set_title('Number of Objects over Seconds', fontsize=20, pad=20) # Increase padding for the title
120
+ ax.set_xlabel('Second', fontsize=16) # Increase font size
121
+ ax.set_ylabel('Object Count', fontsize=16) # Increase font size
122
  ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis
123
  ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
124