ntsc207 commited on
Commit
b03f92b
·
verified ·
1 Parent(s): a355dad

Update detect_deepsort.py

Browse files
Files changed (1) hide show
  1. detect_deepsort.py +63 -53
detect_deepsort.py CHANGED
@@ -6,8 +6,10 @@ from pathlib import Path
6
  import math
7
  import torch
8
  import numpy as np
 
9
  from deep_sort_pytorch.utils.parser import get_config
10
  from deep_sort_pytorch.deep_sort import DeepSort
 
11
  from collections import deque
12
  FILE = Path(__file__).resolve()
13
  ROOT = FILE.parents[0] # YOLO root directory
@@ -22,50 +24,59 @@ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_im
22
  from utils.plots import Annotator, colors, save_one_box
23
  from utils.torch_utils import select_device, smart_inference_mode
24
 
25
- # def initialize_deepsort():
26
- # # Create the Deep SORT configuration object and load settings from the YAML file
27
- # cfg_deep = get_config()
28
- # cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
29
-
30
- # # Initialize the DeepSort tracker
31
- # deepsort = DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
32
- # max_dist=cfg_deep.DEEPSORT.MAX_DIST,
33
- # # min_confidence parameter sets the minimum tracking confidence required for an object detection to be considered in the tracking process
34
- # min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
35
- # #nms_max_overlap specifies the maximum allowed overlap between bounding boxes during non-maximum suppression (NMS)
36
- # nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP,
37
- # #max_iou_distance parameter defines the maximum intersection-over-union (IoU) distance between object detections
38
- # max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
39
- # # Max_age: If an object's tracking ID is lost (i.e., the object is no longer detected), this parameter determines how many frames the tracker should wait before assigning a new id
40
- # max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT,
41
- # #nn_budget: It sets the budget for the nearest-neighbor search.
42
- # nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
43
- # use_cuda=False
44
- # )
45
-
46
- # return deepsort
47
-
48
- #deepsort = initialize_deepsort()
49
  data_deque = {}
50
  def classNames():
51
  cocoClassNames = ["Bus", "Bike", "Car", "Pedestrian", "Truck"
52
  ]
53
  return cocoClassNames
54
  className = classNames()
55
-
 
 
 
 
 
 
56
  def colorLabels(classid):
57
  if classid == 0: #Bus
58
  color = (0, 0, 255)
59
  elif classid == 1: #Bike 250, 247, 0
60
- color = (250, 247, 0)
61
  elif classid == 2: #Car
62
  color = (0, 255, 10)
63
  elif classid == 3: #Pedestrian
64
- color = (0,148,255)
65
  else: #Truck
66
  color = (235,0,255)
67
  return tuple(color)
68
 
 
 
 
69
  def draw_boxes(frame, bbox_xyxy, draw_trails, identities=None, categories=None, offset=(0,0)):
70
  height, width, _ = frame.shape
71
  for key in list(data_deque):
@@ -146,27 +157,6 @@ def run_deepsort(
146
  save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
147
  save_dir.mkdir(parents=True, exist_ok=True) # make dir
148
 
149
- #Initalize deepsort
150
- # Create the Deep SORT configuration object and load settings from the YAML file
151
- cfg_deep = get_config()
152
- cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
153
-
154
- # Initialize the DeepSort tracker
155
- deepsort = DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
156
- max_dist=cfg_deep.DEEPSORT.MAX_DIST,
157
- # min_confidence parameter sets the minimum tracking confidence required for an object detection to be considered in the tracking process
158
- min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
159
- #nms_max_overlap specifies the maximum allowed overlap between bounding boxes during non-maximum suppression (NMS)
160
- nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP,
161
- #max_iou_distance parameter defines the maximum intersection-over-union (IoU) distance between object detections
162
- max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
163
- # Max_age: If an object's tracking ID is lost (i.e., the object is no longer detected), this parameter determines how many frames the tracker should wait before assigning a new id
164
- max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT,
165
- #nn_budget: It sets the budget for the nearest-neighbor search.
166
- nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
167
- use_cuda=True
168
- )
169
-
170
  # Load model
171
  device = select_device(device)
172
  model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
@@ -188,6 +178,7 @@ def run_deepsort(
188
  # Run inference
189
  model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
190
  seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
 
191
  for path, im, im0s, vid_cap, s in dataset:
192
  with dt[0]:
193
  im = torch.from_numpy(im).to(model.device)
@@ -200,16 +191,15 @@ def run_deepsort(
200
  with dt[1]:
201
  visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
202
  pred = model(im, augment=augment, visualize=visualize)
203
- # pred = pred[0][1]
204
 
205
  # NMS
206
  with dt[2]:
207
- pred = pred[0][1] if isinstance(pred[0], list) else pred[0] # single model or ensemble
208
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
209
 
210
  # Second-stage classifier (optional)
211
  # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
212
-
213
  # Process predictions
214
  for i, det in enumerate(pred): # per image
215
  seen += 1
@@ -233,6 +223,7 @@ def run_deepsort(
233
  for c in det[:, 5].unique():
234
  n = (det[:, 5] == c).sum() # detections per class
235
  s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
 
236
  xywh_bboxs = []
237
  confs = []
238
  oids = []
@@ -253,7 +244,7 @@ def run_deepsort(
253
  classNameInt = int(cls)
254
  oids.append(classNameInt)
255
  xywhs = torch.tensor(xywh_bboxs)
256
- confss = torch.tensor(confs)
257
  outputs = deepsort.update(xywhs, confss, oids, ims)
258
  if len(outputs) > 0:
259
  bbox_xyxy = outputs[:, :4]
@@ -287,9 +278,28 @@ def run_deepsort(
287
 
288
  # Print time (inference-only)
289
  LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  if update:
291
  strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
292
- return save_path
 
293
 
294
  def parse_opt():
295
  parser = argparse.ArgumentParser()
 
6
  import math
7
  import torch
8
  import numpy as np
9
+ import re
10
  from deep_sort_pytorch.utils.parser import get_config
11
  from deep_sort_pytorch.deep_sort import DeepSort
12
+ import pandas as pd
13
  from collections import deque
14
  FILE = Path(__file__).resolve()
15
  ROOT = FILE.parents[0] # YOLO root directory
 
24
  from utils.plots import Annotator, colors, save_one_box
25
  from utils.torch_utils import select_device, smart_inference_mode
26
 
27
+ def initialize_deepsort():
28
+ # Create the Deep SORT configuration object and load settings from the YAML file
29
+ cfg_deep = get_config()
30
+ cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
31
+
32
+ # Initialize the DeepSort tracker
33
+ deepsort = DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
34
+ max_dist=cfg_deep.DEEPSORT.MAX_DIST,
35
+ # min_confidence parameter sets the minimum tracking confidence required for an object detection to be considered in the tracking process
36
+ min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
37
+ #nms_max_overlap specifies the maximum allowed overlap between bounding boxes during non-maximum suppression (NMS)
38
+ nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP,
39
+ #max_iou_distance parameter defines the maximum intersection-over-union (IoU) distance between object detections
40
+ max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
41
+ # Max_age: If an object's tracking ID is lost (i.e., the object is no longer detected), this parameter determines how many frames the tracker should wait before assigning a new id
42
+ max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT,
43
+ #nn_budget: It sets the budget for the nearest-neighbor search.
44
+ nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
45
+ use_cuda=False
46
+ )
47
+
48
+ return deepsort
49
+
50
+ deepsort = initialize_deepsort()
51
  data_deque = {}
52
  def classNames():
53
  cocoClassNames = ["Bus", "Bike", "Car", "Pedestrian", "Truck"
54
  ]
55
  return cocoClassNames
56
  className = classNames()
57
+ # def convert_to_int(x):
58
+ # if isinstance(x, str):
59
+ # # Extract numeric value from tensor string using regular expressions
60
+ # match = re.match(r'tensor\((\d+)\)', x)
61
+ # if match:
62
+ # return int(match.group(1))
63
+ # return x
64
  def colorLabels(classid):
65
  if classid == 0: #Bus
66
  color = (0, 0, 255)
67
  elif classid == 1: #Bike 250, 247, 0
68
+ color = (0,148,255)
69
  elif classid == 2: #Car
70
  color = (0, 255, 10)
71
  elif classid == 3: #Pedestrian
72
+ color = (250,247,0)
73
  else: #Truck
74
  color = (235,0,255)
75
  return tuple(color)
76
 
77
+ def convert_to_int(tensor):
78
+ return tensor.type(torch.int16).item()
79
+
80
  def draw_boxes(frame, bbox_xyxy, draw_trails, identities=None, categories=None, offset=(0,0)):
81
  height, width, _ = frame.shape
82
  for key in list(data_deque):
 
157
  save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
158
  save_dir.mkdir(parents=True, exist_ok=True) # make dir
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # Load model
161
  device = select_device(device)
162
  model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
 
178
  # Run inference
179
  model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
180
  seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
181
+ frame_counts = []
182
  for path, im, im0s, vid_cap, s in dataset:
183
  with dt[0]:
184
  im = torch.from_numpy(im).to(model.device)
 
191
  with dt[1]:
192
  visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
193
  pred = model(im, augment=augment, visualize=visualize)
194
+ pred = pred[0][1]
195
 
196
  # NMS
197
  with dt[2]:
 
198
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
199
 
200
  # Second-stage classifier (optional)
201
  # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
202
+ counts = {}
203
  # Process predictions
204
  for i, det in enumerate(pred): # per image
205
  seen += 1
 
223
  for c in det[:, 5].unique():
224
  n = (det[:, 5] == c).sum() # detections per class
225
  s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
226
+ counts[names[int(c)]] = n
227
  xywh_bboxs = []
228
  confs = []
229
  oids = []
 
244
  classNameInt = int(cls)
245
  oids.append(classNameInt)
246
  xywhs = torch.tensor(xywh_bboxs)
247
+ confss = torch.tensor(confs)
248
  outputs = deepsort.update(xywhs, confss, oids, ims)
249
  if len(outputs) > 0:
250
  bbox_xyxy = outputs[:, :4]
 
278
 
279
  # Print time (inference-only)
280
  LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
281
+ frame_counts.append((frame, counts)) # Append the counts for each frame
282
+ transformed_data = []
283
+
284
+ # Iterate over frame_counts and transform each entry into a row in the DataFrame
285
+ for frame, counts_dict in frame_counts:
286
+ for label, count in counts_dict.items():
287
+ transformed_data.append((frame, label.capitalize(), count))
288
+
289
+ # Create a DataFrame from the transformed data
290
+ df = pd.DataFrame(transformed_data, columns=['frame', 'label', 'count'])
291
+
292
+ # Convert count column from tensors to integers
293
+ df['count'] = df['count'].apply(convert_to_int)
294
+
295
+ counts_df = pd.DataFrame(counts.items(), columns=['label', 'count'])
296
+ counts_df['count'] = counts_df['count'].apply(convert_to_int)
297
+ counts_df['label'] = counts_df['label'].astype(str)
298
+
299
  if update:
300
  strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
301
+ return save_path, counts_df, df
302
+
303
 
304
  def parse_opt():
305
  parser = argparse.ArgumentParser()