chenjoya commited on
Commit
4f38308
·
verified ·
1 Parent(s): 8c46b1e

Update demo/infer.py

Browse files
Files changed (1) hide show
  1. demo/infer.py +2 -1
demo/infer.py CHANGED
@@ -32,12 +32,13 @@ class LiveCCDemoInfer:
32
  streaming_time_interval = streaming_fps_frames / fps
33
  frame_time_interval = 1 / fps
34
 
35
- def __init__(self, model_path: str = None, device_id: int = 0):
36
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
37
  model_path, torch_dtype="auto",
38
  # device_map=f'cuda:{device_id}',
39
  # attn_implementation='flash_attention_2'
40
  )
 
41
  self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
42
  self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
43
  self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)
 
32
  streaming_time_interval = streaming_fps_frames / fps
33
  frame_time_interval = 1 / fps
34
 
35
+ def __init__(self, model_path: str = None, device: str = 'cpu'):
36
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
37
  model_path, torch_dtype="auto",
38
  # device_map=f'cuda:{device_id}',
39
  # attn_implementation='flash_attention_2'
40
  )
41
+ self.model.to(device)
42
  self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
43
  self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
44
  self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)