ntsc207 commited on
Commit
74a077f
·
verified ·
1 Parent(s): 7e3030e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -80
app.py CHANGED
@@ -4,12 +4,15 @@ from detect_deepsort import run_deepsort
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
 
 
7
  import threading
8
- import torch
 
9
  should_continue = True
10
 
11
 
12
- @spaces.GPU(duration=120)
13
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
14
  global should_continue
15
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
@@ -21,23 +24,25 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
21
  input_path = None
22
  output_path = None
23
  if img_path is not None:
24
- #_, img_extension = os.path.splitext(img_path)
25
- #if img_extension.lower() in img_extensions:
 
 
 
26
  input_path = img_path
27
  print(input_path)
28
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
29
  elif vid_path is not None:
30
  #_, vid_extension = os.path.splitext(vid_path)
31
  #if vid_extension.lower() in vid_extensions:
32
  input_path = vid_path
33
  print(input_path)
34
  if tracking_algorithm == 'deep_sort':
35
- output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
36
  elif tracking_algorithm == 'strong_sort':
37
- device_strongsort = torch.device('cuda:0')
38
- output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
39
  else:
40
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
41
  # Assuming output_path is the path to the output file
42
  _, output_extension = os.path.splitext(output_path)
43
  if output_extension.lower() in img_extensions:
@@ -51,78 +56,46 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
51
 
52
 
53
 
54
- def stop_processing():
55
- global should_continue
56
- should_continue = False
57
- return "Stop..."
58
-
59
- def app():
60
- with gr.Blocks():
61
- with gr.Row():
62
- with gr.Column():
63
- gr.HTML("<h2>Input Parameters</h2>")
64
- img_path = gr.File(label="Image")
65
- vid_path = gr.File(label="Video")
66
- model_id = gr.Dropdown(
67
- label="Model",
68
- choices=[
69
- "yolov9_e_trained-converted.pt",
70
- "our-converted.pt",
71
- "last_best_model.pt"
72
- ],
73
- value="our-converted.pt"
74
-
75
- )
76
- tracking_algorithm = gr.Dropdown(
77
- label= "Tracking Algorithm",
78
- choices=[
79
- "None",
80
- "deep_sort",
81
- "strong_sort"
82
- ],
83
- value="None"
84
- )
85
- yolov9_infer = gr.Button(value="Inference")
86
- stop_button = gr.Button(value="Stop")
87
- with gr.Column():
88
- gr.HTML("<h2>Output</h2>")
89
- output_image = gr.Image(type="numpy",label="Output Image")
90
- output_video = gr.Video(label="Output Video")
91
- output_path = gr.Textbox(label="Output path")
92
 
93
- yolov9_infer.click(
94
- fn=yolov9_inference,
95
- inputs=[
96
- model_id,
97
- img_path,
98
- vid_path,
99
- tracking_algorithm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ],
101
- outputs=[output_image, output_video, output_path],
102
  )
103
- stop_button.click(stop_processing)
104
-
105
-
106
- gradio_app = gr.Blocks()
107
- with gradio_app:
108
- gr.HTML(
109
- """
110
- <h1 style='text-align: center'>
111
- YOLOv9: Real-time Object Detection
112
- </h1>
113
- """)
114
- css = """
115
- body {
116
- background-color: #f0f0f0;
117
- }
118
- h1 {
119
- color: #4CAF50;
120
- }
121
- """
122
- with gr.Row():
123
- with gr.Column():
124
- app()
125
-
126
- gradio_app.launch(debug=True)
127
-
128
 
 
 
4
  from detect_strongsort import run_strongsort
5
  from detect import run
6
  import os
7
+ from PIL import Image
8
+ import numpy as np
9
  import threading
10
+ import cv2
11
+
12
  should_continue = True
13
 
14
 
15
+ @spaces.GPU(duration=120)
16
  def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
17
  global should_continue
18
  img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
 
24
  input_path = None
25
  output_path = None
26
  if img_path is not None:
27
+ # Convert the numpy array to an image
28
+ img = Image.fromarray(img_path)
29
+ img_path = 'output.png'
30
+ # Save the image
31
+ img.save(img_path)
32
  input_path = img_path
33
  print(input_path)
34
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
35
  elif vid_path is not None:
36
  #_, vid_extension = os.path.splitext(vid_path)
37
  #if vid_extension.lower() in vid_extensions:
38
  input_path = vid_path
39
  print(input_path)
40
  if tracking_algorithm == 'deep_sort':
41
+ output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', draw_trails=True)
42
  elif tracking_algorithm == 'strong_sort':
43
+ output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
 
44
  else:
45
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
46
  # Assuming output_path is the path to the output file
47
  _, output_extension = os.path.splitext(output_path)
48
  if output_extension.lower() in img_extensions:
 
56
 
57
 
58
 
59
+ def app(model_id, img_path, vid_path, tracking_algorithm):
60
+ return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ iface = gr.Interface(
63
+ fn=app,
64
+ inputs=[
65
+ gr.Dropdown(
66
+ label="Model",
67
+ choices=[
68
+ "last_best_model.pt",
69
+ "best_model-converted.pt",
70
+ "yolov9_e_trained.pt",
71
+ "best_model-converted-reparams.pt"
72
+ ],
73
+ value="last_best_model.pt"
74
+ ),
75
+ gr.Image(label="Image"),
76
+ gr.Video(label="Video"),
77
+ gr.Dropdown(
78
+ label= "Tracking Algorithm",
79
+ choices=[
80
+ "None",
81
+ "deep_sort",
82
+ "strong_sort"
83
  ],
84
+ value="None"
85
  )
86
+ ],
87
+ outputs=[
88
+ gr.Image(type="numpy",label="Output Image"),
89
+ gr.Video(label="Output Video"),
90
+ gr.Textbox(label="Output path")
91
+ ],
92
+ examples=[
93
+ ["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"],
94
+ ["last_best_model.pt", None, "test.mp4", "strong_sort"]
95
+ ],
96
+ title='YOLOv9: Real-time Object Detection',
97
+ description='This is a real-time object detection system using YOLOv9.',
98
+ theme='huggingface'
99
+ )
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ iface.launch(debug=True)