File size: 4,627 Bytes
b84126d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import cv2
import sys
import os
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction
import supervision as sv
# Check the number of command-line arguments
if len(sys.argv) != 8:
print("Usage: python yolov8_video_inference.py <model_path> <input_path> <output_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
sys.exit(1)
# Get command-line arguments
model_path = sys.argv[1]
input_path = sys.argv[2]
output_path = sys.argv[3]
slice_height = int(sys.argv[4])
slice_width = int(sys.argv[5])
overlap_height_ratio = float(sys.argv[6])
overlap_width_ratio = float(sys.argv[7])
# Load YOLOv8 model with SAHI
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8', # or 'yolov8'
model_path=model_path,
confidence_threshold=0.1,
device="cpu" # or "cuda"
)
# Annotators
box_annotator = sv.BoxCornerAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2)
def annotate_image(image, object_predictions):
"""
Given an OpenCV image and a list of object predictions from SAHI,
returns an annotated copy of that image.
"""
if not object_predictions:
return image.copy()
xyxy, confidences, class_ids, class_names = [], [], [], []
for pred in object_predictions:
bbox = pred.bbox.to_xyxy() # [x1, y1, x2, y2]
xyxy.append(bbox)
confidences.append(pred.score.value)
class_ids.append(pred.category.id)
class_names.append(pred.category.name)
xyxy = np.array(xyxy, dtype=np.float32)
confidences = np.array(confidences, dtype=np.float32)
class_ids = np.array(class_ids, dtype=int)
detections = sv.Detections(
xyxy=xyxy,
confidence=confidences,
class_id=class_ids
)
labels = [f"{cn} {conf:.2f}" for cn, conf in zip(class_names, confidences)]
annotated = image.copy()
annotated = box_annotator.annotate(scene=annotated, detections=detections)
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
return annotated
def run_sliced_inference(image):
result = get_sliced_prediction(
image=image,
detection_model=detection_model,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio
)
return annotate_image(image, result.object_prediction_list)
def run_full_inference(image):
# Normal inference without slicing
result = get_prediction(
image=image,
detection_model=detection_model
# postprocess_match_threshold=0.5, # If you want to adjust the post-processing threshold
)
return annotate_image(image, result.object_prediction_list)
# Determine if the input is an image or video based on file extension
_, ext = os.path.splitext(input_path.lower())
image_extensions = [".png", ".jpg", ".jpeg", ".bmp"]
if ext in image_extensions:
# ----- IMAGE PROCESSING -----
image = cv2.imread(input_path)
if image is None:
print(f"Error loading image: {input_path}")
sys.exit(1)
h, w = image.shape[:2]
# Decide whether or not to slice
if False: #(h > slice_height) or (w > slice_width):
# If the image is bigger than slice dims, do sliced inference
annotated_image = run_sliced_inference(image)
else:
# Otherwise do normal inference
annotated_image = run_full_inference(image)
cv2.imwrite(output_path, annotated_image)
print(f"Inference complete. Annotated image saved at '{output_path}'")
else:
# ----- VIDEO PROCESSING -----
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
print(f"Error opening video: {input_path}")
sys.exit(1)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# For each frame, you may or may not want slicing. Usually, you can do normal slicing if needed.
annotated_frame = run_sliced_inference(frame)
out.write(annotated_frame)
frame_count += 1
print(f"Processed frame {frame_count}", end='\r')
cap.release()
out.release()
print(f"\nInference complete. Video saved at '{output_path}'") |