import spaces import gradio as gr from detect_deepsort import run_deepsort from detect_strongsort import run_strongsort from detect import run import os import torch import seaborn as sns from PIL import Image import cv2 import numpy as np import matplotlib.pyplot as plt import threading from scipy.interpolate import make_interp_spline import pandas as pd should_continue = True @spaces.GPU(duration=240) def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None): global should_continue img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed #assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided." image_size = 640 conf_threshold = 0.5 iou_threshold = 0.5 input_path = None output_path = None if img_path is not None: # Convert the numpy array to an image img = Image.fromarray(img_path) img_path = 'output.png' # Save the image img.save(img_path) input_path = img_path output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) elif vid_path is not None: vid_name = 'output.mp4' # Create a VideoCapture object cap = cv2.VideoCapture(vid_path) # Check if video opened successfully if not cap.isOpened(): print("Error opening video file") # Read the video frame by frame frames = [] while cap.isOpened(): ret, frame = cap.read() if ret: frames.append(frame) else: break # Release the VideoCapture object cap.release() # Convert the list of frames to a numpy array vid_data = np.array(frames) # Create a VideoWriter object out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0])) # Write the frames to the output video file for frame in frames: out.write(frame) # Release the VideoWriter object out.release() input_path = vid_name if tracking_algorithm == 'deep_sort': output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True) elif tracking_algorithm == 'strong_sort': device_strongsort = torch.device('cuda:0') output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True) else: output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) # Assuming output_path is the path to the output file _, output_extension = os.path.splitext(output_path) palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"} if output_extension.lower() in img_extensions: output_image = output_path # Load the image file here output_video = None plt.style.use("ggplot") fig, ax = plt.subplots(figsize=(10, 6)) #for label in labels: #df_label = frame_counts_df[frame_counts_df['label'] == label] sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label') # Customizations ax.set_title('Count of Labels', fontsize=20, pad=20) # Increase padding for the title ax.set_xlabel('Label', fontsize=16) # Increase font size ax.set_ylabel('Count', fontsize=16) # Increase font size ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis sns.despine() # Remove the top and right spines from plot # Add grid but make it lighter and put it behind bars ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6) ax.set_axisbelow(True) # Add a legend with a smaller font size ax.legend(fontsize=10) plt.tight_layout() # Ensure the entire plot fits into the figure area #ax.set_facecolor('#D3D3D3') elif output_extension.lower() in vid_extensions: output_video = output_path output_image = None # Interpolation preprocessing interpolated_data = [] labels = frame_counts_df['label'].unique() for label in labels: df_label = frame_counts_df[frame_counts_df['label'] == label] # Sort data by frame to ensure smooth interpolation df_label = df_label.sort_values('frame') # Original data points x = df_label['frame'] y = df_label['count'] # Create spline interpolation x_smooth = np.linspace(x.min(), x.max(), 500) spline = make_interp_spline(x, y, k=3) # Cubic spline interpolation y_smooth = spline(x_smooth) # Append the smoothed data to the list interpolated_data.append(pd.DataFrame({'frame': x_smooth, 'count': y_smooth, 'label': label})) # Concatenate all interpolated data into a single DataFrame interpolated_df = pd.concat(interpolated_data) plt.style.use("ggplot") fig, ax = plt.subplots(figsize=(10, 6)) # Plot using Seaborn sns.lineplot(ax=ax, data=interpolated_df, x='frame', y='count', hue='label', palette=palette, linewidth=2.5) ax.set_title('Count of Labels over Frames', fontsize=20, pad=20) # Increase padding for the title ax.set_xlabel('Frame', fontsize=16) # Increase font size ax.set_ylabel('Count', fontsize=16) # Increase font size ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis # Add grid but make it lighter and put it behind bars ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6) ax.set_axisbelow(True) # Change the background color to a lighter shade ax.set_facecolor('#F0F0F0') # Add a legend with a smaller font size ax.legend(fontsize=10) plt.tight_layout() # Ensure the entire plot is visible # Return the figure, output_image, and output_video return output_image, output_video, fig def app(): img = Image.open('./img_examples/classes.png') img = img.resize((410, 260), Image.Resampling.LANCZOS) img = np.array(img) with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"): with gr.Row(): with gr.Column(): gr.HTML("