chenjoya commited on
Commit
a76995d
·
verified ·
1 Parent(s): 44bdfc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -3
app.py CHANGED
@@ -62,10 +62,51 @@ with gr.Blocks() as demo:
62
  infer = _init_infer()
63
  state['video_path'] = video_path
64
  yield 'finished initialization, responding...', state
65
- if mode == 'Conversation':
66
- yield infer.video_qa(query=message, state=state)
 
 
 
 
 
 
 
 
 
 
67
  else:
68
- return 'waiting video input...'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def gr_chatinterface_chatbot_clear_fn():
70
  return {}, {}, 0, 0
71
  gr_chatinterface = gr.ChatInterface(
 
62
  infer = _init_infer()
63
  state['video_path'] = video_path
64
  yield 'finished initialization, responding...', state
65
+ if mode != 'Conversation':
66
+ yield 'waiting video input...', state
67
+ query = message
68
+ if video_path:
69
+ message = {
70
+ "role": "user",
71
+ "content": [
72
+ {"type": "video", "video": video_path},
73
+ {"type": "text", "text": query if query else default_query},
74
+ ],
75
+ }
76
+
77
  else:
78
+ message = {
79
+ "role": "user",
80
+ "content": [
81
+ {"type": "text", "text": query if query else default_query},
82
+ ],
83
+ }
84
+ image_inputs, video_inputs = process_vision_info([message])
85
+ texts = infer.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
86
+ past_ids = state.get('past_ids', None)
87
+ if past_ids is not None:
88
+ texts = '<|im_end|>\n' + texts[infer.system_prompt_offset:]
89
+ inputs = infer.processor(
90
+ text=texts,
91
+ images=image_inputs,
92
+ videos=video_inputs,
93
+ return_tensors="pt",
94
+ )
95
+ inputs.to(infer.model.device)
96
+ if past_ids is not None:
97
+ inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
98
+ outputs = infer.model.generate(
99
+ **inputs, past_key_values=state.get('past_key_values', None),
100
+ return_dict_in_generate=True, do_sample=do_sample,
101
+ repetition_penalty=repetition_penalty,
102
+ max_new_tokens=512,
103
+ )
104
+ state['past_key_values'] = outputs.past_key_values
105
+ state['past_ids'] = outputs.sequences[:, :-1]
106
+ response = infer.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
107
+ print(response)
108
+ return response, state
109
+
110
  def gr_chatinterface_chatbot_clear_fn():
111
  return {}, {}, 0, 0
112
  gr_chatinterface = gr.ChatInterface(