File size: 4,627 Bytes
b84126d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import cv2
import sys
import os
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction
import supervision as sv

# Check the number of command-line arguments
if len(sys.argv) != 8:
    print("Usage: python yolov8_video_inference.py <model_path> <input_path> <output_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
    sys.exit(1)

# Get command-line arguments
model_path = sys.argv[1]
input_path = sys.argv[2]
output_path = sys.argv[3]
slice_height = int(sys.argv[4])
slice_width = int(sys.argv[5])
overlap_height_ratio = float(sys.argv[6])
overlap_width_ratio = float(sys.argv[7])

# Load YOLOv8 model with SAHI
detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov8',  # or 'yolov8'
    model_path=model_path,
    confidence_threshold=0.1,
    device="cpu"  # or "cuda"
)

# Annotators
box_annotator = sv.BoxCornerAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2)

def annotate_image(image, object_predictions):
    """
    Given an OpenCV image and a list of object predictions from SAHI,
    returns an annotated copy of that image.
    """
    if not object_predictions:
        return image.copy()
    
    xyxy, confidences, class_ids, class_names = [], [], [], []
    for pred in object_predictions:
        bbox = pred.bbox.to_xyxy()  # [x1, y1, x2, y2]
        xyxy.append(bbox)
        confidences.append(pred.score.value)
        class_ids.append(pred.category.id)
        class_names.append(pred.category.name)

    xyxy = np.array(xyxy, dtype=np.float32)
    confidences = np.array(confidences, dtype=np.float32)
    class_ids = np.array(class_ids, dtype=int)

    detections = sv.Detections(
        xyxy=xyxy,
        confidence=confidences,
        class_id=class_ids
    )

    labels = [f"{cn} {conf:.2f}" for cn, conf in zip(class_names, confidences)]

    annotated = image.copy()
    annotated = box_annotator.annotate(scene=annotated, detections=detections)
    annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
    return annotated

def run_sliced_inference(image):
    result = get_sliced_prediction(
        image=image,
        detection_model=detection_model,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio
    )
    return annotate_image(image, result.object_prediction_list)

def run_full_inference(image):
    # Normal inference without slicing
    result = get_prediction(
        image=image,
        detection_model=detection_model
        # postprocess_match_threshold=0.5,  # If you want to adjust the post-processing threshold
    )
    return annotate_image(image, result.object_prediction_list)

# Determine if the input is an image or video based on file extension
_, ext = os.path.splitext(input_path.lower())
image_extensions = [".png", ".jpg", ".jpeg", ".bmp"]

if ext in image_extensions:
    # ----- IMAGE PROCESSING -----
    image = cv2.imread(input_path)
    if image is None:
        print(f"Error loading image: {input_path}")
        sys.exit(1)

    h, w = image.shape[:2]

    # Decide whether or not to slice
    if False:  #(h > slice_height) or (w > slice_width):
        # If the image is bigger than slice dims, do sliced inference
        annotated_image = run_sliced_inference(image)
    else:
        # Otherwise do normal inference
        annotated_image = run_full_inference(image)

    cv2.imwrite(output_path, annotated_image)
    print(f"Inference complete. Annotated image saved at '{output_path}'")

else:
    # ----- VIDEO PROCESSING -----
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        print(f"Error opening video: {input_path}")
        sys.exit(1)

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")

    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # For each frame, you may or may not want slicing. Usually, you can do normal slicing if needed.
        annotated_frame = run_sliced_inference(frame)
        out.write(annotated_frame)

        frame_count += 1
        print(f"Processed frame {frame_count}", end='\r')

    cap.release()
    out.release()
    print(f"\nInference complete. Video saved at '{output_path}'")