unidrone / run_sliced_inference.py
StephanST's picture
Upload 5 files
b84126d verified
raw
history blame contribute delete
4.63 kB
import cv2
import sys
import os
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction
import supervision as sv
# Check the number of command-line arguments
if len(sys.argv) != 8:
print("Usage: python yolov8_video_inference.py <model_path> <input_path> <output_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
sys.exit(1)
# Get command-line arguments
model_path = sys.argv[1]
input_path = sys.argv[2]
output_path = sys.argv[3]
slice_height = int(sys.argv[4])
slice_width = int(sys.argv[5])
overlap_height_ratio = float(sys.argv[6])
overlap_width_ratio = float(sys.argv[7])
# Load YOLOv8 model with SAHI
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8', # or 'yolov8'
model_path=model_path,
confidence_threshold=0.1,
device="cpu" # or "cuda"
)
# Annotators
box_annotator = sv.BoxCornerAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2)
def annotate_image(image, object_predictions):
"""
Given an OpenCV image and a list of object predictions from SAHI,
returns an annotated copy of that image.
"""
if not object_predictions:
return image.copy()
xyxy, confidences, class_ids, class_names = [], [], [], []
for pred in object_predictions:
bbox = pred.bbox.to_xyxy() # [x1, y1, x2, y2]
xyxy.append(bbox)
confidences.append(pred.score.value)
class_ids.append(pred.category.id)
class_names.append(pred.category.name)
xyxy = np.array(xyxy, dtype=np.float32)
confidences = np.array(confidences, dtype=np.float32)
class_ids = np.array(class_ids, dtype=int)
detections = sv.Detections(
xyxy=xyxy,
confidence=confidences,
class_id=class_ids
)
labels = [f"{cn} {conf:.2f}" for cn, conf in zip(class_names, confidences)]
annotated = image.copy()
annotated = box_annotator.annotate(scene=annotated, detections=detections)
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
return annotated
def run_sliced_inference(image):
result = get_sliced_prediction(
image=image,
detection_model=detection_model,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio
)
return annotate_image(image, result.object_prediction_list)
def run_full_inference(image):
# Normal inference without slicing
result = get_prediction(
image=image,
detection_model=detection_model
# postprocess_match_threshold=0.5, # If you want to adjust the post-processing threshold
)
return annotate_image(image, result.object_prediction_list)
# Determine if the input is an image or video based on file extension
_, ext = os.path.splitext(input_path.lower())
image_extensions = [".png", ".jpg", ".jpeg", ".bmp"]
if ext in image_extensions:
# ----- IMAGE PROCESSING -----
image = cv2.imread(input_path)
if image is None:
print(f"Error loading image: {input_path}")
sys.exit(1)
h, w = image.shape[:2]
# Decide whether or not to slice
if False: #(h > slice_height) or (w > slice_width):
# If the image is bigger than slice dims, do sliced inference
annotated_image = run_sliced_inference(image)
else:
# Otherwise do normal inference
annotated_image = run_full_inference(image)
cv2.imwrite(output_path, annotated_image)
print(f"Inference complete. Annotated image saved at '{output_path}'")
else:
# ----- VIDEO PROCESSING -----
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
print(f"Error opening video: {input_path}")
sys.exit(1)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# For each frame, you may or may not want slicing. Usually, you can do normal slicing if needed.
annotated_frame = run_sliced_inference(frame)
out.write(annotated_frame)
frame_count += 1
print(f"Processed frame {frame_count}", end='\r')
cap.release()
out.release()
print(f"\nInference complete. Video saved at '{output_path}'")