lshzhm commited on
Commit
671a2c8
·
1 Parent(s): b967777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -162,6 +162,8 @@ def load(device):
162
  # del checkpoint['model_state_dict'][key]
163
  e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
164
 
 
 
165
  e2tts.vocos = EncodecWrapper("facebook/encodec_24khz")
166
  for param in e2tts.vocos.parameters():
167
  param.requires_grad = False
@@ -198,6 +200,7 @@ def load(device):
198
 
199
 
200
  e2tts, stft = load(device)
 
201
 
202
 
203
  def run(e2tts, stft, arg1, arg2, arg3, arg4):
@@ -288,7 +291,7 @@ def run(e2tts, stft, arg1, arg2, arg3, arg4):
288
  print("paths", video_path, audio_path, video_path_gen)
289
  return video_path_gen
290
  except Exception as e:
291
- print("Exception", video_path, e)
292
  traceback.print_exc()
293
 
294
  if False:
@@ -311,6 +314,13 @@ def video_to_audio(video: gr.Video, prompt: str, num_steps: int):
311
 
312
  video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
313
 
 
 
 
 
 
 
 
314
  video_save_path = run(e2tts, stft, video_path, prompt, len(prompt)==0, num_steps)
315
 
316
  return video_save_path
@@ -320,6 +330,13 @@ def video_to_piano(video: gr.Video, prompt: str, num_steps: int):
320
 
321
  video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
322
 
 
 
 
 
 
 
 
323
  video_save_path = run(e2tts, stft, video_path, prompt, len(prompt)==0, num_steps)
324
 
325
  return video_save_path
 
162
  # del checkpoint['model_state_dict'][key]
163
  e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
164
 
165
+ del checkpoint
166
+
167
  e2tts.vocos = EncodecWrapper("facebook/encodec_24khz")
168
  for param in e2tts.vocos.parameters():
169
  param.requires_grad = False
 
200
 
201
 
202
  e2tts, stft = load(device)
203
+ e2tts = e2tts
204
 
205
 
206
  def run(e2tts, stft, arg1, arg2, arg3, arg4):
 
291
  print("paths", video_path, audio_path, video_path_gen)
292
  return video_path_gen
293
  except Exception as e:
294
+ print("Exception", e)
295
  traceback.print_exc()
296
 
297
  if False:
 
314
 
315
  video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
316
 
317
+ if video.startswith("http"):
318
+ data = requests.get(video, timeout=60).content
319
+ with open(video_path, "wb") as fw:
320
+ fw.write(data)
321
+ else:
322
+ shutil.copy(video, video_path)
323
+
324
  video_save_path = run(e2tts, stft, video_path, prompt, len(prompt)==0, num_steps)
325
 
326
  return video_save_path
 
330
 
331
  video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
332
 
333
+ if video.startswith("http"):
334
+ data = requests.get(video, timeout=60).content
335
+ with open(video_path, "wb") as fw:
336
+ fw.write(data)
337
+ else:
338
+ shutil.copy(video, video_path)
339
+
340
  video_save_path = run(e2tts, stft, video_path, prompt, len(prompt)==0, num_steps)
341
 
342
  return video_save_path