lshzhm commited on
Commit
7b09efd
·
1 Parent(s): 5688bca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -31,6 +31,10 @@ if True:
31
  file_path = snapshot_download(repo_id="lshzhm/Video-to-Audio-and-Piano", local_dir=model_path)
32
 
33
  print(f"Model saved at: {file_path}")
 
 
 
 
34
 
35
  log = logging.getLogger()
36
 
@@ -94,7 +98,7 @@ def read_audio_from_video(video_path):
94
  return waveform
95
 
96
 
97
- def load():
98
  #duration_predictor = DurationPredictor(
99
  # transformer = dict(
100
  # dim = 512,
@@ -140,7 +144,7 @@ def load():
140
  num_channels = 128,
141
  sampling_rate = 24000,
142
  )
143
- e2tts = e2tts.to("cuda")
144
 
145
  #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
146
  #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
@@ -159,7 +163,7 @@ def load():
159
  for param in e2tts.vocos.parameters():
160
  param.requires_grad = False
161
  e2tts.vocos.eval()
162
- e2tts.vocos.to("cuda")
163
 
164
  #dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
165
  #sample = dataset[1]
@@ -190,7 +194,7 @@ def load():
190
  return e2tts, stft
191
 
192
 
193
- e2tts, stft = load()
194
 
195
 
196
  def run(e2tts, stft, arg1, arg2, arg3, arg4):
 
31
  file_path = snapshot_download(repo_id="lshzhm/Video-to-Audio-and-Piano", local_dir=model_path)
32
 
33
  print(f"Model saved at: {file_path}")
34
+
35
+ device = "cpu"
36
+ else:
37
+ device = "cuda"
38
 
39
  log = logging.getLogger()
40
 
 
98
  return waveform
99
 
100
 
101
+ def load(device):
102
  #duration_predictor = DurationPredictor(
103
  # transformer = dict(
104
  # dim = 512,
 
144
  num_channels = 128,
145
  sampling_rate = 24000,
146
  )
147
+ e2tts = e2tts.to(device)
148
 
149
  #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
150
  #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
 
163
  for param in e2tts.vocos.parameters():
164
  param.requires_grad = False
165
  e2tts.vocos.eval()
166
+ e2tts.vocos.to(device)
167
 
168
  #dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
169
  #sample = dataset[1]
 
194
  return e2tts, stft
195
 
196
 
197
+ e2tts, stft = load(device)
198
 
199
 
200
  def run(e2tts, stft, arg1, arg2, arg3, arg4):