Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -31,6 +31,10 @@ if True:
|
|
31 |
file_path = snapshot_download(repo_id="lshzhm/Video-to-Audio-and-Piano", local_dir=model_path)
|
32 |
|
33 |
print(f"Model saved at: {file_path}")
|
|
|
|
|
|
|
|
|
34 |
|
35 |
log = logging.getLogger()
|
36 |
|
@@ -94,7 +98,7 @@ def read_audio_from_video(video_path):
|
|
94 |
return waveform
|
95 |
|
96 |
|
97 |
-
def load():
|
98 |
#duration_predictor = DurationPredictor(
|
99 |
# transformer = dict(
|
100 |
# dim = 512,
|
@@ -140,7 +144,7 @@ def load():
|
|
140 |
num_channels = 128,
|
141 |
sampling_rate = 24000,
|
142 |
)
|
143 |
-
e2tts = e2tts.to(
|
144 |
|
145 |
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
|
146 |
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
|
@@ -159,7 +163,7 @@ def load():
|
|
159 |
for param in e2tts.vocos.parameters():
|
160 |
param.requires_grad = False
|
161 |
e2tts.vocos.eval()
|
162 |
-
e2tts.vocos.to(
|
163 |
|
164 |
#dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
|
165 |
#sample = dataset[1]
|
@@ -190,7 +194,7 @@ def load():
|
|
190 |
return e2tts, stft
|
191 |
|
192 |
|
193 |
-
e2tts, stft = load()
|
194 |
|
195 |
|
196 |
def run(e2tts, stft, arg1, arg2, arg3, arg4):
|
|
|
31 |
file_path = snapshot_download(repo_id="lshzhm/Video-to-Audio-and-Piano", local_dir=model_path)
|
32 |
|
33 |
print(f"Model saved at: {file_path}")
|
34 |
+
|
35 |
+
device = "cpu"
|
36 |
+
else:
|
37 |
+
device = "cuda"
|
38 |
|
39 |
log = logging.getLogger()
|
40 |
|
|
|
98 |
return waveform
|
99 |
|
100 |
|
101 |
+
def load(device):
|
102 |
#duration_predictor = DurationPredictor(
|
103 |
# transformer = dict(
|
104 |
# dim = 512,
|
|
|
144 |
num_channels = 128,
|
145 |
sampling_rate = 24000,
|
146 |
)
|
147 |
+
e2tts = e2tts.to(device)
|
148 |
|
149 |
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
|
150 |
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
|
|
|
163 |
for param in e2tts.vocos.parameters():
|
164 |
param.requires_grad = False
|
165 |
e2tts.vocos.eval()
|
166 |
+
e2tts.vocos.to(device)
|
167 |
|
168 |
#dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
|
169 |
#sample = dataset[1]
|
|
|
194 |
return e2tts, stft
|
195 |
|
196 |
|
197 |
+
e2tts, stft = load(device)
|
198 |
|
199 |
|
200 |
def run(e2tts, stft, arg1, arg2, arg3, arg4):
|