ntsc207 commited on
Commit
898ef8d
·
verified ·
1 Parent(s): 0527830

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -5,7 +5,7 @@ from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
7
  import threading
8
-
9
  should_continue = True
10
 
11
 
@@ -34,7 +34,8 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
34
  if tracking_algorithm == 'deep_sort':
35
  output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
36
  elif tracking_algorithm == 'strong_sort':
37
- output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
 
38
  else:
39
  output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
40
  # Assuming output_path is the path to the output file
 
5
  from detect import run
6
  import os
7
  import threading
8
+ import torch
9
  should_continue = True
10
 
11
 
 
34
  if tracking_algorithm == 'deep_sort':
35
  output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
36
  elif tracking_algorithm == 'strong_sort':
37
+ device_strongsort = torch.device('cuda:0')
38
+ output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
39
  else:
40
  output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
41
  # Assuming output_path is the path to the output file