FishEye8K / app.py
ntsc207's picture
Update app.py
597b8ca verified
raw
history blame
4.72 kB
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import torch
from PIL import Image
import numpy as np
import threading
import cv2
should_continue = True
@spaces.GPU(duration=120)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
# Convert the numpy array to an image
img = Image.fromarray(img_path)
img_path = 'output.png'
# Save the image
img.save(img_path)
input_path = img_path
print(input_path)
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
elif vid_path is not None:
vid_name = 'output.mp4'
# Create a VideoCapture object
cap = cv2.VideoCapture(vid_path)
# Check if video opened successfully
if not cap.isOpened():
print("Error opening video file")
# Read the video frame by frame
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
# Release the VideoCapture object
cap.release()
# Convert the list of frames to a numpy array
vid_data = np.array(frames)
# Create a VideoWriter object
out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0]))
# Write the frames to the output video file
for frame in frames:
out.write(frame)
# Release the VideoWriter object
out.release()
input_path = vid_name
if tracking_algorithm == 'deep_sort':
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
device_strongsort = torch.device('cuda:0')
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
else:
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
elif output_extension.lower() in vid_extensions:
output_image = None
output_video = output_path # Load the video file here
return output_image, output_video, output_path
def app(model_id, img_path, vid_path, tracking_algorithm):
return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm)
iface = gr.Interface(
fn=app,
inputs=[
gr.Dropdown(
label="Model",
choices=[
"our-converted.pt",
"yolov9_e_trained-converted.pt",
"last_best_model.pt"
],
value="our-converted.pt"
),
gr.Image(label="Image"),
gr.Video(label="Video"),
gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
],
outputs=[
gr.Image(type="numpy",label="Output Image"),
gr.Video(label="Output Video"),
gr.Textbox(label="Output path")
],
examples=[
["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"],
["last_best_model.pt", None, "test.mp4", "strong_sort"]
],
title='YOLOv9: Real-time Object Detection',
description='This is a real-time object detection system using YOLOv9.',
theme='huggingface'
)
iface.launch(debug=True)