Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -28,15 +28,15 @@ def get_or_load_model():
|
|
28 |
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
|
29 |
t3_state = torch.load(checkpoint_path)
|
30 |
|
31 |
-
if "cond_enc.spkr_enc.0.weight" in t3_state:
|
32 |
-
t3_state["cond_enc.spkr_enc.weight"] = t3_state.pop("cond_enc.spkr_enc.0.weight")
|
33 |
-
if "cond_enc.spkr_enc.0.bias" in t3_state:
|
34 |
-
t3_state["cond_enc.spkr_enc.bias"] = t3_state.pop("cond_enc.spkr_enc.0.bias")
|
35 |
-
|
36 |
cleaned_state_dict = {
|
37 |
key.replace("_orig_mod.", ""): value
|
38 |
for key, value in t3_state.items()
|
39 |
}
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
MODEL.t3.load_state_dict(cleaned_state_dict)
|
42 |
|
|
|
28 |
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
|
29 |
t3_state = torch.load(checkpoint_path)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
cleaned_state_dict = {
|
32 |
key.replace("_orig_mod.", ""): value
|
33 |
for key, value in t3_state.items()
|
34 |
}
|
35 |
+
|
36 |
+
if "cond_enc.spkr_enc.0.weight" in cleaned_state_dict:
|
37 |
+
cleaned_state_dict["cond_enc.spkr_enc.weight"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.weight")
|
38 |
+
if "cond_enc.spkr_enc.0.bias" in t3_state:
|
39 |
+
cleaned_state_dict["cond_enc.spkr_enc.bias"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.bias")
|
40 |
|
41 |
MODEL.t3.load_state_dict(cleaned_state_dict)
|
42 |
|