|
import cv2 |
|
import sys |
|
import os |
|
import numpy as np |
|
from sahi import AutoDetectionModel |
|
from sahi.predict import get_sliced_prediction, get_prediction |
|
import supervision as sv |
|
|
|
|
|
if len(sys.argv) != 8: |
|
print("Usage: python yolov8_video_inference.py <model_path> <input_path> <output_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>") |
|
sys.exit(1) |
|
|
|
|
|
model_path = sys.argv[1] |
|
input_path = sys.argv[2] |
|
output_path = sys.argv[3] |
|
slice_height = int(sys.argv[4]) |
|
slice_width = int(sys.argv[5]) |
|
overlap_height_ratio = float(sys.argv[6]) |
|
overlap_width_ratio = float(sys.argv[7]) |
|
|
|
|
|
detection_model = AutoDetectionModel.from_pretrained( |
|
model_type='yolov8', |
|
model_path=model_path, |
|
confidence_threshold=0.1, |
|
device="cpu" |
|
) |
|
|
|
|
|
box_annotator = sv.BoxCornerAnnotator(thickness=2) |
|
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2) |
|
|
|
def annotate_image(image, object_predictions): |
|
""" |
|
Given an OpenCV image and a list of object predictions from SAHI, |
|
returns an annotated copy of that image. |
|
""" |
|
if not object_predictions: |
|
return image.copy() |
|
|
|
xyxy, confidences, class_ids, class_names = [], [], [], [] |
|
for pred in object_predictions: |
|
bbox = pred.bbox.to_xyxy() |
|
xyxy.append(bbox) |
|
confidences.append(pred.score.value) |
|
class_ids.append(pred.category.id) |
|
class_names.append(pred.category.name) |
|
|
|
xyxy = np.array(xyxy, dtype=np.float32) |
|
confidences = np.array(confidences, dtype=np.float32) |
|
class_ids = np.array(class_ids, dtype=int) |
|
|
|
detections = sv.Detections( |
|
xyxy=xyxy, |
|
confidence=confidences, |
|
class_id=class_ids |
|
) |
|
|
|
labels = [f"{cn} {conf:.2f}" for cn, conf in zip(class_names, confidences)] |
|
|
|
annotated = image.copy() |
|
annotated = box_annotator.annotate(scene=annotated, detections=detections) |
|
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels) |
|
return annotated |
|
|
|
def run_sliced_inference(image): |
|
result = get_sliced_prediction( |
|
image=image, |
|
detection_model=detection_model, |
|
slice_height=slice_height, |
|
slice_width=slice_width, |
|
overlap_height_ratio=overlap_height_ratio, |
|
overlap_width_ratio=overlap_width_ratio |
|
) |
|
return annotate_image(image, result.object_prediction_list) |
|
|
|
def run_full_inference(image): |
|
|
|
result = get_prediction( |
|
image=image, |
|
detection_model=detection_model |
|
|
|
) |
|
return annotate_image(image, result.object_prediction_list) |
|
|
|
|
|
_, ext = os.path.splitext(input_path.lower()) |
|
image_extensions = [".png", ".jpg", ".jpeg", ".bmp"] |
|
|
|
if ext in image_extensions: |
|
|
|
image = cv2.imread(input_path) |
|
if image is None: |
|
print(f"Error loading image: {input_path}") |
|
sys.exit(1) |
|
|
|
h, w = image.shape[:2] |
|
|
|
|
|
if False: |
|
|
|
annotated_image = run_sliced_inference(image) |
|
else: |
|
|
|
annotated_image = run_full_inference(image) |
|
|
|
cv2.imwrite(output_path, annotated_image) |
|
print(f"Inference complete. Annotated image saved at '{output_path}'") |
|
|
|
else: |
|
|
|
cap = cv2.VideoCapture(input_path) |
|
if not cap.isOpened(): |
|
print(f"Error opening video: {input_path}") |
|
sys.exit(1) |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
frame_count = 0 |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
annotated_frame = run_sliced_inference(frame) |
|
out.write(annotated_frame) |
|
|
|
frame_count += 1 |
|
print(f"Processed frame {frame_count}", end='\r') |
|
|
|
cap.release() |
|
out.release() |
|
print(f"\nInference complete. Video saved at '{output_path}'") |