FishEye8K / app.py
ntsc207's picture
Update app.py
30c59c4 verified
raw
history blame
14.6 kB
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import torch
import seaborn as sns
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import threading
from scipy.interpolate import make_interp_spline
import pandas as pd
should_continue = True
@spaces.GPU(duration=240)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
# Convert the numpy array to an image
img = Image.fromarray(img_path)
img_path = 'output.png'
# Save the image
img.save(img_path)
input_path = img_path
output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
elif vid_path is not None:
vid_name = 'output.mp4'
# Create a VideoCapture object
cap = cv2.VideoCapture(vid_path)
# Check if video opened successfully
if not cap.isOpened():
print("Error opening video file")
# Read the video frame by frame
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
# Release the VideoCapture object
cap.release()
# Convert the list of frames to a numpy array
vid_data = np.array(frames)
# Create a VideoWriter object
out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0]))
# Write the frames to the output video file
for frame in frames:
out.write(frame)
# Release the VideoWriter object
out.release()
input_path = vid_name
if tracking_algorithm == 'deep_sort':
output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
device_strongsort = torch.device('cuda:0')
output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True,hide_labels = True)
else:
output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
plt.style.use("ggplot")
fig, ax = plt.subplots(figsize=(10, 6), dpi = 300)
#for label in labels:
#df_label = frame_counts_df[frame_counts_df['label'] == label]
sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label')
# Customizations
ax.set_title('Number of Objects', fontsize=20, pad=20) # Increase padding for the title
ax.set_xlabel('Object Class', fontsize=16) # Increase font size
ax.set_ylabel('Object Count', fontsize=16) # Increase font size
ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability
ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
sns.despine() # Remove the top and right spines from plot
# Add grid but make it lighter and put it behind bars
ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
ax.set_axisbelow(True)
# Add a legend with a smaller font size
ax.legend(fontsize=10)
plt.tight_layout() # Ensure the entire plot fits into the figure area
#ax.set_facecolor('#D3D3D3')
elif output_extension.lower() in vid_extensions:
output_video = output_path # Load the video file here
output_image = None
plt.style.use("ggplot")
fig, ax = plt.subplots(figsize=(10, 6), dpi = 300)
#for label in labels:
#df_label = frame_counts_df[frame_counts_df['label'] == label]
sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label', palette=palette,linewidth=2.5)
ax.set_title('Object count over frame', fontsize=20, pad=20) # Increase padding for the title
ax.set_xlabel('Frame', fontsize=16) # Increase font size
ax.set_ylabel('Object Count', fontsize=16) # Increase font size
ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis
ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
# Add grid but make it lighter and put it behind bars
ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
ax.set_axisbelow(True)
# Change the background color to a lighter shade
ax.set_facecolor('#F0F0F0')
# Add a legend with a smaller font size
ax.legend(fontsize=10)
plt.tight_layout() # Ensure the entire
# output_video = output_path
# output_image = None
# # Interpolation preprocessing
# interpolated_data = []
# labels = frame_counts_df['label'].unique()
# for label in labels:
# df_label = frame_counts_df[frame_counts_df['label'] == label]
# # Sort data by frame to ensure smooth interpolation
# df_label = df_label.sort_values('frame')
# # Original data points
# x = df_label['frame']
# y = df_label['count']
# # Check if we have enough points for interpolation
# if len(x) > 1:
# # Create spline interpolation
# x_smooth = np.linspace(x.min(), x.max(), 500)
# spline = make_interp_spline(x, y, k=3) # Cubic spline interpolation
# y_smooth = spline(x_smooth)
# # Append the smoothed data to the list
# interpolated_data.append(pd.DataFrame({'frame': x_smooth, 'count': y_smooth, 'label': label}))
# # Concatenate all interpolated data into a single DataFrame
# if interpolated_data:
# interpolated_df = pd.concat(interpolated_data)
# else:
# interpolated_df = pd.DataFrame(columns=['frame', 'count', 'label'])
# plt.style.use("ggplot")
# fig, ax = plt.subplots(figsize=(10, 6))
# # Plot using Seaborn
# sns.lineplot(ax=ax, data=interpolated_df, x='frame', y='count', hue='label', palette=palette, linewidth=2.5)
# ax.set_title('Number of Objects over Seconds', fontsize=20, pad=20) # Increase padding for the title
# ax.set_xlabel('Second', fontsize=16) # Increase font size
# ax.set_ylabel('Object Count', fontsize=16) # Increase font size
# ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis
# ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis
# # Add grid but make it lighter and put it behind bars
# ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
# ax.set_axisbelow(True)
# # Change the background color to a lighter shade
# ax.set_facecolor('#F0F0F0')
# # Add a legend with a smaller font size
# ax.legend(fontsize=10)
# plt.tight_layout() # Ensure the entire plot is visible
return output_image, output_video, fig
def app():
img = Image.open('./img_examples/classes.png')
img = img.resize((410, 260), Image.Resampling.LANCZOS)
img = np.array(img)
with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"):
with gr.Row():
with gr.Column():
gr.HTML("<h2>Input</h2>")
img_path = gr.Image(label="Image", height = 260, width = 410)
vid_path = gr.Video(label="Video", height = 260, width = 410)
#gr.HTML("<img src='flie/img_examples/clasess.png'>")
# gr.HTML("""
# <h2>Classes (Color)</h2>
# <ul>
# <li><span style="color:#FF3333">■</span> Bus</li>
# <li><span style="color:#3358FF">■</span> Bike</li>
# <li><span style="color:#33FF33">■</span> Car</li>
# <li><span style="color:#F6FF33">■</span> Pedestrian</li>
# <li><span style="color:#9F33FF">■</span> Truck</li>
# </ul>
# """)
# gr.HTML("""
# <h2>Classes (Color)</h2>
# <ul>
# <li style="font-size:17px;"><span style="color:#FF3333">■</span> Bus</li>
# <li style="font-size:17px;"><span style="color:#3358FF">■</span> Bike</li>
# <li style="font-size:17px;"><span style="color:#33FF33">■</span> Car</li>
# <li style="font-size:17px;"><span style="color:#F6FF33">■</span> Pedestrian</li>
# <li style="font-size:17px;"><span style="color:#9F33FF">■</span> Truck</li>
# </ul>
# """)
#gr.Image(value = img, interactive = False, label = "Classes", height = 260, width = 410)
#gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False)
#gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False)
with gr.Column(min_width = 270):
gr.HTML("<h2>Output</h2>")
output_image = gr.Image(type="numpy",label="Output Image", height = 260, width = 410)
#df = gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False)
output_video = gr.Video(label="Output Video", height = 260, width = 410)
#frame_counts_df = gr.LinePlot(show_label=False, x="frame", y="count", x_title="Frame", y_title="Counts", color="label")
fig = gr.Plot(label = "label")
#output_path = gr.Textbox(label="Output path")
with gr.Column():
gr.HTML("<h2>Configuration</h2>")
model_id = gr.Dropdown(
label="Model",
choices=[
"Our_model-e.pt",
"Our_model-c-dev.pt",
"yolov9-e_trained.pt",
"yolov9-c_trained.pt",
],
value="Our_model-e.pt"
)
tracking_algorithm = gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
yolov9_infer = gr.Button(value="Inference")
gr.HTML("""
<p style="text-align:center; font-family:Arial; font-size:16px;">
<span style="display:inline-block; width:8px; height:8px; background:#FF3333;"></span> Bus &nbsp;
<span style="display:inline-block; width:8px; height:8px; background:#3358FF;"></span> Bike &nbsp;
<span style="display:inline-block; width:8px; height:8px; background:#33FF33;"></span> Car &nbsp;
<span style="display:inline-block; width:8px; height:8px; background:#F6FF33;"></span> Pedestrian &nbsp;
<span style="display:inline-block; width:8px; height:8px; background:#9F33FF;"></span> Truck
</p>
""")
gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False, examples_per_page = 4)
gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False, examples_per_page = 4)
yolov9_infer.click(
fn=yolov9_inference,
inputs=[
model_id,
img_path,
vid_path,
tracking_algorithm
],
outputs=[output_image, output_video, fig],
)
gradio_app = gr.Blocks(title= "YOLOv9-FishEye")
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv9-FishEye: Improving model for realtime fisheye camera object detection
</h1>
""")
css = """
body {
background-color: #f0f0f0;
}
h1 {
color: #4CAF50;
}
"""
with gr.Row():
with gr.Column():
app()
gradio_app.launch(debug=True, favicon_path= "fisheye_icon.png")