ntsc207 commited on
Commit
ed52d42
·
verified ·
1 Parent(s): c29b005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -56
app.py CHANGED
@@ -1,18 +1,18 @@
1
- import spaces
2
  import gradio as gr
3
  from detect_deepsort import run_deepsort
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
7
- import torch
8
  from PIL import Image
 
9
  import numpy as np
10
  import threading
11
- import cv2
12
 
13
  should_continue = True
14
 
15
- @spaces.GPU(duration=120)
 
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
@@ -31,7 +31,7 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
31
  img.save(img_path)
32
  input_path = img_path
33
  print(input_path)
34
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
35
  elif vid_path is not None:
36
  vid_name = 'output.mp4'
37
 
@@ -68,64 +68,91 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
68
  out.release()
69
  input_path = vid_name
70
  if tracking_algorithm == 'deep_sort':
71
- output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
72
  elif tracking_algorithm == 'strong_sort':
73
- device_strongsort = torch.device('cuda:0')
74
- output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
75
  else:
76
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
77
  # Assuming output_path is the path to the output file
78
  _, output_extension = os.path.splitext(output_path)
79
  if output_extension.lower() in img_extensions:
80
- output_image = output_path # Load the image file here
81
- output_video = None
82
  elif output_extension.lower() in vid_extensions:
83
- output_image = None
84
- output_video = output_path # Load the video file here
85
-
86
- return output_image, output_video, output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def app(model_id, img_path, vid_path, tracking_algorithm):
91
- return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm)
92
 
93
- iface = gr.Interface(
94
- fn=app,
95
- inputs=[
96
- gr.Dropdown(
97
- label="Model",
98
- choices=[
99
- "our-converted.pt",
100
- "yolov9_e_trained-converted.pt",
101
- "last_best_model.pt"
102
- ],
103
- value="our-converted.pt"
104
- ),
105
- gr.Image(label="Image"),
106
- gr.Video(label="Video"),
107
- gr.Dropdown(
108
- label= "Tracking Algorithm",
109
- choices=[
110
- "None",
111
- "deep_sort",
112
- "strong_sort"
113
- ],
114
- value="None"
115
- )
116
- ],
117
- outputs=[
118
- gr.Image(type="numpy",label="Output Image"),
119
- gr.Video(label="Output Video"),
120
- gr.Textbox(label="Output path")
121
- ],
122
- examples=[
123
- ["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"],
124
- ["last_best_model.pt", None, "test.mp4", "strong_sort"]
125
- ],
126
- title='YOLOv9: Real-time Object Detection',
127
- description='This is a real-time object detection system using YOLOv9.',
128
- theme='huggingface'
129
- )
130
-
131
- iface.launch(debug=True)
 
1
+ #import spaces
2
  import gradio as gr
3
  from detect_deepsort import run_deepsort
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
 
7
  from PIL import Image
8
+ import cv2
9
  import numpy as np
10
  import threading
 
11
 
12
  should_continue = True
13
 
14
+
15
+ #@spaces.GPU(duration=120)
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
 
31
  img.save(img_path)
32
  input_path = img_path
33
  print(input_path)
34
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
35
  elif vid_path is not None:
36
  vid_name = 'output.mp4'
37
 
 
68
  out.release()
69
  input_path = vid_name
70
  if tracking_algorithm == 'deep_sort':
71
+ output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', draw_trails=True)
72
  elif tracking_algorithm == 'strong_sort':
73
+ output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
 
74
  else:
75
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
76
  # Assuming output_path is the path to the output file
77
  _, output_extension = os.path.splitext(output_path)
78
  if output_extension.lower() in img_extensions:
79
+ output = output_path # Load the image file here
 
80
  elif output_extension.lower() in vid_extensions:
81
+ output = output_path # Load the video file here
82
+
83
+ return output, output_path
84
+
85
+
86
+ def app():
87
+ with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"):
88
+ with gr.Row():
89
+ with gr.Column():
90
+ gr.HTML("<h2>Input Parameters</h2>")
91
+ img_path = gr.Image(label="Image", height = 370, width = 600)
92
+ vid_path = gr.Video(label="Video", height = 370, width = 600)
93
+ model_id = gr.Dropdown(
94
+ label="Model",
95
+ choices=[
96
+ "our-converted.pt",
97
+ "yolov9_e_trained-converted.pt"
98
+ ],
99
+ value="our-converted.pt"
100
+
101
+ )
102
+ tracking_algorithm = gr.Dropdown(
103
+ label= "Tracking Algorithm",
104
+ choices=[
105
+ "None",
106
+ "deep_sort",
107
+ "strong_sort"
108
+ ],
109
+ value="None"
110
+ )
111
+ gr.Examples(['camera1_A_133.png'], inputs=img_path,label = "Image Example")
112
+ gr.Examples(['test.mp4'], inputs=vid_path, label = "Video Example")
113
+ yolov9_infer = gr.Button(value="Inference")
114
+ with gr.Column():
115
+ gr.HTML("<h2>Output</h2>")
116
+ if img_path is not None:
117
+ output_image = gr.Image(type="numpy",label="Output")
118
+ output = output_image
119
+ else:
120
+ output_video = gr.Video(label="Output")
121
+ output = output_video
122
+ output_path = gr.Textbox(label="Output path")
123
+
124
+ yolov9_infer.click(
125
+ fn=yolov9_inference,
126
+ inputs=[
127
+ model_id,
128
+ img_path,
129
+ vid_path,
130
+ tracking_algorithm
131
+ ],
132
+ outputs=[output, output_path],
133
+ )
134
 
135
 
136
+ gradio_app = gr.Blocks()
137
+ with gradio_app:
138
+ gr.HTML(
139
+ """
140
+ <h1 style='text-align: center'>
141
+ YOLOv9: Real-time Object Detection
142
+ </h1>
143
+ """)
144
+ css = """
145
+ body {
146
+ background-color: #f0f0f0;
147
+ }
148
+ h1 {
149
+ color: #4CAF50;
150
+ }
151
+ """
152
+ with gr.Row():
153
+ with gr.Column():
154
+ app()
155
+
156
+ gradio_app.launch(debug=True)
157
 
 
 
158