qubvel-hf HF Staff commited on
Commit
d569c73
ยท
1 Parent(s): 9fae04f
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +158 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Vjepa2 Streaming Video Classification
3
  emoji: ๐Ÿ 
4
  colorFrom: blue
5
  colorTo: indigo
 
1
  ---
2
+ title: V-JEPA 2 - Streaming Video Classification
3
  emoji: ๐Ÿ 
4
  colorFrom: blue
5
  colorTo: indigo
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+ from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs
8
+ from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
9
+
10
+
11
+ CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2"
12
+ TORCH_DTYPE = torch.float16
13
+ TORCH_DEVICE = "cuda"
14
+ UPDATE_EVERY_N_FRAMES = 64
15
+
16
+
17
+ def add_text_on_image(image, text):
18
+
19
+ # Add a black background to the text
20
+ image[:70] = 0
21
+
22
+ line_spacing = 10
23
+ top_margin = 20
24
+
25
+ font = cv2.FONT_HERSHEY_SIMPLEX
26
+ font_scale = 0.5
27
+ thickness = 1
28
+ color = (255, 255, 255) # White
29
+
30
+ words = text.split()
31
+ lines = []
32
+ current_line = ""
33
+
34
+ img_width = image.shape[1]
35
+
36
+ # Build lines that fit within the image width
37
+ for word in words:
38
+ test_line = current_line + (" " if current_line else "") + word
39
+ (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
40
+ if test_width > img_width - 20: # 20 px margin
41
+ lines.append(current_line)
42
+ current_line = word
43
+ else:
44
+ current_line = test_line
45
+ if current_line:
46
+ lines.append(current_line)
47
+
48
+ # Draw each line, centered
49
+ y = top_margin
50
+ for line in lines:
51
+ (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness)
52
+ x = (img_width - line_width) // 2
53
+ cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA)
54
+ y += line_height + line_spacing
55
+
56
+ return image
57
+
58
+
59
+ class RunningFramesCache:
60
+
61
+ def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
62
+ self.save_every_k_frame = save_every_k_frame
63
+ self.max_frames = max_frames
64
+ self._frames = []
65
+
66
+ def add_frame(self, frame: np.ndarray):
67
+ self._frames.append(frame)
68
+ if len(self._frames) > self.max_frames:
69
+ self._frames.pop(0)
70
+
71
+ def get_frames(self) -> list[np.ndarray]:
72
+ return self._frames
73
+
74
+ def __len__(self) -> int:
75
+ return len(self._frames)
76
+
77
+
78
+ class RunningResult:
79
+
80
+ def __init__(self, max_predictions: int = 4):
81
+ self.predictions = []
82
+ self.max_predictions = max_predictions
83
+
84
+ def add_prediction(self, prediction: str):
85
+ # add time in a format of HH:MM:SS
86
+ current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time()))
87
+ self.predictions.append((current_time_formatted, prediction))
88
+ if len(self.predictions) > self.max_predictions:
89
+ self.predictions.pop(0)
90
+
91
+ def get_formatted_predictions(self) -> str:
92
+ if not self.predictions:
93
+ return "Starting..."
94
+
95
+ current, *past = self.predictions[::-1]
96
+ text = f">>> {current[1]}\n\n" + "\n".join([
97
+ f"[{time_formatted}] {prediction}"
98
+ for time_formatted, prediction in past
99
+ ])
100
+ return text
101
+
102
+ def get_last_prediction(self) -> str:
103
+ return self.predictions[-1][1] if self.predictions else "Starting..."
104
+
105
+
106
+ class FrameProcessingCallback:
107
+ def __init__(self):
108
+
109
+ # Loading model and processor
110
+ self.model = VJEPA2ForVideoClassification.from_pretrained(
111
+ CHECKPOINT, torch_dtype=torch.bfloat16
112
+ ).to(TORCH_DEVICE)
113
+ self.video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT)
114
+
115
+ # Init frames cache
116
+ self.running_frames_cache = RunningFramesCache(
117
+ save_every_k_frame=128 / self.model.config.frames_per_clip,
118
+ max_frames=self.model.config.frames_per_clip,
119
+ )
120
+ self.running_result = RunningResult(max_predictions=4)
121
+ self.frame_count = 0
122
+
123
+ def __call__(self, image: np.ndarray):
124
+ image = np.flip(image, axis=1).copy()
125
+ self.running_frames_cache.add_frame(image)
126
+ self.frame_count += 1
127
+ print(f"Frame {self.frame_count}, n frames: {len(self.running_frames_cache)}")
128
+
129
+ if self.frame_count % UPDATE_EVERY_N_FRAMES == 0 and len(self.running_frames_cache) == self.model.config.frames_per_clip:
130
+ # Prepare frames for model
131
+ frames = self.running_frames_cache.get_frames()
132
+ frames = np.array(frames)
133
+ inputs = self.video_processor(frames, device=TORCH_DEVICE, return_tensors="pt").to(dtype=TORCH_DTYPE)
134
+
135
+ # Run model
136
+ with torch.no_grad():
137
+ logits = self.model(**inputs).logits
138
+ top_index = logits.argmax(dim=-1).item()
139
+ class_name = self.model.config.id2label[top_index]
140
+ self.running_result.add_prediction(class_name)
141
+
142
+ formatted_predictions = self.running_result.get_formatted_predictions()
143
+ last_prediction = self.running_result.get_last_prediction()
144
+ image = add_text_on_image(image, last_prediction)
145
+ return image, AdditionalOutputs(formatted_predictions)
146
+
147
+
148
+ stream = Stream(
149
+ handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
150
+ modality="video",
151
+ mode="send-receive",
152
+ additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)],
153
+ additional_outputs_handler=lambda _, output: output,
154
+ )
155
+
156
+
157
+ if __name__ == "__main__":
158
+ stream.ui.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers @ git+https://github.com/huggingface/transformers
3
+ torch
4
+ torchvision
5
+ opencv-python-headless
6
+ fastrtc>=0.0.28