Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,720 Bytes
68e3bf5 f4c379b 839f10e 9e56ba5 9327140 74a077f 9e56ba5 74a077f 9e56ba5 c5064a3 839f10e 9e56ba5 839f10e 9e56ba5 839f10e 9e56ba5 74a077f 839f10e a1225a7 9e56ba5 c5064a3 839f10e a1225a7 839f10e a1225a7 839f10e a1225a7 839f10e 9e56ba5 74a077f 9e56ba5 74a077f 7a54e80 597b8ca 74a077f 7a54e80 74a077f f4c379b 74a077f f4c379b 74a077f c5064a3 74a077f 9e56ba5 74a077f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import torch
from PIL import Image
import numpy as np
import threading
import cv2
should_continue = True
@spaces.GPU(duration=120)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
# Convert the numpy array to an image
img = Image.fromarray(img_path)
img_path = 'output.png'
# Save the image
img.save(img_path)
input_path = img_path
print(input_path)
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
elif vid_path is not None:
vid_name = 'output.mp4'
# Create a VideoCapture object
cap = cv2.VideoCapture(vid_path)
# Check if video opened successfully
if not cap.isOpened():
print("Error opening video file")
# Read the video frame by frame
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
# Release the VideoCapture object
cap.release()
# Convert the list of frames to a numpy array
vid_data = np.array(frames)
# Create a VideoWriter object
out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0]))
# Write the frames to the output video file
for frame in frames:
out.write(frame)
# Release the VideoWriter object
out.release()
input_path = vid_name
if tracking_algorithm == 'deep_sort':
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
device_strongsort = torch.device('cuda:0')
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
else:
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
elif output_extension.lower() in vid_extensions:
output_image = None
output_video = output_path # Load the video file here
return output_image, output_video, output_path
def app(model_id, img_path, vid_path, tracking_algorithm):
return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm)
iface = gr.Interface(
fn=app,
inputs=[
gr.Dropdown(
label="Model",
choices=[
"our-converted.pt",
"yolov9_e_trained-converted.pt",
"last_best_model.pt"
],
value="our-converted.pt"
),
gr.Image(label="Image"),
gr.Video(label="Video"),
gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
],
outputs=[
gr.Image(type="numpy",label="Output Image"),
gr.Video(label="Output Video"),
gr.Textbox(label="Output path")
],
examples=[
["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"],
["last_best_model.pt", None, "test.mp4", "strong_sort"]
],
title='YOLOv9: Real-time Object Detection',
description='This is a real-time object detection system using YOLOv9.',
theme='huggingface'
)
iface.launch(debug=True) |