ntsc207 commited on
Commit
70a5a36
·
verified ·
1 Parent(s): b03f92b

Update detect_strongsort.py

Browse files
Files changed (1) hide show
  1. detect_strongsort.py +28 -10
detect_strongsort.py CHANGED
@@ -15,6 +15,7 @@ import torch
15
  import torch.backends.cudnn as cudnn
16
  from numpy import random
17
  from time import time
 
18
 
19
 
20
  FILE = Path(__file__).resolve()
@@ -44,7 +45,6 @@ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 't
44
  def plot_one_box(x, img, color=None, label=None, line_thickness=3):
45
  # Plots one bounding box on image img
46
  tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
47
- #color = color or [random.randint(0, 255) for _ in range(3)]
48
  c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
49
  cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
50
  if label:
@@ -55,6 +55,10 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=3):
55
  cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
56
 
57
 
 
 
 
 
58
  @smart_inference_mode()
59
  def run_strongsort(
60
  source='0',
@@ -163,6 +167,8 @@ def run_strongsort(
163
  model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
164
  seen, windows, dt,sdt = 0, [], (Profile(), Profile(), Profile(), Profile()),[0.0, 0.0, 0.0, 0.0]
165
  curr_frames, prev_frames = [None] * bs, [None] * bs
 
 
166
  for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset):
167
  # s = ''
168
  t1 = time_sync()
@@ -179,20 +185,19 @@ def run_strongsort(
179
  with dt[1]:
180
  visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
181
  pred = model(im, augment=augment, visualize=visualize)
182
- # pred = pred[0][1]
183
  t3 = time_sync()
184
  sdt[1] += t3 - t2
185
 
186
  # Apply NMS
187
  with dt[2]:
188
- pred = pred[0][1] if isinstance(pred[0], list) else pred[0] # single model or ensemble
189
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
190
  sdt[2] += time_sync() - t3
191
 
192
  # Second-stage classifier (optional)
193
  # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
194
 
195
-
196
  # Process detections
197
  for i, det in enumerate(pred): # detections per image
198
  seen += 1
@@ -227,6 +232,7 @@ def run_strongsort(
227
  imc = im0.copy() if save_crop else im0 # for save_crop
228
  annotator = Annotator(im0, line_width=line_thickness, example=str(names))
229
 
 
230
  if cfg.STRONGSORT.ECC: # camera motion compensation
231
  strongsort_list[i].tracker.camera_update(prev_frames[i], curr_frames[i])
232
 
@@ -238,7 +244,7 @@ def run_strongsort(
238
  for c in det[:, -1].unique():
239
  n = (det[:, -1] == c).sum() # detections per class
240
  s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
241
-
242
  xywhs = xyxy2xywh(det[:, 0:4])
243
  confs = det[:, 4]
244
  clss = det[:, 5]
@@ -248,12 +254,13 @@ def run_strongsort(
248
  outputs[i] = strongsort_list[i].update(xywhs.cpu(), confs.cpu(), clss.cpu(), im0)
249
  t5 = time_sync()
250
  sdt[3] += t5 - t4
251
-
252
  # Write results
253
  for j, (output, conf) in enumerate(zip(outputs[i], confs)):
254
  xyxy = output[0:4]
255
  id = output[4]
256
  cls = output[5]
 
257
  # for *xyxy, conf, cls in reversed(det):
258
  if save_txt: # Write to file
259
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
@@ -269,7 +276,7 @@ def run_strongsort(
269
  if save_crop:
270
  save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
271
 
272
-
273
  # # draw boxes for visualization
274
  # if len(outputs[i]) > 0:
275
  # for j, (output, conf) in enumerate(zip(outputs[i], confs)):
@@ -305,6 +312,8 @@ def run_strongsort(
305
 
306
  # Stream results
307
  im0 = annotator.result()
 
 
308
  if view_img:
309
  if platform.system() == 'Linux' and p not in windows:
310
  windows.append(p)
@@ -334,8 +343,17 @@ def run_strongsort(
334
 
335
  prev_frames[i] = curr_frames[i]
336
 
337
- # Print time (inference-only)
338
- LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
 
 
 
 
 
 
 
 
 
339
  # Print results
340
  LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape, %.1fms StrongSORT' % tuple(1E3 * x / seen for x in sdt))
341
  if save_txt or save_img:
@@ -343,7 +361,7 @@ def run_strongsort(
343
  LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
344
  if update:
345
  strip_optimizer(yolo_weights[0]) # update model (to fix SourceChangeWarning)
346
- return save_path
347
  def parse_opt():
348
  parser = argparse.ArgumentParser()
349
  parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov9.pt', help='model.pt path(s)')
 
15
  import torch.backends.cudnn as cudnn
16
  from numpy import random
17
  from time import time
18
+ import pandas as pd
19
 
20
 
21
  FILE = Path(__file__).resolve()
 
45
  def plot_one_box(x, img, color=None, label=None, line_thickness=3):
46
  # Plots one bounding box on image img
47
  tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
 
48
  c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
49
  cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
50
  if label:
 
55
  cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
56
 
57
 
58
+
59
+ def convert_to_int(tensor):
60
+ return tensor.type(torch.int16).item()
61
+
62
  @smart_inference_mode()
63
  def run_strongsort(
64
  source='0',
 
167
  model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
168
  seen, windows, dt,sdt = 0, [], (Profile(), Profile(), Profile(), Profile()),[0.0, 0.0, 0.0, 0.0]
169
  curr_frames, prev_frames = [None] * bs, [None] * bs
170
+ frame_counts = []
171
+
172
  for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset):
173
  # s = ''
174
  t1 = time_sync()
 
185
  with dt[1]:
186
  visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
187
  pred = model(im, augment=augment, visualize=visualize)
188
+ pred = pred[0][1]
189
  t3 = time_sync()
190
  sdt[1] += t3 - t2
191
 
192
  # Apply NMS
193
  with dt[2]:
 
194
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
195
  sdt[2] += time_sync() - t3
196
 
197
  # Second-stage classifier (optional)
198
  # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
199
 
200
+ counts = {}
201
  # Process detections
202
  for i, det in enumerate(pred): # detections per image
203
  seen += 1
 
232
  imc = im0.copy() if save_crop else im0 # for save_crop
233
  annotator = Annotator(im0, line_width=line_thickness, example=str(names))
234
 
235
+
236
  if cfg.STRONGSORT.ECC: # camera motion compensation
237
  strongsort_list[i].tracker.camera_update(prev_frames[i], curr_frames[i])
238
 
 
244
  for c in det[:, -1].unique():
245
  n = (det[:, -1] == c).sum() # detections per class
246
  s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
247
+ counts[names[int(c)]] = n
248
  xywhs = xyxy2xywh(det[:, 0:4])
249
  confs = det[:, 4]
250
  clss = det[:, 5]
 
254
  outputs[i] = strongsort_list[i].update(xywhs.cpu(), confs.cpu(), clss.cpu(), im0)
255
  t5 = time_sync()
256
  sdt[3] += t5 - t4
257
+
258
  # Write results
259
  for j, (output, conf) in enumerate(zip(outputs[i], confs)):
260
  xyxy = output[0:4]
261
  id = output[4]
262
  cls = output[5]
263
+ label = names[int(cls)]
264
  # for *xyxy, conf, cls in reversed(det):
265
  if save_txt: # Write to file
266
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
 
276
  if save_crop:
277
  save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
278
 
279
+ frame_counts.append({'frame': frame_idx, 'counts': counts.copy()})
280
  # # draw boxes for visualization
281
  # if len(outputs[i]) > 0:
282
  # for j, (output, conf) in enumerate(zip(outputs[i], confs)):
 
312
 
313
  # Stream results
314
  im0 = annotator.result()
315
+
316
+
317
  if view_img:
318
  if platform.system() == 'Linux' and p not in windows:
319
  windows.append(p)
 
343
 
344
  prev_frames[i] = curr_frames[i]
345
 
346
+
347
+ # Print time (inference-only)
348
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
349
+
350
+ flattened_counts = [
351
+ {'frame': entry['frame'], 'label': label, 'count': count}
352
+ for entry in frame_counts for label, count in entry['counts'].items()
353
+ ]
354
+ frame_counts_df = pd.DataFrame(flattened_counts)
355
+ frame_counts_df['count'] = frame_counts_df['count'].apply(convert_to_int)
356
+ counts_df = None
357
  # Print results
358
  LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape, %.1fms StrongSORT' % tuple(1E3 * x / seen for x in sdt))
359
  if save_txt or save_img:
 
361
  LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
362
  if update:
363
  strip_optimizer(yolo_weights[0]) # update model (to fix SourceChangeWarning)
364
+ return save_path, counts_df, frame_counts_df
365
  def parse_opt():
366
  parser = argparse.ArgumentParser()
367
  parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov9.pt', help='model.pt path(s)')