ntsc207 commited on
Commit
c5064a3
·
verified ·
1 Parent(s): 1633956

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -10
app.py CHANGED
@@ -7,12 +7,12 @@ import os
7
  from PIL import Image
8
  import numpy as np
9
  import threading
 
10
  import cv2
11
 
12
  should_continue = True
13
 
14
-
15
- @spaces.GPU(duration=120)
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
@@ -33,10 +33,40 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
33
  print(input_path)
34
  output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
35
  elif vid_path is not None:
36
- #_, vid_extension = os.path.splitext(vid_path)
37
- #if vid_extension.lower() in vid_extensions:
38
- input_path = vid_path
39
- print(input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if tracking_algorithm == 'deep_sort':
41
  output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
42
  elif tracking_algorithm == 'strong_sort':
@@ -66,8 +96,10 @@ iface = gr.Interface(
66
  gr.Dropdown(
67
  label="Model",
68
  choices=[
69
- "yolov9_e_trained-converted.pt",
70
- "best_model-converted.pt"
 
 
71
  ],
72
  value="last_best_model.pt"
73
  ),
@@ -89,8 +121,8 @@ iface = gr.Interface(
89
  gr.Textbox(label="Output path")
90
  ],
91
  examples=[
92
- ["best_model-converted.pt", "camera1_A_133.png", None, "deep_sort"],
93
- ["best_model-converted.pt", None, "test.mp4", "strong_sort"]
94
  ],
95
  title='YOLOv9: Real-time Object Detection',
96
  description='This is a real-time object detection system using YOLOv9.',
 
7
  from PIL import Image
8
  import numpy as np
9
  import threading
10
+ import skvideo.io
11
  import cv2
12
 
13
  should_continue = True
14
 
15
+ @spaces.GPU(duration=120)
 
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
 
33
  print(input_path)
34
  output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
35
  elif vid_path is not None:
36
+ vid_name = 'output.mp4'
37
+
38
+ # Create a VideoCapture object
39
+ cap = cv2.VideoCapture(vid_path)
40
+
41
+ # Check if video opened successfully
42
+ if not cap.isOpened():
43
+ print("Error opening video file")
44
+
45
+ # Read the video frame by frame
46
+ frames = []
47
+ while cap.isOpened():
48
+ ret, frame = cap.read()
49
+ if ret:
50
+ frames.append(frame)
51
+ else:
52
+ break
53
+
54
+ # Release the VideoCapture object
55
+ cap.release()
56
+
57
+ # Convert the list of frames to a numpy array
58
+ vid_data = np.array(frames)
59
+
60
+ # Create a VideoWriter object
61
+ out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0]))
62
+
63
+ # Write the frames to the output video file
64
+ for frame in frames:
65
+ out.write(frame)
66
+
67
+ # Release the VideoWriter object
68
+ out.release()
69
+ input_path = vid_name
70
  if tracking_algorithm == 'deep_sort':
71
  output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
72
  elif tracking_algorithm == 'strong_sort':
 
96
  gr.Dropdown(
97
  label="Model",
98
  choices=[
99
+ "last_best_model.pt",
100
+ "best_model-converted.pt",
101
+ "yolov9_e_trained.pt",
102
+ "best_model-converted-reparams.pt"
103
  ],
104
  value="last_best_model.pt"
105
  ),
 
121
  gr.Textbox(label="Output path")
122
  ],
123
  examples=[
124
+ ["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"],
125
+ ["last_best_model.pt", None, "test.mp4", "strong_sort"]
126
  ],
127
  title='YOLOv9: Real-time Object Detection',
128
  description='This is a real-time object detection system using YOLOv9.',