SebastianBodza commited on
Commit
7be4a2e
·
verified ·
1 Parent(s): f135a1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -27,6 +27,18 @@ def get_or_load_model():
27
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
  t3_state = torch.load(checkpoint_path)
 
 
 
 
 
 
 
 
 
 
 
 
30
  cleaned_state_dict = {
31
  key.replace("_orig_mod.", ""): value
32
  for key, value in t3_state.items()
 
27
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
  t3_state = torch.load(checkpoint_path)
30
+
31
+ cleaned_state_dict = {
32
+ key.replace("_orig_mod.", ""): value
33
+ for key, value in t3_state.items()
34
+ }
35
+
36
+ if "cond_enc.spkr_enc.0.weight" in t3_state:
37
+ t3_state["cond_enc.spkr_enc.weight"] = t3_state.pop("cond_enc.spkr_enc.0.weight")
38
+ if "cond_enc.spkr_enc.0.bias" in t3_state:
39
+ t3_state["cond_enc.spkr_enc.bias"] = t3_state.pop("cond_enc.spkr_enc.0.bias")
40
+
41
+
42
  cleaned_state_dict = {
43
  key.replace("_orig_mod.", ""): value
44
  for key, value in t3_state.items()