Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,035 Bytes
b7ef483 9d593b2 0b3e025 8d3d225 9d593b2 11aa5da 1af0819 0b3e025 9386371 0b3e025 9386371 0b3e025 8d3d225 a5c442f 7be4a2e e4f0518 09a240e 2f43ddf 09a240e 096206c e4f0518 8d3d225 9386371 0b3e025 9386371 0b3e025 9386371 0b3e025 9386371 9d593b2 9386371 9d593b2 0b3e025 3dab9c0 9d593b2 9386371 0b3e025 9386371 0b3e025 9386371 0b3e025 977d14c 0b3e025 9d593b2 3dab9c0 0b3e025 9d593b2 9386371 b5cf7f2 c3cf266 9386371 604a75f cec0ccc 604a75f 9386371 9d593b2 9386371 84c64ee 9386371 bf4bbc3 9386371 e2da991 9386371 54f0382 9386371 58ffee2 9d593b2 b5cf7f2 9d593b2 9386371 9d593b2 58ffee2 9d593b2 9386371 9d593b2 9386371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import random
import torch
import spaces
import numpy as np
import gradio as gr
from chatterbox.src.chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.3"
T3_CHECKPOINT_FILE = "t3_final.pt"
print(f"🚀 Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
def get_or_load_model():
"""Loads the ChatterboxTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
t3_state = torch.load(checkpoint_path)
cleaned_state_dict = {
key.replace("_orig_mod.", ""): value
for key, value in t3_state.items()
}
if "cond_enc.spkr_enc.0.weight" in cleaned_state_dict:
cleaned_state_dict["cond_enc.spkr_enc.weight"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.weight")
if "cond_enc.spkr_enc.0.bias" in cleaned_state_dict:
cleaned_state_dict["cond_enc.spkr_enc.bias"] = cleaned_state_dict.pop("cond_enc.spkr_enc.0.bias")
MODEL.t3.load_state_dict(cleaned_state_dict)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str,
exaggeration_input: float,
temperature_input: float,
seed_num_input: int,
cfgw_input: float
) -> tuple[int, np.ndarray]:
"""
Generates TTS audio using the ChatterboxTTS model.
Args:
text_input: The text to synthesize (max 300 characters).
audio_prompt_path_input: Path to the reference audio file.
exaggeration_input: Exaggeration parameter for the model.
temperature_input: Temperature parameter for the model.
seed_num_input: Random seed (0 for random).
cfgw_input: CFG/Pace weight.
Returns:
A tuple containing the sample rate (int) and the audio waveform (numpy.ndarray).
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
wav = current_model.generate(
text_input, # Truncate text to max chars
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Kartoffel-TTS (Based on Chatterbox) - German Text-to-Speech Demo
**PREVIEW SPACE:** Mainly for me to test new finetunes.
Generate high-quality speech from text with reference audio styling.
# Limited support of vocal expressions:
- `<haha>`, `<hahaha>`, `<hahahaha>`, `<chuckle>`
- `<wuhuuu>`, `<wow>`
- `<hmm_neugierig>`, `<hmph>`
- `<huh>`
- `<ohhh>`, `<oooh>`, `<ughh>`
- `<eeehhh>`, `<aaaaaaah>`, `<aaach>`
Examples:
"Muss das denn immer so nass enden?" Fips konnte sich vor Lachen kaum halten. <hahaha>, das war einfach zu komisch.
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Eines Nachmittags sprang er auf den Rücken seines Freundes, des Otters, und zusammen sausten sie den alten Bachlaufs hinunter. Fips warf die Arme in die Luft und schrie <wuhuuu>",
label="Text to synthesize (max chars 300)",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
exaggeration = gr.Slider(
0.25, 6, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.0, 1, step=.05, label="CFG/Pace", value=0.8
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.6)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output],
)
demo.launch() |