qubvel-hf's picture
qubvel-hf HF Staff
Add comments and docstrings
79f197e
raw
history blame
7.52 kB
"""
Real-time video classification using VJEPA2 model with streaming capabilities.
This module implements a real-time video classification system that:
1. Captures video frames from a webcam
2. Processes batches of frames using the V-JEPA 2 model
3. Displays predictions overlaid on the video stream
4. Maintains a history of recent predictions
The system uses FastRTC for video streaming and Gradio for the web interface.
"""
import cv2
import time
import torch
import gradio as gr
import numpy as np
from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
# Model configuration
CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint
TORCH_DTYPE = torch.float16 # Use half precision for faster inference
TORCH_DEVICE = "cuda" # Use GPU for inference
UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames)
def add_text_on_image(image, text):
"""
Overlays text on an image with a black background bar at the top.
Args:
image (np.ndarray): Input image to add text to
text (str): Text to overlay on the image
Returns:
np.ndarray: Image with text overlaid
"""
# Add a black background to the text
image[:70] = 0
line_spacing = 10
top_margin = 20
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1
color = (255, 255, 255) # White
words = text.split()
lines = []
current_line = ""
img_width = image.shape[1]
# Build lines that fit within the image width
for word in words:
test_line = current_line + (" " if current_line else "") + word
(test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
if test_width > img_width - 20: # 20 px margin
lines.append(current_line)
current_line = word
else:
current_line = test_line
if current_line:
lines.append(current_line)
# Draw each line, centered
y = top_margin
for line in lines:
(line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness)
x = (img_width - line_width) // 2
cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA)
y += line_height + line_spacing
return image
class RunningFramesCache:
"""
Maintains a rolling buffer of video frames for model input.
This class manages a fixed-size queue of frames, keeping only the most recent
frames needed for model inference. It supports subsampling frames to reduce
memory usage and processing requirements.
Args:
save_every_k_frame (int): Only save every k-th frame (for subsampling)
max_frames (int): Maximum number of frames to keep in cache
"""
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
self.save_every_k_frame = save_every_k_frame
self.max_frames = max_frames
self._frames = []
def add_frame(self, frame: np.ndarray):
self._frames.append(frame)
if len(self._frames) > self.max_frames:
self._frames.pop(0)
def get_last_n_frames(self, n: int) -> list[np.ndarray]:
return self._frames[-n:]
def __len__(self) -> int:
return len(self._frames)
class RunningResult:
"""
Maintains a history of recent model predictions with timestamps.
This class keeps track of the most recent predictions made by the model,
including timestamps for each prediction. It provides formatted output
for display in the UI.
Args:
max_predictions (int): Maximum number of predictions to keep in history
"""
def __init__(self, max_predictions: int = 4):
self.predictions = []
self.max_predictions = max_predictions
def add_prediction(self, prediction: str):
# add time in a format of HH:MM:SS
current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time()))
self.predictions.append((current_time_formatted, prediction))
if len(self.predictions) > self.max_predictions:
self.predictions.pop(0)
def get_formatted_predictions(self) -> str:
if not self.predictions:
return "Starting..."
current, *past = self.predictions[::-1]
text = f">>> {current[1]}\n\n" + "\n".join(
[f"[{time_formatted}] {prediction}" for time_formatted, prediction in past]
)
return text
def get_last_prediction(self) -> str:
return self.predictions[-1][1] if self.predictions else "Starting..."
class FrameProcessingCallback:
"""
Handles real-time video frame processing and model inference.
This class is responsible for:
1. Loading and managing the V-JEPA 2 model
2. Processing incoming video frames
3. Running model inference at regular intervals
4. Managing frame caching and prediction history
5. Formatting output for display
The callback is called for each frame from the video stream and handles
the coordination between frame capture, model inference, and result display.
"""
def __init__(self):
# Loading model and processor
self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
self.model = self.model.to(TORCH_DEVICE)
self.video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT)
# Init frames cache
self.frames_per_clip = self.model.config.frames_per_clip
self.running_frames_cache = RunningFramesCache(
save_every_k_frame=128 / self.frames_per_clip,
max_frames=self.frames_per_clip,
)
self.running_result = RunningResult(max_predictions=4)
self.frame_count = 0
def __call__(self, image: np.ndarray):
image = np.flip(image, axis=1).copy()
self.running_frames_cache.add_frame(image)
self.frame_count += 1
if (
self.frame_count % UPDATE_EVERY_N_FRAMES == 0
and len(self.running_frames_cache) >= self.frames_per_clip
):
# Prepare frames for model
frames = self.running_frames_cache.get_last_n_frames(self.frames_per_clip)
frames = np.array(frames)
inputs = self.video_processor(frames, device=TORCH_DEVICE, return_tensors="pt")
inputs = inputs.to(dtype=TORCH_DTYPE)
# Run model
with torch.no_grad():
logits = self.model(**inputs).logits
# Get top prediction
top_index = logits.argmax(dim=-1).item()
class_name = self.model.config.id2label[top_index]
self.running_result.add_prediction(class_name)
formatted_predictions = self.running_result.get_formatted_predictions()
last_prediction = self.running_result.get_last_prediction()
image = add_text_on_image(image, last_prediction)
return image, AdditionalOutputs(formatted_predictions)
# Initialize the video stream with processing callback
stream = Stream(
handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
modality="video",
mode="send-receive",
additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)],
additional_outputs_handler=lambda _, output: output,
)
if __name__ == "__main__":
stream.ui.launch()