Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from detect_deepsort import run_deepsort | |
from detect_strongsort import run_strongsort | |
from detect import run | |
import os | |
import torch | |
from PIL import Image | |
import numpy as np | |
import threading | |
import cv2 | |
should_continue = True | |
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None): | |
global should_continue | |
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed | |
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed | |
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided." | |
image_size = 640 | |
conf_threshold = 0.5 | |
iou_threshold = 0.5 | |
input_path = None | |
output_path = None | |
if img_path is not None: | |
# Convert the numpy array to an image | |
img = Image.fromarray(img_path) | |
img_path = 'output.png' | |
# Save the image | |
img.save(img_path) | |
input_path = img_path | |
print(input_path) | |
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) | |
elif vid_path is not None: | |
vid_name = 'output.mp4' | |
# Create a VideoCapture object | |
cap = cv2.VideoCapture(vid_path) | |
# Check if video opened successfully | |
if not cap.isOpened(): | |
print("Error opening video file") | |
# Read the video frame by frame | |
frames = [] | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if ret: | |
frames.append(frame) | |
else: | |
break | |
# Release the VideoCapture object | |
cap.release() | |
# Convert the list of frames to a numpy array | |
vid_data = np.array(frames) | |
# Create a VideoWriter object | |
out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0])) | |
# Write the frames to the output video file | |
for frame in frames: | |
out.write(frame) | |
# Release the VideoWriter object | |
out.release() | |
input_path = vid_name | |
if tracking_algorithm == 'deep_sort': | |
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True) | |
elif tracking_algorithm == 'strong_sort': | |
device_strongsort = torch.device('cuda:0') | |
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True) | |
else: | |
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) | |
# Assuming output_path is the path to the output file | |
_, output_extension = os.path.splitext(output_path) | |
if output_extension.lower() in img_extensions: | |
output_image = output_path # Load the image file here | |
output_video = None | |
elif output_extension.lower() in vid_extensions: | |
output_image = None | |
output_video = output_path # Load the video file here | |
return output_image, output_video, output_path | |
def app(model_id, img_path, vid_path, tracking_algorithm): | |
return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm) | |
iface = gr.Interface( | |
fn=app, | |
inputs=[ | |
gr.Dropdown( | |
label="Model", | |
choices=[ | |
"our-converted.pt", | |
"yolov9_e_trained-converted.pt", | |
"last_best_model.pt" | |
], | |
value="our-converted.pt" | |
), | |
gr.Image(label="Image"), | |
gr.Video(label="Video"), | |
gr.Dropdown( | |
label= "Tracking Algorithm", | |
choices=[ | |
"None", | |
"deep_sort", | |
"strong_sort" | |
], | |
value="None" | |
) | |
], | |
outputs=[ | |
gr.Image(type="numpy",label="Output Image"), | |
gr.Video(label="Output Video"), | |
gr.Textbox(label="Output path") | |
], | |
examples=[ | |
["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"], | |
["last_best_model.pt", None, "test.mp4", "strong_sort"] | |
], | |
title='YOLOv9: Real-time Object Detection', | |
description='This is a real-time object detection system using YOLOv9.', | |
theme='huggingface' | |
) | |
iface.launch(debug=True) |