qubvel-hf HF Staff commited on
Commit
42718c3
·
1 Parent(s): 32e3eb4

Modify state initialization

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -163,10 +163,30 @@ class RunningResult:
163
  return self.predictions[-1][1] if self.predictions else "Starting..."
164
 
165
 
166
- def process_frames(image: np.ndarray, running_frames_cache: RunningFramesCache, running_result: RunningResult):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  image = np.flip(image, axis=1).copy()
168
  running_frames_cache.add_frame(image)
169
 
 
170
  if (
171
  running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0
172
  and len(running_frames_cache) >= model.config.frames_per_clip
@@ -187,6 +207,7 @@ def process_frames(image: np.ndarray, running_frames_cache: RunningFramesCache,
187
  class_name = model.config.id2label[top_index]
188
  running_result.add_prediction(class_name)
189
 
 
190
  formatted_predictions = running_result.get_formatted_predictions()
191
  last_prediction = running_result.get_last_prediction()
192
  image = add_text_on_image(image, last_prediction)
@@ -197,13 +218,8 @@ async def get_credentials():
197
  return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
198
 
199
 
200
- frames_cache = gr.State(
201
- RunningFramesCache(
202
- save_every_k_frame=128 / frames_per_clip,
203
- max_frames=frames_per_clip,
204
- )
205
- )
206
- result_cache = gr.State(RunningResult(4))
207
 
208
  # Initialize the video stream with processing callback
209
  stream = Stream(
 
163
  return self.predictions[-1][1] if self.predictions else "Starting..."
164
 
165
 
166
+ def process_frames(image: np.ndarray, frames_state: list, result_state: list):
167
+
168
+ # Initialize frames cache if not exists (and put in gradio state)
169
+ if not frames_state:
170
+ running_frames_cache = RunningFramesCache(
171
+ save_every_k_frame=128 / frames_per_clip,
172
+ max_frames=frames_per_clip,
173
+ )
174
+ frames_state.append(running_frames_cache)
175
+ else:
176
+ running_frames_cache = frames_state[0]
177
+
178
+ # Initialize result cache if not exists (and put in gradio state)
179
+ if not result_state:
180
+ running_result = RunningResult(4)
181
+ result_state.append(running_result)
182
+ else:
183
+ running_result = result_state[0]
184
+
185
+ # Add frame to frames cache
186
  image = np.flip(image, axis=1).copy()
187
  running_frames_cache.add_frame(image)
188
 
189
+ # Run model if enough frames are available
190
  if (
191
  running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0
192
  and len(running_frames_cache) >= model.config.frames_per_clip
 
207
  class_name = model.config.id2label[top_index]
208
  running_result.add_prediction(class_name)
209
 
210
+ # Get formatted predictions and last prediction
211
  formatted_predictions = running_result.get_formatted_predictions()
212
  last_prediction = running_result.get_last_prediction()
213
  image = add_text_on_image(image, last_prediction)
 
218
  return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
219
 
220
 
221
+ frames_cache = gr.State([])
222
+ result_cache = gr.State([])
 
 
 
 
 
223
 
224
  # Initialize the video stream with processing callback
225
  stream = Stream(