SebastianBodza commited on
Commit
e4f0518
·
verified ·
1 Parent(s): b5f2fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -10,8 +10,8 @@ from safetensors.torch import load_file
10
 
11
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.1"
14
- T3_CHECKPOINT_FILE = "t3_kartoffelbox.safetensors"
15
  print(f"🚀 Running on device: {DEVICE}")
16
 
17
  # --- Global Model Initialization ---
@@ -26,8 +26,12 @@ def get_or_load_model():
26
  try:
27
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
- t3_state = load_file(checkpoint_path, device="cpu")
30
- MODEL.t3.load_state_dict(t3_state)
 
 
 
 
31
 
32
  if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
33
  MODEL.to(DEVICE)
 
10
 
11
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.2"
14
+ T3_CHECKPOINT_FILE = " t3_global_step_50.pt"
15
  print(f"🚀 Running on device: {DEVICE}")
16
 
17
  # --- Global Model Initialization ---
 
26
  try:
27
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
28
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
29
+ t3_state = torch.load(checkpoint_path, device="cpu")
30
+ cleaned_state_dict = {
31
+ key.replace("_orig_mod.", ""): value
32
+ for key, value in t3_state.items()
33
+ }
34
+ MODEL.t3.load_state_dict(cleaned_state_dict)
35
 
36
  if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
37
  MODEL.to(DEVICE)