SebastianBodza commited on
Commit
09a240e
·
verified ·
1 Parent(s): 096206c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -28,15 +28,15 @@ def get_or_load_model():
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
  t3_state = torch.load(checkpoint_path)
30
 
31
- if "cond_enc.spkr_enc.0.weight" in t3_state:
32
- t3_state["cond_enc.spkr_enc.weight"] = t3_state.pop("cond_enc.spkr_enc.0.weight")
33
- if "cond_enc.spkr_enc.0.bias" in t3_state:
34
- t3_state["cond_enc.spkr_enc.bias"] = t3_state.pop("cond_enc.spkr_enc.0.bias")
35
-
36
  cleaned_state_dict = {
37
  key.replace("_orig_mod.", ""): value
38
  for key, value in t3_state.items()
39
  }
 
 
 
 
 
40
 
41
  MODEL.t3.load_state_dict(cleaned_state_dict)
42
 
 
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
  t3_state = torch.load(checkpoint_path)
30
 
 
 
 
 
 
31
  cleaned_state_dict = {
32
  key.replace("_orig_mod.", ""): value
33
  for key, value in t3_state.items()
34
  }
35
+
36
+ if "cond_enc.spkr_enc.0.weight" in cleaned_state_dict:
37
+ cleaned_state_dict["cond_enc.spkr_enc.weight"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.weight")
38
+ if "cond_enc.spkr_enc.0.bias" in t3_state:
39
+ cleaned_state_dict["cond_enc.spkr_enc.bias"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.bias")
40
 
41
  MODEL.t3.load_state_dict(cleaned_state_dict)
42