hynt commited on
Commit
b25efa1
·
1 Parent(s): c4ce1b4

fix module

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. f5_tts/api.py +59 -50
  2. f5_tts/configs/{E2TTS_Base.yaml → E2TTS_Base_train.yaml} +7 -11
  3. f5_tts/configs/{E2TTS_Small.yaml → E2TTS_Small_train.yaml} +7 -11
  4. f5_tts/configs/{F5TTS_Base.yaml → F5TTS_Base_train.yaml} +11 -15
  5. f5_tts/configs/{F5TTS_Small.yaml → F5TTS_Small_train.yaml} +7 -11
  6. f5_tts/configs/F5TTS_v1_Base.yaml +0 -53
  7. f5_tts/eval/eval_infer_batch.py +27 -22
  8. f5_tts/eval/eval_infer_batch.sh +6 -11
  9. f5_tts/eval/eval_librispeech_test_clean.py +27 -21
  10. f5_tts/eval/eval_seedtts_testset.py +27 -21
  11. f5_tts/eval/eval_utmos.py +16 -14
  12. f5_tts/eval/utils_eval.py +6 -11
  13. f5_tts/infer/README.md +85 -40
  14. f5_tts/infer/SHARED.md +9 -19
  15. f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc +0 -0
  16. f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc +0 -0
  17. f5_tts/infer/examples/basic/basic.toml +3 -3
  18. f5_tts/infer/examples/multi/story.toml +2 -2
  19. f5_tts/infer/infer_cli.py +33 -28
  20. f5_tts/infer/infer_gradio.py +13 -52
  21. f5_tts/infer/speech_edit.py +27 -26
  22. f5_tts/infer/utils_infer.py +76 -105
  23. f5_tts/model/__pycache__/__init__.cpython-310.pyc +0 -0
  24. f5_tts/model/__pycache__/cfm.cpython-310.pyc +0 -0
  25. f5_tts/model/__pycache__/dataset.cpython-310.pyc +0 -0
  26. f5_tts/model/__pycache__/modules.cpython-310.pyc +0 -0
  27. f5_tts/model/__pycache__/trainer.cpython-310.pyc +0 -0
  28. f5_tts/model/__pycache__/utils.cpython-310.pyc +0 -0
  29. f5_tts/model/backbones/README.md +2 -2
  30. f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc +0 -0
  31. f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc +0 -0
  32. f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc +0 -0
  33. f5_tts/model/backbones/dit.py +8 -63
  34. f5_tts/model/backbones/mmdit.py +9 -52
  35. f5_tts/model/backbones/unett.py +5 -36
  36. f5_tts/model/cfm.py +2 -3
  37. f5_tts/model/dataset.py +3 -6
  38. f5_tts/model/modules.py +42 -115
  39. f5_tts/model/trainer.py +18 -29
  40. f5_tts/model/utils.py +16 -8
  41. f5_tts/scripts/count_max_epoch.py +1 -1
  42. f5_tts/socket_client.py +0 -61
  43. f5_tts/socket_server.py +98 -169
  44. f5_tts/train/README.md +5 -5
  45. f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc +0 -0
  46. f5_tts/train/datasets/prepare_csv_wavs.py +2 -2
  47. f5_tts/train/datasets/prepare_emilia.py +4 -4
  48. f5_tts/train/datasets/prepare_libritts.py +10 -5
  49. f5_tts/train/datasets/prepare_metadata.py +12 -0
  50. f5_tts/train/finetune_cli.py +20 -53
f5_tts/api.py CHANGED
@@ -5,43 +5,43 @@ from importlib.resources import files
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
8
- from omegaconf import OmegaConf
9
 
10
  from f5_tts.infer.utils_infer import (
 
 
11
  load_model,
12
  load_vocoder,
13
- transcribe,
14
  preprocess_ref_audio_text,
15
- infer_process,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
 
 
18
  )
19
- from f5_tts.model import DiT, UNetT # noqa: F401. used for config
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
23
  class F5TTS:
24
  def __init__(
25
  self,
26
- model="F5TTS_v1_Base",
27
  ckpt_file="",
28
  vocab_file="",
29
  ode_method="euler",
30
  use_ema=True,
31
- vocoder_local_path=None,
 
32
  device=None,
33
  hf_cache_dir=None,
34
  ):
35
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
36
- model_cls = globals()[model_cfg.model.backbone]
37
- model_arc = model_cfg.model.arch
38
-
39
- self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
40
- self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
41
-
42
- self.ode_method = ode_method
43
- self.use_ema = use_ema
44
-
45
  if device is not None:
46
  self.device = device
47
  else:
@@ -58,31 +58,39 @@ class F5TTS:
58
  )
59
 
60
  # Load models
61
- self.vocoder = load_vocoder(
62
- self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
 
63
  )
64
 
65
- repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
66
-
67
- # override for previous models
68
- if model == "F5TTS_Base":
69
- if self.mel_spec_type == "vocos":
70
- ckpt_step = 1200000
71
- elif self.mel_spec_type == "bigvgan":
72
- model = "F5TTS_Base_bigvgan"
73
- ckpt_type = "pt"
74
- elif model == "E2TTS_Base":
75
- repo_name = "E2-TTS"
76
- ckpt_step = 1200000
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
- raise ValueError(f"Unknown model type: {model}")
79
 
80
- if not ckpt_file:
81
- ckpt_file = str(
82
- cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
83
- )
84
  self.ema_model = load_model(
85
- model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
86
  )
87
 
88
  def transcribe(self, ref_audio, language=None):
@@ -94,8 +102,8 @@ class F5TTS:
94
  if remove_silence:
95
  remove_silence_for_generated_wav(file_wave)
96
 
97
- def export_spectrogram(self, spec, file_spec):
98
- save_spectrogram(spec, file_spec)
99
 
100
  def infer(
101
  self,
@@ -113,16 +121,17 @@ class F5TTS:
113
  fix_duration=None,
114
  remove_silence=False,
115
  file_wave=None,
116
- file_spec=None,
117
- seed=None,
118
  ):
119
- if seed is None:
120
- self.seed = random.randint(0, sys.maxsize)
121
- seed_everything(self.seed)
 
122
 
123
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
124
 
125
- wav, sr, spec = infer_process(
126
  ref_file,
127
  ref_text,
128
  gen_text,
@@ -144,22 +153,22 @@ class F5TTS:
144
  if file_wave is not None:
145
  self.export_wav(wav, file_wave, remove_silence)
146
 
147
- if file_spec is not None:
148
- self.export_spectrogram(spec, file_spec)
149
 
150
- return wav, sr, spec
151
 
152
 
153
  if __name__ == "__main__":
154
  f5tts = F5TTS()
155
 
156
- wav, sr, spec = f5tts.infer(
157
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
158
  ref_text="some call me nature, others call me mother nature.",
159
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
160
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
161
- file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
162
- seed=None,
163
  )
164
 
165
  print("seed :", f5tts.seed)
 
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
 
8
 
9
  from f5_tts.infer.utils_infer import (
10
+ hop_length,
11
+ infer_process,
12
  load_model,
13
  load_vocoder,
 
14
  preprocess_ref_audio_text,
 
15
  remove_silence_for_generated_wav,
16
  save_spectrogram,
17
+ transcribe,
18
+ target_sample_rate,
19
  )
20
+ from f5_tts.model import DiT, UNetT
21
  from f5_tts.model.utils import seed_everything
22
 
23
 
24
  class F5TTS:
25
  def __init__(
26
  self,
27
+ model_type="F5-TTS",
28
  ckpt_file="",
29
  vocab_file="",
30
  ode_method="euler",
31
  use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
  device=None,
35
  hf_cache_dir=None,
36
  ):
37
+ # Initialize parameters
38
+ self.final_wave = None
39
+ self.target_sample_rate = target_sample_rate
40
+ self.hop_length = hop_length
41
+ self.seed = -1
42
+ self.mel_spec_type = vocoder_name
43
+
44
+ # Set device
 
 
45
  if device is not None:
46
  self.device = device
47
  else:
 
58
  )
59
 
60
  # Load models
61
+ self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
62
+ self.load_ema_model(
63
+ model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
64
  )
65
 
66
+ def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
67
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
68
+
69
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
70
+ if model_type == "F5-TTS":
71
+ if not ckpt_file:
72
+ if mel_spec_type == "vocos":
73
+ ckpt_file = str(
74
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
75
+ )
76
+ elif mel_spec_type == "bigvgan":
77
+ ckpt_file = str(
78
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
79
+ )
80
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
81
+ model_cls = DiT
82
+ elif model_type == "E2-TTS":
83
+ if not ckpt_file:
84
+ ckpt_file = str(
85
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
86
+ )
87
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
88
+ model_cls = UNetT
89
  else:
90
+ raise ValueError(f"Unknown model type: {model_type}")
91
 
 
 
 
 
92
  self.ema_model = load_model(
93
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
94
  )
95
 
96
  def transcribe(self, ref_audio, language=None):
 
102
  if remove_silence:
103
  remove_silence_for_generated_wav(file_wave)
104
 
105
+ def export_spectrogram(self, spect, file_spect):
106
+ save_spectrogram(spect, file_spect)
107
 
108
  def infer(
109
  self,
 
121
  fix_duration=None,
122
  remove_silence=False,
123
  file_wave=None,
124
+ file_spect=None,
125
+ seed=-1,
126
  ):
127
+ if seed == -1:
128
+ seed = random.randint(0, sys.maxsize)
129
+ seed_everything(seed)
130
+ self.seed = seed
131
 
132
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
133
 
134
+ wav, sr, spect = infer_process(
135
  ref_file,
136
  ref_text,
137
  gen_text,
 
153
  if file_wave is not None:
154
  self.export_wav(wav, file_wave, remove_silence)
155
 
156
+ if file_spect is not None:
157
+ self.export_spectrogram(spect, file_spect)
158
 
159
+ return wav, sr, spect
160
 
161
 
162
  if __name__ == "__main__":
163
  f5tts = F5TTS()
164
 
165
+ wav, sr, spect = f5tts.infer(
166
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
167
  ref_text="some call me nature, others call me mother nature.",
168
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
169
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
170
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
171
+ seed=-1, # random seed = -1
172
  )
173
 
174
  print("seed :", f5tts.seed)
f5_tts/configs/{E2TTS_Base.yaml → E2TTS_Base_train.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,29 +20,25 @@ optim:
20
  model:
21
  name: E2TTS_Base
22
  tokenizer: pinyin
23
- tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
- backbone: UNetT
25
  arch:
26
  dim: 1024
27
  depth: 24
28
  heads: 16
29
  ff_mult: 4
30
- text_mask_padding: False
31
- pe_attn_head: 1
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
- mel_spec_type: vocos # vocos | bigvgan
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
- local_path: null # local vocoder path
42
 
43
  ckpts:
44
- logger: wandb # wandb | tensorboard | null
45
- log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
  save_per_updates: 50000 # save checkpoint per updates
47
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: E2TTS_Base
22
  tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 1024
26
  depth: 24
27
  heads: 16
28
  ff_mult: 4
 
 
29
  mel_spec:
30
  target_sample_rate: 24000
31
  n_mel_channels: 100
32
  hop_length: 256
33
  win_length: 1024
34
  n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
  vocoder:
37
  is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
 
40
  ckpts:
41
+ logger: wandb # wandb | tensorboard | None
 
42
  save_per_updates: 50000 # save checkpoint per updates
43
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
44
  last_per_updates: 5000 # save last checkpoint per updates
f5_tts/configs/{E2TTS_Small.yaml → E2TTS_Small_train.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,29 +20,25 @@ optim:
20
  model:
21
  name: E2TTS_Small
22
  tokenizer: pinyin
23
- tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
- backbone: UNetT
25
  arch:
26
  dim: 768
27
  depth: 20
28
  heads: 12
29
  ff_mult: 4
30
- text_mask_padding: False
31
- pe_attn_head: 1
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
- mel_spec_type: vocos # vocos | bigvgan
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
- local_path: null # local vocoder path
42
 
43
  ckpts:
44
- logger: wandb # wandb | tensorboard | null
45
- log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
  save_per_updates: 50000 # save checkpoint per updates
47
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: E2TTS_Small
22
  tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 768
26
  depth: 20
27
  heads: 12
28
  ff_mult: 4
 
 
29
  mel_spec:
30
  target_sample_rate: 24000
31
  n_mel_channels: 100
32
  hop_length: 256
33
  win_length: 1024
34
  n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
  vocoder:
37
  is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
 
40
  ckpts:
41
+ logger: wandb # wandb | tensorboard | None
 
42
  save_per_updates: 50000 # save checkpoint per updates
43
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
44
  last_per_updates: 5000 # save last checkpoint per updates
f5_tts/configs/{F5TTS_Base.yaml → F5TTS_Base_train.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
- name: your_training_dataset # dataset name
7
- batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,17 +20,14 @@ optim:
20
  model:
21
  name: F5TTS_Base # model name
22
  tokenizer: char # tokenizer type
23
- tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
- backbone: DiT
25
  arch:
26
  dim: 1024
27
  depth: 22
28
  heads: 16
29
  ff_mult: 2
30
  text_dim: 512
31
- text_mask_padding: False
32
  conv_layers: 4
33
- pe_attn_head: 1
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
@@ -38,15 +35,14 @@ model:
38
  hop_length: 256
39
  win_length: 1024
40
  n_fft: 1024
41
- mel_spec_type: vocos # vocos | bigvgan
42
  vocoder:
43
- is_local: False # use local offline ckpt or not
44
- local_path: null # local vocoder path
45
 
46
  ckpts:
47
- logger: tensorboard # wandb | tensorboard | null
48
- log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
- save_per_updates: 50000 # save checkpoint per updates
50
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
  last_per_updates: 5000 # save last checkpoint per updates
52
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
+ name: vn_1000h # dataset name
7
+ batch_size_per_gpu: 2000 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 200
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: F5TTS_Base # model name
22
  tokenizer: char # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 1024
26
  depth: 22
27
  heads: 16
28
  ff_mult: 2
29
  text_dim: 512
 
30
  conv_layers: 4
 
31
  checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
 
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
  vocoder:
40
+ is_local: True # use local offline ckpt or not
41
+ local_path: /mnt/i/Project/F5-TTS/ckpts/vocos # local vocoder path
42
 
43
  ckpts:
44
+ logger: tensorboard # wandb | tensorboard | None
45
+ save_per_updates: 30000 # save checkpoint per updates
 
46
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
47
  last_per_updates: 5000 # save last checkpoint per updates
48
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
f5_tts/configs/{F5TTS_Small.yaml → F5TTS_Small_train.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,17 +20,14 @@ optim:
20
  model:
21
  name: F5TTS_Small
22
  tokenizer: pinyin
23
- tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
- backbone: DiT
25
  arch:
26
  dim: 768
27
  depth: 18
28
  heads: 12
29
  ff_mult: 2
30
  text_dim: 512
31
- text_mask_padding: False
32
  conv_layers: 4
33
- pe_attn_head: 1
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
@@ -38,14 +35,13 @@ model:
38
  hop_length: 256
39
  win_length: 1024
40
  n_fft: 1024
41
- mel_spec_type: vocos # vocos | bigvgan
42
  vocoder:
43
  is_local: False # use local offline ckpt or not
44
- local_path: null # local vocoder path
45
 
46
  ckpts:
47
- logger: wandb # wandb | tensorboard | null
48
- log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
  save_per_updates: 50000 # save checkpoint per updates
50
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: F5TTS_Small
22
  tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 768
26
  depth: 18
27
  heads: 12
28
  ff_mult: 2
29
  text_dim: 512
 
30
  conv_layers: 4
 
31
  checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
 
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
+ local_path: None # local vocoder path
42
 
43
  ckpts:
44
+ logger: wandb # wandb | tensorboard | None
 
45
  save_per_updates: 50000 # save checkpoint per updates
46
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
47
  last_per_updates: 5000 # save last checkpoint per updates
f5_tts/configs/F5TTS_v1_Base.yaml DELETED
@@ -1,53 +0,0 @@
1
- hydra:
2
- run:
3
- dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
- datasets:
6
- name: Emilia_ZH_EN # dataset name
7
- batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # frame | sample
9
- max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
- num_workers: 16
11
-
12
- optim:
13
- epochs: 11
14
- learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup updates
16
- grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
- max_grad_norm: 1.0 # gradient clipping
18
- bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
-
20
- model:
21
- name: F5TTS_v1_Base # model name
22
- tokenizer: pinyin # tokenizer type
23
- tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
- backbone: DiT
25
- arch:
26
- dim: 1024
27
- depth: 22
28
- heads: 16
29
- ff_mult: 2
30
- text_dim: 512
31
- text_mask_padding: True
32
- qk_norm: null # null | rms_norm
33
- conv_layers: 4
34
- pe_attn_head: null
35
- checkpoint_activations: False # recompute activations and save memory for extra compute
36
- mel_spec:
37
- target_sample_rate: 24000
38
- n_mel_channels: 100
39
- hop_length: 256
40
- win_length: 1024
41
- n_fft: 1024
42
- mel_spec_type: vocos # vocos | bigvgan
43
- vocoder:
44
- is_local: False # use local offline ckpt or not
45
- local_path: null # local vocoder path
46
-
47
- ckpts:
48
- logger: wandb # wandb | tensorboard | null
49
- log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
50
- save_per_updates: 50000 # save checkpoint per updates
51
- keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
52
- last_per_updates: 5000 # save last checkpoint per updates
53
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5_tts/eval/eval_infer_batch.py CHANGED
@@ -10,7 +10,6 @@ from importlib.resources import files
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
13
- from omegaconf import OmegaConf
14
  from tqdm import tqdm
15
 
16
  from f5_tts.eval.utils_eval import (
@@ -19,26 +18,36 @@ from f5_tts.eval.utils_eval import (
19
  get_seedtts_testset_metainfo,
20
  )
21
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
22
- from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
23
  from f5_tts.model.utils import get_tokenizer
24
 
25
  accelerator = Accelerator()
26
  device = f"cuda:{accelerator.process_index}"
27
 
28
 
29
- use_ema = True
30
- target_rms = 0.1
31
 
 
 
 
 
 
 
32
 
33
  rel_path = str(files("f5_tts").joinpath("../../"))
34
 
35
 
36
  def main():
 
 
37
  parser = argparse.ArgumentParser(description="batch inference")
38
 
39
  parser.add_argument("-s", "--seed", default=None, type=int)
 
40
  parser.add_argument("-n", "--expname", required=True)
41
- parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
 
 
42
 
43
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
44
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -49,8 +58,12 @@ def main():
49
  args = parser.parse_args()
50
 
51
  seed = args.seed
 
52
  exp_name = args.expname
53
  ckpt_step = args.ckptstep
 
 
 
54
 
55
  nfe_step = args.nfestep
56
  ode_method = args.odemethod
@@ -64,19 +77,13 @@ def main():
64
  use_truth_duration = False
65
  no_ref_audio = False
66
 
67
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
68
- model_cls = globals()[model_cfg.model.backbone]
69
- model_arc = model_cfg.model.arch
70
 
71
- dataset_name = model_cfg.datasets.name
72
- tokenizer = model_cfg.model.tokenizer
73
-
74
- mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
75
- target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
76
- n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
77
- hop_length = model_cfg.model.mel_spec.hop_length
78
- win_length = model_cfg.model.mel_spec.win_length
79
- n_fft = model_cfg.model.mel_spec.n_fft
80
 
81
  if testset == "ls_pc_test_clean":
82
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
@@ -104,6 +111,8 @@ def main():
104
 
105
  # -------------------------------------------------#
106
 
 
 
107
  prompts_all = get_inference_prompt(
108
  metainfo,
109
  speed=speed,
@@ -130,7 +139,7 @@ def main():
130
 
131
  # Model
132
  model = CFM(
133
- transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
134
  mel_spec_kwargs=dict(
135
  n_fft=n_fft,
136
  hop_length=hop_length,
@@ -145,10 +154,6 @@ def main():
145
  vocab_char_map=vocab_char_map,
146
  ).to(device)
147
 
148
- ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
149
- if not os.path.exists(ckpt_path):
150
- print("Loading from self-organized training checkpoints rather than released pretrained.")
151
- ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
152
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
153
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
154
 
 
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
 
13
  from tqdm import tqdm
14
 
15
  from f5_tts.eval.utils_eval import (
 
18
  get_seedtts_testset_metainfo,
19
  )
20
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
+ from f5_tts.model import CFM, DiT, UNetT
22
  from f5_tts.model.utils import get_tokenizer
23
 
24
  accelerator = Accelerator()
25
  device = f"cuda:{accelerator.process_index}"
26
 
27
 
28
+ # --------------------- Dataset Settings -------------------- #
 
29
 
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ win_length = 1024
34
+ n_fft = 1024
35
+ target_rms = 0.1
36
 
37
  rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
40
  def main():
41
+ # ---------------------- infer setting ---------------------- #
42
+
43
  parser = argparse.ArgumentParser(description="batch inference")
44
 
45
  parser.add_argument("-s", "--seed", default=None, type=int)
46
+ parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
47
  parser.add_argument("-n", "--expname", required=True)
48
+ parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
+ parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
+ parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
 
52
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
  parser.add_argument("-o", "--odemethod", default="euler")
 
58
  args = parser.parse_args()
59
 
60
  seed = args.seed
61
+ dataset_name = args.dataset
62
  exp_name = args.expname
63
  ckpt_step = args.ckptstep
64
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
+ mel_spec_type = args.mel_spec_type
66
+ tokenizer = args.tokenizer
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
 
77
  use_truth_duration = False
78
  no_ref_audio = False
79
 
80
+ if exp_name == "F5TTS_Base":
81
+ model_cls = DiT
82
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83
 
84
+ elif exp_name == "E2TTS_Base":
85
+ model_cls = UNetT
86
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
87
 
88
  if testset == "ls_pc_test_clean":
89
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
 
111
 
112
  # -------------------------------------------------#
113
 
114
+ use_ema = True
115
+
116
  prompts_all = get_inference_prompt(
117
  metainfo,
118
  speed=speed,
 
139
 
140
  # Model
141
  model = CFM(
142
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143
  mel_spec_kwargs=dict(
144
  n_fft=n_fft,
145
  hop_length=hop_length,
 
154
  vocab_char_map=vocab_char_map,
155
  ).to(device)
156
 
 
 
 
 
157
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159
 
f5_tts/eval/eval_infer_batch.sh CHANGED
@@ -1,18 +1,13 @@
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
-
13
- # e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
14
- python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
15
- python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
16
- python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
17
 
18
  # etc.
 
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
 
 
 
 
 
12
 
13
  # etc.
f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -53,37 +53,43 @@ def main():
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
- # --------------------------------------------------------------------------
57
-
58
- full_results = []
59
- metrics = []
60
 
61
  if eval_task == "wer":
 
 
 
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
- full_results.extend(r)
67
- elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with mp.Pool(processes=len(gpus)) as pool:
69
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
70
  results = pool.map(run_sim, args)
71
  for r in results:
72
- full_results.extend(r)
73
- else:
74
- raise ValueError(f"Unknown metric type: {eval_task}")
75
-
76
- result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
77
- with open(result_path, "w") as f:
78
- for line in full_results:
79
- metrics.append(line[eval_task])
80
- f.write(json.dumps(line, ensure_ascii=False) + "\n")
81
- metric = round(np.mean(metrics), 5)
82
- f.write(f"\n{eval_task.upper()}: {metric}\n")
83
-
84
- print(f"\nTotal {len(metrics)} samples")
85
- print(f"{eval_task.upper()}: {metric}")
86
- print(f"{eval_task.upper()} results saved to {result_path}")
87
 
88
 
89
  if __name__ == "__main__":
 
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
+ # --------------------------- WER ---------------------------
 
 
 
57
 
58
  if eval_task == "wer":
59
+ wer_results = []
60
+ wers = []
61
+
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
+ wer_results.extend(r)
67
+
68
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
+ with open(wer_result_path, "w") as f:
70
+ for line in wer_results:
71
+ wers.append(line["wer"])
72
+ json_line = json.dumps(line, ensure_ascii=False)
73
+ f.write(json_line + "\n")
74
+
75
+ wer = round(np.mean(wers) * 100, 3)
76
+ print(f"\nTotal {len(wers)} samples")
77
+ print(f"WER : {wer}%")
78
+ print(f"Results have been saved to {wer_result_path}")
79
+
80
+ # --------------------------- SIM ---------------------------
81
+
82
+ if eval_task == "sim":
83
+ sims = []
84
  with mp.Pool(processes=len(gpus)) as pool:
85
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
  results = pool.map(run_sim, args)
87
  for r in results:
88
+ sims.extend(r)
89
+
90
+ sim = round(sum(sims) / len(sims), 3)
91
+ print(f"\nTotal {len(sims)} samples")
92
+ print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  if __name__ == "__main__":
f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -52,37 +52,43 @@ def main():
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
- # --------------------------------------------------------------------------
56
-
57
- full_results = []
58
- metrics = []
59
 
60
  if eval_task == "wer":
 
 
 
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
- full_results.extend(r)
66
- elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with mp.Pool(processes=len(gpus)) as pool:
68
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
  results = pool.map(run_sim, args)
70
  for r in results:
71
- full_results.extend(r)
72
- else:
73
- raise ValueError(f"Unknown metric type: {eval_task}")
74
-
75
- result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
76
- with open(result_path, "w") as f:
77
- for line in full_results:
78
- metrics.append(line[eval_task])
79
- f.write(json.dumps(line, ensure_ascii=False) + "\n")
80
- metric = round(np.mean(metrics), 5)
81
- f.write(f"\n{eval_task.upper()}: {metric}\n")
82
-
83
- print(f"\nTotal {len(metrics)} samples")
84
- print(f"{eval_task.upper()}: {metric}")
85
- print(f"{eval_task.upper()} results saved to {result_path}")
86
 
87
 
88
  if __name__ == "__main__":
 
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
+ # --------------------------- WER ---------------------------
 
 
 
56
 
57
  if eval_task == "wer":
58
+ wer_results = []
59
+ wers = []
60
+
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
+ wer_results.extend(r)
66
+
67
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
+ with open(wer_result_path, "w") as f:
69
+ for line in wer_results:
70
+ wers.append(line["wer"])
71
+ json_line = json.dumps(line, ensure_ascii=False)
72
+ f.write(json_line + "\n")
73
+
74
+ wer = round(np.mean(wers) * 100, 3)
75
+ print(f"\nTotal {len(wers)} samples")
76
+ print(f"WER : {wer}%")
77
+ print(f"Results have been saved to {wer_result_path}")
78
+
79
+ # --------------------------- SIM ---------------------------
80
+
81
+ if eval_task == "sim":
82
+ sims = []
83
  with mp.Pool(processes=len(gpus)) as pool:
84
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
  results = pool.map(run_sim, args)
86
  for r in results:
87
+ sims.extend(r)
88
+
89
+ sim = round(sum(sims) / len(sims), 3)
90
+ print(f"\nTotal {len(sims)} samples")
91
+ print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  if __name__ == "__main__":
f5_tts/eval/eval_utmos.py CHANGED
@@ -19,23 +19,25 @@ def main():
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
 
22
  utmos_score = 0
23
 
24
- utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
 
 
 
 
 
 
 
 
 
 
 
25
  with open(utmos_result_path, "w", encoding="utf-8") as f:
26
- for audio_path in tqdm(audio_paths, desc="Processing"):
27
- wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
- wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
- score = predictor(wav_tensor, sr)
30
- line = {}
31
- line["wav"], line["utmos"] = str(audio_path.stem), score.item()
32
- utmos_score += score.item()
33
- f.write(json.dumps(line, ensure_ascii=False) + "\n")
34
- avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
35
- f.write(f"\nUTMOS: {avg_score:.4f}\n")
36
-
37
- print(f"UTMOS: {avg_score:.4f}")
38
- print(f"UTMOS results saved to {utmos_result_path}")
39
 
40
 
41
  if __name__ == "__main__":
 
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
+ utmos_results = {}
23
  utmos_score = 0
24
 
25
+ for audio_path in tqdm(audio_paths, desc="Processing"):
26
+ wav_name = audio_path.stem
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ utmos_results[str(wav_name)] = score.item()
31
+ utmos_score += score.item()
32
+
33
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
+ print(f"UTMOS: {avg_score}")
35
+
36
+ utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
  with open(utmos_result_path, "w", encoding="utf-8") as f:
38
+ json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
+
40
+ print(f"Results have been saved to {utmos_result_path}")
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == "__main__":
f5_tts/eval/utils_eval.py CHANGED
@@ -389,10 +389,10 @@ def run_sim(args):
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
- sim_results = []
393
- for gen_wav, prompt_wav, truth in tqdm(test_set):
394
- wav1, sr1 = torchaudio.load(gen_wav)
395
- wav2, sr2 = torchaudio.load(prompt_wav)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
@@ -408,11 +408,6 @@ def run_sim(args):
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
- sim_results.append(
412
- {
413
- "wav": Path(gen_wav).stem,
414
- "sim": sim,
415
- }
416
- )
417
 
418
- return sim_results
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sims = []
393
+ for wav1, wav2, truth in tqdm(test_set):
394
+ wav1, sr1 = torchaudio.load(wav1)
395
+ wav2, sr2 = torchaudio.load(wav2)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sims.append(sim)
 
 
 
 
 
412
 
413
+ return sims
f5_tts/infer/README.md CHANGED
@@ -23,24 +23,12 @@ Currently supported features:
23
  - Basic TTS with Chunk Inference
24
  - Multi-Style / Multi-Speaker Generation
25
  - Voice Chat powered by Qwen2.5-3B-Instruct
26
- - [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
27
 
28
  The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
29
 
30
  The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
31
 
32
- More flags options:
33
-
34
- ```bash
35
- # Automatically launch the interface in the default web browser
36
- f5-tts_infer-gradio --inbrowser
37
-
38
- # Set the root path of the application, if it's not served from the root ("/") of the domain
39
- # For example, if the application is served at "https://example.com/myapp"
40
- f5-tts_infer-gradio --root_path "/myapp"
41
- ```
42
-
43
- Could also be used as a component for larger application:
44
  ```python
45
  import gradio as gr
46
  from f5_tts.infer.infer_gradio import app
@@ -68,16 +56,17 @@ Basically you can inference with flags:
68
  ```bash
69
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
70
  f5-tts_infer-cli \
71
- --model F5TTS_v1_Base \
72
  --ref_audio "ref_audio.wav" \
73
- --ref_text "The content, subtitle or transcription of reference audio." \
74
- --gen_text "Some text you want TTS model generate for you."
 
 
 
75
 
76
- # Use BigVGAN as vocoder. Currently only support F5TTS_Base.
77
- f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
78
-
79
- # Use custom path checkpoint, e.g.
80
- f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
81
 
82
  # More instructions
83
  f5-tts_infer-cli --help
@@ -92,8 +81,8 @@ f5-tts_infer-cli -c custom.toml
92
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
93
 
94
  ```toml
95
- # F5TTS_v1_Base | E2TTS_Base
96
- model = "F5TTS_v1_Base"
97
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
98
  # If an empty "", transcribes the reference audio automatically.
99
  ref_text = "Some call me nature, others call me mother nature."
@@ -107,8 +96,8 @@ output_dir = "tests"
107
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
108
 
109
  ```toml
110
- # F5TTS_v1_Base | E2TTS_Base
111
- model = "F5TTS_v1_Base"
112
  ref_audio = "infer/examples/multi/main.flac"
113
  # If an empty "", transcribes the reference audio automatically.
114
  ref_text = ""
@@ -128,27 +117,83 @@ ref_text = ""
128
  ```
129
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
130
 
131
- ## Socket Real-time Service
132
 
133
- Real-time voice output with chunk stream:
134
 
135
  ```bash
136
- # Start socket server
137
- python src/f5_tts/socket_server.py
138
-
139
- # If PyAudio not installed
140
- sudo apt-get install portaudio19-dev
141
- pip install pyaudio
142
-
143
- # Communicate with socket client
144
- python src/f5_tts/socket_client.py
145
  ```
146
 
147
- ## Speech Editing
148
-
149
- To test speech editing capabilities, use the following command:
150
 
 
151
  ```bash
152
- python src/f5_tts/infer/speech_edit.py
153
  ```
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  - Basic TTS with Chunk Inference
24
  - Multi-Style / Multi-Speaker Generation
25
  - Voice Chat powered by Qwen2.5-3B-Instruct
 
26
 
27
  The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
28
 
29
  The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
30
 
31
+ Could also be used as a component for larger application.
 
 
 
 
 
 
 
 
 
 
 
32
  ```python
33
  import gradio as gr
34
  from f5_tts.infer.infer_gradio import app
 
56
  ```bash
57
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
58
  f5-tts_infer-cli \
59
+ --model "F5-TTS" \
60
  --ref_audio "ref_audio.wav" \
61
+ --ref_text "hình ảnh cực đoan trong em_vi của sơn tùng mờ thành phố bị khán giả chỉ trích" \
62
+ --gen_text "tôi yêu em đến nay chừng thể, ngọn lửa tình chưa hẳn đã tàn phai." \
63
+ --vocoder_name vocos \
64
+ --load_vocoder_from_local \
65
+ --ckpt_file ckpts/F5TTS_Base_vocos_char_vnTTS/model_last.pt
66
 
67
+ # Choose Vocoder
68
+ f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
69
+ f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
70
 
71
  # More instructions
72
  f5-tts_infer-cli --help
 
81
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
82
 
83
  ```toml
84
+ # F5-TTS | E2-TTS
85
+ model = "F5-TTS"
86
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
87
  # If an empty "", transcribes the reference audio automatically.
88
  ref_text = "Some call me nature, others call me mother nature."
 
96
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
97
 
98
  ```toml
99
+ # F5-TTS | E2-TTS
100
+ model = "F5-TTS"
101
  ref_audio = "infer/examples/multi/main.flac"
102
  # If an empty "", transcribes the reference audio automatically.
103
  ref_text = ""
 
117
  ```
118
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
119
 
120
+ ## Speech Editing
121
 
122
+ To test speech editing capabilities, use the following command:
123
 
124
  ```bash
125
+ python src/f5_tts/infer/speech_edit.py
 
 
 
 
 
 
 
 
126
  ```
127
 
128
+ ## Socket Realtime Client
 
 
129
 
130
+ To communicate with socket server you need to run
131
  ```bash
132
+ python src/f5_tts/socket_server.py
133
  ```
134
 
135
+ <details>
136
+ <summary>Then create client to communicate</summary>
137
+
138
+ ``` python
139
+ import socket
140
+ import numpy as np
141
+ import asyncio
142
+ import pyaudio
143
+
144
+ async def listen_to_voice(text, server_ip='localhost', server_port=9999):
145
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
146
+ client_socket.connect((server_ip, server_port))
147
+
148
+ async def play_audio_stream():
149
+ buffer = b''
150
+ p = pyaudio.PyAudio()
151
+ stream = p.open(format=pyaudio.paFloat32,
152
+ channels=1,
153
+ rate=24000, # Ensure this matches the server's sampling rate
154
+ output=True,
155
+ frames_per_buffer=2048)
156
+
157
+ try:
158
+ while True:
159
+ chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
160
+ if not chunk: # End of stream
161
+ break
162
+ if b"END_OF_AUDIO" in chunk:
163
+ buffer += chunk.replace(b"END_OF_AUDIO", b"")
164
+ if buffer:
165
+ audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
166
+ stream.write(audio_array.tobytes())
167
+ break
168
+ buffer += chunk
169
+ if len(buffer) >= 4096:
170
+ audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
171
+ stream.write(audio_array.tobytes())
172
+ buffer = buffer[4096:]
173
+ finally:
174
+ stream.stop_stream()
175
+ stream.close()
176
+ p.terminate()
177
+
178
+ try:
179
+ # Send only the text to the server
180
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
181
+ await play_audio_stream()
182
+ print("Audio playback finished.")
183
+
184
+ except Exception as e:
185
+ print(f"Error in listen_to_voice: {e}")
186
+
187
+ finally:
188
+ client_socket.close()
189
+
190
+ # Example usage: Replace this with your actual server IP and port
191
+ async def main():
192
+ await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
193
+
194
+ # Run the main async function
195
+ asyncio.run(main())
196
+ ```
197
+
198
+ </details>
199
+
f5_tts/infer/SHARED.md CHANGED
@@ -16,7 +16,7 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
@@ -37,17 +37,7 @@
37
 
38
  ## Multilingual
39
 
40
- #### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
41
- |Model|🤗Hugging Face|Data (Hours)|Model License|
42
- |:---:|:------------:|:-----------:|:-------------:|
43
- |F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
44
-
45
- ```bash
46
- Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
47
- Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
48
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
- ```
50
-
51
  |Model|🤗Hugging Face|Data (Hours)|Model License|
52
  |:---:|:------------:|:-----------:|:-------------:|
53
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
@@ -55,7 +45,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
55
  ```bash
56
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
57
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
58
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
59
  ```
60
 
61
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -74,7 +64,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
74
  ```bash
75
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
76
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
77
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
78
  ```
79
 
80
 
@@ -88,7 +78,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
88
  ```bash
89
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
90
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
91
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
92
  ```
93
 
94
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -106,7 +96,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
106
  ```bash
107
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
108
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
109
- Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
110
  ```
111
 
112
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
@@ -123,7 +113,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "t
123
  ```bash
124
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
125
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
126
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
127
  ```
128
 
129
  - Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -141,7 +131,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
141
  ```bash
142
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
143
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
144
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
145
  ```
146
 
147
 
@@ -158,7 +148,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
158
  ```bash
159
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
160
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
161
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
162
  ```
163
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
164
  - Any improvements are welcome
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
 
37
 
38
  ## Multilingual
39
 
40
+ #### F5-TTS Base @ zh & en @ F5-TTS
 
 
 
 
 
 
 
 
 
 
41
  |Model|🤗Hugging Face|Data (Hours)|Model License|
42
  |:---:|:------------:|:-----------:|:-------------:|
43
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
 
45
  ```bash
46
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
47
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
48
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
  ```
50
 
51
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
 
64
  ```bash
65
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
66
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
67
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
68
  ```
69
 
70
 
 
78
  ```bash
79
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
80
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
81
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
82
  ```
83
 
84
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
 
96
  ```bash
97
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
98
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
99
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
100
  ```
101
 
102
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
 
113
  ```bash
114
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
115
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
116
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
117
  ```
118
 
119
  - Trained by [Mithril Man](https://github.com/MithrilMan)
 
131
  ```bash
132
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
133
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
134
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
135
  ```
136
 
137
 
 
148
  ```bash
149
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
150
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
151
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
152
  ```
153
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
154
  - Any improvements are welcome
f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc CHANGED
Binary files a/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc and b/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc differ
 
f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc CHANGED
Binary files a/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc and b/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc differ
 
f5_tts/infer/examples/basic/basic.toml CHANGED
@@ -1,5 +1,5 @@
1
- # F5TTS_v1_Base | E2TTS_Base
2
- model = "F5TTS_v1_Base"
3
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
- output_file = "infer_cli_basic.wav"
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
 
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
+ output_file = "infer_cli_basic.wav"
f5_tts/infer/examples/multi/story.toml CHANGED
@@ -1,5 +1,5 @@
1
- # F5TTS_v1_Base | E2TTS_Base
2
- model = "F5TTS_v1_Base"
3
  ref_audio = "infer/examples/multi/main.flac"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = ""
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
  ref_audio = "infer/examples/multi/main.flac"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = ""
f5_tts/infer/infer_cli.py CHANGED
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
- from f5_tts.model import DiT, UNetT # noqa: F401. used for config
31
 
32
 
33
  parser = argparse.ArgumentParser(
@@ -50,8 +50,7 @@ parser.add_argument(
50
  "-m",
51
  "--model",
52
  type=str,
53
- default="F5TTS_Base",
54
- help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
55
  )
56
  parser.add_argument(
57
  "-mc",
@@ -173,7 +172,8 @@ config = tomli.load(open(args.config, "rb"))
173
 
174
  # command-line interface parameters
175
 
176
- model = args.model or config.get("model", "F5TTS_v1_Base")
 
177
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
  vocab_file = args.vocab_file or config.get("vocab_file", "")
179
 
@@ -236,7 +236,7 @@ if save_chunk:
236
  # load vocoder
237
 
238
  if vocoder_name == "vocos":
239
- vocoder_local_path = "../checkpoints/vocos-mel-24khz"
240
  elif vocoder_name == "bigvgan":
241
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
242
 
@@ -245,32 +245,37 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
245
 
246
  # load TTS model
247
 
248
- model_cfg = OmegaConf.load(
249
- args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
250
- ).model
251
- model_cls = globals()[model_cfg.backbone]
252
-
253
- repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
254
-
255
- if model != "F5TTS_Base":
256
- assert vocoder_name == model_cfg.mel_spec.mel_spec_type
257
-
258
- # override for previous models
259
- if model == "F5TTS_Base":
260
- if vocoder_name == "vocos":
 
 
 
 
 
 
 
 
 
 
 
 
261
  ckpt_step = 1200000
262
- elif vocoder_name == "bigvgan":
263
- model = "F5TTS_Base_bigvgan"
264
- ckpt_type = "pt"
265
- elif model == "E2TTS_Base":
266
- repo_name = "E2-TTS"
267
- ckpt_step = 1200000
268
-
269
- if not ckpt_file:
270
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
271
 
272
  print(f"Using {model}...")
273
- ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
274
 
275
 
276
  # inference process
 
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
+ from f5_tts.model import DiT, UNetT
31
 
32
 
33
  parser = argparse.ArgumentParser(
 
50
  "-m",
51
  "--model",
52
  type=str,
53
+ help="The model name: F5-TTS | E2-TTS",
 
54
  )
55
  parser.add_argument(
56
  "-mc",
 
172
 
173
  # command-line interface parameters
174
 
175
+ model = args.model or config.get("model", "F5-TTS")
176
+ model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
  vocab_file = args.vocab_file or config.get("vocab_file", "")
179
 
 
236
  # load vocoder
237
 
238
  if vocoder_name == "vocos":
239
+ vocoder_local_path = "ckpts/vocos"
240
  elif vocoder_name == "bigvgan":
241
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
242
 
 
245
 
246
  # load TTS model
247
 
248
+ if model == "F5-TTS":
249
+ model_cls = DiT
250
+ model_cfg = OmegaConf.load(model_cfg).model.arch
251
+ if not ckpt_file: # path not specified, download from repo
252
+ if vocoder_name == "vocos":
253
+ repo_name = "F5-TTS"
254
+ exp_name = "F5TTS_Base"
255
+ ckpt_step = 1200000
256
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
257
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
258
+ # ckpt_file = f"ckpts/{exp_name}/model_last.pt" # .pt | .safetensors; local path
259
+ elif vocoder_name == "bigvgan":
260
+ repo_name = "F5-TTS"
261
+ exp_name = "F5TTS_Base_bigvgan"
262
+ ckpt_step = 1250000
263
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
264
+
265
+ elif model == "E2-TTS":
266
+ assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
267
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
268
+ model_cls = UNetT
269
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
270
+ if not ckpt_file: # path not specified, download from repo
271
+ repo_name = "E2-TTS"
272
+ exp_name = "E2TTS_Base"
273
  ckpt_step = 1200000
274
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
275
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
 
 
 
 
 
 
276
 
277
  print(f"Using {model}...")
278
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
279
 
280
 
281
  # inference process
f5_tts/infer/infer_gradio.py CHANGED
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
41
  )
42
 
43
 
44
- DEFAULT_TTS_MODEL = "F5-TTS_v1"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
- "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
49
- "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
@@ -56,15 +56,13 @@ DEFAULT_TTS_MODEL_CFG = [
56
  vocoder = load_vocoder()
57
 
58
 
59
- def load_f5tts():
60
- ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
61
- F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
62
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
63
 
64
 
65
- def load_e2tts():
66
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
68
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
69
 
70
 
@@ -75,7 +73,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
75
  if vocab_path.startswith("hf://"):
76
  vocab_path = str(cached_path(vocab_path))
77
  if model_cfg is None:
78
- model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
79
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
80
 
81
 
@@ -132,7 +130,7 @@ def infer(
132
 
133
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
134
 
135
- if model == DEFAULT_TTS_MODEL:
136
  ema_model = F5TTS_ema_model
137
  elif model == "E2-TTS":
138
  global E2TTS_ema_model
@@ -764,7 +762,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
764
  """
765
  )
766
 
767
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
768
 
769
  def load_last_used_custom():
770
  try:
@@ -823,30 +821,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
823
  custom_model_cfg = gr.Dropdown(
824
  choices=[
825
  DEFAULT_TTS_MODEL_CFG[2],
826
- json.dumps(
827
- dict(
828
- dim=1024,
829
- depth=22,
830
- heads=16,
831
- ff_mult=2,
832
- text_dim=512,
833
- text_mask_padding=False,
834
- conv_layers=4,
835
- pe_attn_head=1,
836
- )
837
- ),
838
- json.dumps(
839
- dict(
840
- dim=768,
841
- depth=18,
842
- heads=12,
843
- ff_mult=2,
844
- text_dim=512,
845
- text_mask_padding=False,
846
- conv_layers=4,
847
- pe_attn_head=1,
848
- )
849
- ),
850
  ],
851
  value=load_last_used_custom()[2],
852
  allow_custom_value=True,
@@ -900,24 +875,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
900
  type=str,
901
  help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
902
  )
903
- @click.option(
904
- "--inbrowser",
905
- "-i",
906
- is_flag=True,
907
- default=False,
908
- help="Automatically launch the interface in the default web browser",
909
- )
910
- def main(port, host, share, api, root_path, inbrowser):
911
  global app
912
  print("Starting app...")
913
- app.queue(api_open=api).launch(
914
- server_name=host,
915
- server_port=port,
916
- share=share,
917
- show_api=api,
918
- root_path=root_path,
919
- inbrowser=inbrowser,
920
- )
921
 
922
 
923
  if __name__ == "__main__":
 
41
  )
42
 
43
 
44
+ DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
 
56
  vocoder = load_vocoder()
57
 
58
 
59
+ def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
60
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
61
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
62
 
63
 
64
+ def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
65
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
66
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
67
 
68
 
 
73
  if vocab_path.startswith("hf://"):
74
  vocab_path = str(cached_path(vocab_path))
75
  if model_cfg is None:
76
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
77
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
78
 
79
 
 
130
 
131
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
132
 
133
+ if model == "F5-TTS":
134
  ema_model = F5TTS_ema_model
135
  elif model == "E2-TTS":
136
  global E2TTS_ema_model
 
762
  """
763
  )
764
 
765
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
766
 
767
  def load_last_used_custom():
768
  try:
 
821
  custom_model_cfg = gr.Dropdown(
822
  choices=[
823
  DEFAULT_TTS_MODEL_CFG[2],
824
+ json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  ],
826
  value=load_last_used_custom()[2],
827
  allow_custom_value=True,
 
875
  type=str,
876
  help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
877
  )
878
+ def main(port, host, share, api, root_path):
 
 
 
 
 
 
 
879
  global app
880
  print("Starting app...")
881
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path)
 
 
 
 
 
 
 
882
 
883
 
884
  if __name__ == "__main__":
f5_tts/infer/speech_edit.py CHANGED
@@ -1,16 +1,13 @@
1
  import os
2
 
3
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
-
5
- from importlib.resources import files
6
 
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
- from omegaconf import OmegaConf
11
 
12
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
13
- from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
14
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
15
 
16
  device = (
@@ -24,40 +21,44 @@ device = (
24
  )
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ---------------------- infer setting ---------------------- #
28
 
29
  seed = None # int | None
30
 
31
- exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
32
- ckpt_step = 1250000
33
 
34
  nfe_step = 32 # 16, 32
35
  cfg_strength = 2.0
36
  ode_method = "euler" # euler | midpoint
37
  sway_sampling_coef = -1.0
38
  speed = 1.0
39
- target_rms = 0.1
40
-
41
-
42
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
43
- model_cls = globals()[model_cfg.model.backbone]
44
- model_arc = model_cfg.model.arch
45
 
46
- dataset_name = model_cfg.datasets.name
47
- tokenizer = model_cfg.model.tokenizer
 
48
 
49
- mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
50
- target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
51
- n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
52
- hop_length = model_cfg.model.mel_spec.hop_length
53
- win_length = model_cfg.model.mel_spec.win_length
54
- n_fft = model_cfg.model.mel_spec.n_fft
55
 
56
-
57
- ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
58
  output_dir = "tests"
59
 
60
-
61
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
62
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
63
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
@@ -66,7 +67,7 @@ output_dir = "tests"
66
  # [--language "zho" for Chinese, "eng" for English]
67
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
68
 
69
- audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
70
  origin_text = "Some call me nature, others call me mother nature."
71
  target_text = "Some call me optimist, others call me realist."
72
  parts_to_edit = [
@@ -105,7 +106,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
105
 
106
  # Model
107
  model = CFM(
108
- transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
109
  mel_spec_kwargs=dict(
110
  n_fft=n_fft,
111
  hop_length=hop_length,
 
1
  import os
2
 
3
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
 
 
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  import torchaudio
 
8
 
9
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
10
+ from f5_tts.model import CFM, DiT, UNetT
11
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
12
 
13
  device = (
 
21
  )
22
 
23
 
24
+ # --------------------- Dataset Settings -------------------- #
25
+
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ win_length = 1024
30
+ n_fft = 1024
31
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
32
+ target_rms = 0.1
33
+
34
+ tokenizer = "pinyin"
35
+ dataset_name = "Emilia_ZH_EN"
36
+
37
+
38
  # ---------------------- infer setting ---------------------- #
39
 
40
  seed = None # int | None
41
 
42
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
43
+ ckpt_step = 1200000
44
 
45
  nfe_step = 32 # 16, 32
46
  cfg_strength = 2.0
47
  ode_method = "euler" # euler | midpoint
48
  sway_sampling_coef = -1.0
49
  speed = 1.0
 
 
 
 
 
 
50
 
51
+ if exp_name == "F5TTS_Base":
52
+ model_cls = DiT
53
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
54
 
55
+ elif exp_name == "E2TTS_Base":
56
+ model_cls = UNetT
57
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
58
 
59
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
 
60
  output_dir = "tests"
61
 
 
62
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
63
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
64
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
 
67
  # [--language "zho" for Chinese, "eng" for English]
68
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
69
 
70
+ audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
71
  origin_text = "Some call me nature, others call me mother nature."
72
  target_text = "Some call me optimist, others call me realist."
73
  parts_to_edit = [
 
106
 
107
  # Model
108
  model = CFM(
109
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
110
  mel_spec_kwargs=dict(
111
  n_fft=n_fft,
112
  hop_length=hop_length,
f5_tts/infer/utils_infer.py CHANGED
@@ -2,9 +2,8 @@
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
  import os
4
  import sys
5
- from concurrent.futures import ThreadPoolExecutor
6
 
7
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
8
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
9
 
10
  import hashlib
@@ -110,8 +109,13 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
110
  repo_id = "charactr/vocos-mel-24khz"
111
  config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
112
  model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
 
 
 
 
113
  vocoder = Vocos.from_hparams(config_path)
114
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
 
115
  from vocos.feature_extractors import EncodecFeatures
116
 
117
  if isinstance(vocoder.feature_extractor, EncodecFeatures):
@@ -301,19 +305,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
301
  )
302
  non_silent_wave = AudioSegment.silent(duration=0)
303
  for non_silent_seg in non_silent_segs:
304
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
305
  show_info("Audio is over 15s, clipping short. (1)")
306
  break
307
  non_silent_wave += non_silent_seg
308
 
309
  # 2. try to find short silence for clipping if 1. failed
310
- if len(non_silent_wave) > 12000:
311
  non_silent_segs = silence.split_on_silence(
312
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
313
  )
314
  non_silent_wave = AudioSegment.silent(duration=0)
315
  for non_silent_seg in non_silent_segs:
316
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
317
  show_info("Audio is over 15s, clipping short. (2)")
318
  break
319
  non_silent_wave += non_silent_seg
@@ -321,8 +325,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
321
  aseg = non_silent_wave
322
 
323
  # 3. if no proper silence found for clipping
324
- if len(aseg) > 12000:
325
- aseg = aseg[:12000]
326
  show_info("Audio is over 15s, clipping short. (3)")
327
 
328
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
@@ -383,31 +387,29 @@ def infer_process(
383
  ):
384
  # Split the input text into batches
385
  audio, sr = torchaudio.load(ref_audio)
386
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
387
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
388
  for i, gen_text in enumerate(gen_text_batches):
389
  print(f"gen_text {i}", gen_text)
390
  print("\n")
391
 
392
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
393
- return next(
394
- infer_batch_process(
395
- (audio, sr),
396
- ref_text,
397
- gen_text_batches,
398
- model_obj,
399
- vocoder,
400
- mel_spec_type=mel_spec_type,
401
- progress=progress,
402
- target_rms=target_rms,
403
- cross_fade_duration=cross_fade_duration,
404
- nfe_step=nfe_step,
405
- cfg_strength=cfg_strength,
406
- sway_sampling_coef=sway_sampling_coef,
407
- speed=speed,
408
- fix_duration=fix_duration,
409
- device=device,
410
- )
411
  )
412
 
413
 
@@ -430,8 +432,6 @@ def infer_batch_process(
430
  speed=1,
431
  fix_duration=None,
432
  device=None,
433
- streaming=False,
434
- chunk_size=2048,
435
  ):
436
  audio, sr = ref_audio
437
  if audio.shape[0] > 1:
@@ -450,12 +450,7 @@ def infer_batch_process(
450
 
451
  if len(ref_text[-1].encode("utf-8")) == 1:
452
  ref_text = ref_text + " "
453
-
454
- def process_batch(gen_text):
455
- local_speed = speed
456
- if len(gen_text.encode("utf-8")) < 10:
457
- local_speed = 0.3
458
-
459
  # Prepare the text
460
  text_list = [ref_text + gen_text]
461
  final_text_list = convert_char_to_pinyin(text_list)
@@ -467,7 +462,7 @@ def infer_batch_process(
467
  # Calculate duration
468
  ref_text_len = len(ref_text.encode("utf-8"))
469
  gen_text_len = len(gen_text.encode("utf-8"))
470
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
471
 
472
  # inference
473
  with torch.inference_mode():
@@ -479,88 +474,64 @@ def infer_batch_process(
479
  cfg_strength=cfg_strength,
480
  sway_sampling_coef=sway_sampling_coef,
481
  )
482
- del _
483
 
484
- generated = generated.to(torch.float32) # generated mel spectrogram
485
  generated = generated[:, ref_audio_len:, :]
486
- generated = generated.permute(0, 2, 1)
487
  if mel_spec_type == "vocos":
488
- generated_wave = vocoder.decode(generated)
489
  elif mel_spec_type == "bigvgan":
490
- generated_wave = vocoder(generated)
491
  if rms < target_rms:
492
  generated_wave = generated_wave * rms / target_rms
493
 
494
  # wav -> numpy
495
  generated_wave = generated_wave.squeeze().cpu().numpy()
496
 
497
- if streaming:
498
- for j in range(0, len(generated_wave), chunk_size):
499
- yield generated_wave[j : j + chunk_size], target_sample_rate
500
- else:
501
- generated_cpu = generated[0].cpu().numpy()
502
- del generated
503
- yield generated_wave, generated_cpu
504
-
505
- if streaming:
506
- for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
507
- for chunk in process_batch(gen_text):
508
- yield chunk
509
  else:
510
- with ThreadPoolExecutor() as executor:
511
- futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
512
- for future in progress.tqdm(futures) if progress is not None else futures:
513
- result = future.result()
514
- if result:
515
- generated_wave, generated_mel_spec = next(result)
516
- generated_waves.append(generated_wave)
517
- spectrograms.append(generated_mel_spec)
518
-
519
- if generated_waves:
520
- if cross_fade_duration <= 0:
521
- # Simply concatenate
522
- final_wave = np.concatenate(generated_waves)
523
- else:
524
- # Combine all generated waves with cross-fading
525
- final_wave = generated_waves[0]
526
- for i in range(1, len(generated_waves)):
527
- prev_wave = final_wave
528
- next_wave = generated_waves[i]
529
-
530
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
531
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
532
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
533
-
534
- if cross_fade_samples <= 0:
535
- # No overlap possible, concatenate
536
- final_wave = np.concatenate([prev_wave, next_wave])
537
- continue
538
-
539
- # Overlapping parts
540
- prev_overlap = prev_wave[-cross_fade_samples:]
541
- next_overlap = next_wave[:cross_fade_samples]
542
-
543
- # Fade out and fade in
544
- fade_out = np.linspace(1, 0, cross_fade_samples)
545
- fade_in = np.linspace(0, 1, cross_fade_samples)
546
-
547
- # Cross-faded overlap
548
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
549
-
550
- # Combine
551
- new_wave = np.concatenate(
552
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
553
- )
554
-
555
- final_wave = new_wave
556
-
557
- # Create a combined spectrogram
558
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
559
-
560
- yield final_wave, target_sample_rate, combined_spectrogram
561
 
562
- else:
563
- yield None, target_sample_rate, None
 
 
 
 
564
 
565
 
566
  # remove silence from generated wav
 
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
  import os
4
  import sys
 
5
 
6
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
7
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
8
 
9
  import hashlib
 
109
  repo_id = "charactr/vocos-mel-24khz"
110
  config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
111
  model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
112
+ # print("Download Vocos from huggingface charactr/vocos-mel-24khz")
113
+ # repo_id = "charactr/vocos-mel-24khz"
114
+ # config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
115
+ # model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
116
  vocoder = Vocos.from_hparams(config_path)
117
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
118
+ # print(state_dict)
119
  from vocos.feature_extractors import EncodecFeatures
120
 
121
  if isinstance(vocoder.feature_extractor, EncodecFeatures):
 
305
  )
306
  non_silent_wave = AudioSegment.silent(duration=0)
307
  for non_silent_seg in non_silent_segs:
308
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
309
  show_info("Audio is over 15s, clipping short. (1)")
310
  break
311
  non_silent_wave += non_silent_seg
312
 
313
  # 2. try to find short silence for clipping if 1. failed
314
+ if len(non_silent_wave) > 15000:
315
  non_silent_segs = silence.split_on_silence(
316
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
317
  )
318
  non_silent_wave = AudioSegment.silent(duration=0)
319
  for non_silent_seg in non_silent_segs:
320
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
321
  show_info("Audio is over 15s, clipping short. (2)")
322
  break
323
  non_silent_wave += non_silent_seg
 
325
  aseg = non_silent_wave
326
 
327
  # 3. if no proper silence found for clipping
328
+ if len(aseg) > 15000:
329
+ aseg = aseg[:15000]
330
  show_info("Audio is over 15s, clipping short. (3)")
331
 
332
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
 
387
  ):
388
  # Split the input text into batches
389
  audio, sr = torchaudio.load(ref_audio)
390
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
391
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
392
  for i, gen_text in enumerate(gen_text_batches):
393
  print(f"gen_text {i}", gen_text)
394
  print("\n")
395
 
396
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
397
+ return infer_batch_process(
398
+ (audio, sr),
399
+ ref_text,
400
+ gen_text_batches,
401
+ model_obj,
402
+ vocoder,
403
+ mel_spec_type=mel_spec_type,
404
+ progress=progress,
405
+ target_rms=target_rms,
406
+ cross_fade_duration=cross_fade_duration,
407
+ nfe_step=nfe_step,
408
+ cfg_strength=cfg_strength,
409
+ sway_sampling_coef=sway_sampling_coef,
410
+ speed=speed,
411
+ fix_duration=fix_duration,
412
+ device=device,
 
 
413
  )
414
 
415
 
 
432
  speed=1,
433
  fix_duration=None,
434
  device=None,
 
 
435
  ):
436
  audio, sr = ref_audio
437
  if audio.shape[0] > 1:
 
450
 
451
  if len(ref_text[-1].encode("utf-8")) == 1:
452
  ref_text = ref_text + " "
453
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
 
 
 
 
 
454
  # Prepare the text
455
  text_list = [ref_text + gen_text]
456
  final_text_list = convert_char_to_pinyin(text_list)
 
462
  # Calculate duration
463
  ref_text_len = len(ref_text.encode("utf-8"))
464
  gen_text_len = len(gen_text.encode("utf-8"))
465
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
466
 
467
  # inference
468
  with torch.inference_mode():
 
474
  cfg_strength=cfg_strength,
475
  sway_sampling_coef=sway_sampling_coef,
476
  )
 
477
 
478
+ generated = generated.to(torch.float32)
479
  generated = generated[:, ref_audio_len:, :]
480
+ generated_mel_spec = generated.permute(0, 2, 1)
481
  if mel_spec_type == "vocos":
482
+ generated_wave = vocoder.decode(generated_mel_spec)
483
  elif mel_spec_type == "bigvgan":
484
+ generated_wave = vocoder(generated_mel_spec)
485
  if rms < target_rms:
486
  generated_wave = generated_wave * rms / target_rms
487
 
488
  # wav -> numpy
489
  generated_wave = generated_wave.squeeze().cpu().numpy()
490
 
491
+ generated_waves.append(generated_wave)
492
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
493
+
494
+ # Combine all generated waves with cross-fading
495
+ if cross_fade_duration <= 0:
496
+ # Simply concatenate
497
+ final_wave = np.concatenate(generated_waves)
 
 
 
 
 
498
  else:
499
+ final_wave = generated_waves[0]
500
+ for i in range(1, len(generated_waves)):
501
+ prev_wave = final_wave
502
+ next_wave = generated_waves[i]
503
+
504
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
505
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
506
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
507
+
508
+ if cross_fade_samples <= 0:
509
+ # No overlap possible, concatenate
510
+ final_wave = np.concatenate([prev_wave, next_wave])
511
+ continue
512
+
513
+ # Overlapping parts
514
+ prev_overlap = prev_wave[-cross_fade_samples:]
515
+ next_overlap = next_wave[:cross_fade_samples]
516
+
517
+ # Fade out and fade in
518
+ fade_out = np.linspace(1, 0, cross_fade_samples)
519
+ fade_in = np.linspace(0, 1, cross_fade_samples)
520
+
521
+ # Cross-faded overlap
522
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
523
+
524
+ # Combine
525
+ new_wave = np.concatenate(
526
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
527
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
+ final_wave = new_wave
530
+
531
+ # Create a combined spectrogram
532
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
533
+
534
+ return final_wave, target_sample_rate, combined_spectrogram
535
 
536
 
537
  # remove silence from generated wav
f5_tts/model/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/__init__.cpython-310.pyc and b/f5_tts/model/__pycache__/__init__.cpython-310.pyc differ
 
f5_tts/model/__pycache__/cfm.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/cfm.cpython-310.pyc and b/f5_tts/model/__pycache__/cfm.cpython-310.pyc differ
 
f5_tts/model/__pycache__/dataset.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/dataset.cpython-310.pyc and b/f5_tts/model/__pycache__/dataset.cpython-310.pyc differ
 
f5_tts/model/__pycache__/modules.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/modules.cpython-310.pyc and b/f5_tts/model/__pycache__/modules.cpython-310.pyc differ
 
f5_tts/model/__pycache__/trainer.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/trainer.cpython-310.pyc and b/f5_tts/model/__pycache__/trainer.cpython-310.pyc differ
 
f5_tts/model/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/__pycache__/utils.cpython-310.pyc and b/f5_tts/model/__pycache__/utils.cpython-310.pyc differ
 
f5_tts/model/backbones/README.md CHANGED
@@ -4,7 +4,7 @@
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
- - possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
@@ -14,7 +14,7 @@
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
- - stable diffusion 3 block structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
 
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
 
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
+ - sd3 structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc differ
 
f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc differ
 
f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc CHANGED
Binary files a/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc differ
 
f5_tts/model/backbones/dit.py CHANGED
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
- AdaLayerNorm_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
@@ -30,12 +30,10 @@ from f5_tts.model.modules import (
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
37
- self.mask_padding = mask_padding # mask filler and batch padding tokens or not
38
-
39
  if conv_layers > 0:
40
  self.extra_modeling = True
41
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -51,8 +49,6 @@ class TextEmbedding(nn.Module):
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
  batch, text_len = text.shape[0], text.shape[1]
53
  text = F.pad(text, (0, seq_len - text_len), value=0)
54
- if self.mask_padding:
55
- text_mask = text == 0
56
 
57
  if drop_text: # cfg for text
58
  text = torch.zeros_like(text)
@@ -68,13 +64,7 @@ class TextEmbedding(nn.Module):
68
  text = text + text_pos_embed
69
 
70
  # convnextv2 blocks
71
- if self.mask_padding:
72
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
73
- for block in self.text_blocks:
74
- text = block(text)
75
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
- else:
77
- text = self.text_blocks(text)
78
 
79
  return text
80
 
@@ -113,10 +103,7 @@ class DiT(nn.Module):
113
  mel_dim=100,
114
  text_num_embeds=256,
115
  text_dim=None,
116
- text_mask_padding=True,
117
- qk_norm=None,
118
  conv_layers=0,
119
- pe_attn_head=None,
120
  long_skip_connection=False,
121
  checkpoint_activations=False,
122
  ):
@@ -125,10 +112,7 @@ class DiT(nn.Module):
125
  self.time_embed = TimestepEmbedding(dim)
126
  if text_dim is None:
127
  text_dim = mel_dim
128
- self.text_embed = TextEmbedding(
129
- text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
130
- )
131
- self.text_cond, self.text_uncond = None, None # text cache
132
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
133
 
134
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -137,40 +121,15 @@ class DiT(nn.Module):
137
  self.depth = depth
138
 
139
  self.transformer_blocks = nn.ModuleList(
140
- [
141
- DiTBlock(
142
- dim=dim,
143
- heads=heads,
144
- dim_head=dim_head,
145
- ff_mult=ff_mult,
146
- dropout=dropout,
147
- qk_norm=qk_norm,
148
- pe_attn_head=pe_attn_head,
149
- )
150
- for _ in range(depth)
151
- ]
152
  )
153
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
154
 
155
- self.norm_out = AdaLayerNorm_Final(dim) # final modulation
156
  self.proj_out = nn.Linear(dim, mel_dim)
157
 
158
  self.checkpoint_activations = checkpoint_activations
159
 
160
- self.initialize_weights()
161
-
162
- def initialize_weights(self):
163
- # Zero-out AdaLN layers in DiT blocks:
164
- for block in self.transformer_blocks:
165
- nn.init.constant_(block.attn_norm.linear.weight, 0)
166
- nn.init.constant_(block.attn_norm.linear.bias, 0)
167
-
168
- # Zero-out output layers:
169
- nn.init.constant_(self.norm_out.linear.weight, 0)
170
- nn.init.constant_(self.norm_out.linear.bias, 0)
171
- nn.init.constant_(self.proj_out.weight, 0)
172
- nn.init.constant_(self.proj_out.bias, 0)
173
-
174
  def ckpt_wrapper(self, module):
175
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
176
  def ckpt_forward(*inputs):
@@ -179,9 +138,6 @@ class DiT(nn.Module):
179
 
180
  return ckpt_forward
181
 
182
- def clear_cache(self):
183
- self.text_cond, self.text_uncond = None, None
184
-
185
  def forward(
186
  self,
187
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -191,25 +147,14 @@ class DiT(nn.Module):
191
  drop_audio_cond, # cfg for cond audio
192
  drop_text, # cfg for text
193
  mask: bool["b n"] | None = None, # noqa: F722
194
- cache=False,
195
  ):
196
  batch, seq_len = x.shape[0], x.shape[1]
197
  if time.ndim == 0:
198
  time = time.repeat(batch)
199
 
200
- # t: conditioning time, text: text, x: noised audio + cond audio + text
201
  t = self.time_embed(time)
202
- if cache:
203
- if drop_text:
204
- if self.text_uncond is None:
205
- self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
206
- text_embed = self.text_uncond
207
- else:
208
- if self.text_cond is None:
209
- self.text_cond = self.text_embed(text, seq_len, drop_text=False)
210
- text_embed = self.text_cond
211
- else:
212
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
213
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
214
 
215
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
 
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
+ AdaLayerNormZero_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
 
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
 
37
  if conv_layers > 0:
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
  batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
 
64
  text = text + text_pos_embed
65
 
66
  # convnextv2 blocks
67
+ text = self.text_blocks(text)
 
 
 
 
 
 
68
 
69
  return text
70
 
 
103
  mel_dim=100,
104
  text_num_embeds=256,
105
  text_dim=None,
 
 
106
  conv_layers=0,
 
107
  long_skip_connection=False,
108
  checkpoint_activations=False,
109
  ):
 
112
  self.time_embed = TimestepEmbedding(dim)
113
  if text_dim is None:
114
  text_dim = mel_dim
115
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
116
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
117
 
118
  self.rotary_embed = RotaryEmbedding(dim_head)
 
121
  self.depth = depth
122
 
123
  self.transformer_blocks = nn.ModuleList(
124
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
127
 
128
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
  self.checkpoint_activations = checkpoint_activations
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def ckpt_wrapper(self, module):
134
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
135
  def ckpt_forward(*inputs):
 
138
 
139
  return ckpt_forward
140
 
 
 
 
141
  def forward(
142
  self,
143
  x: float["b n d"], # nosied input audio # noqa: F722
 
147
  drop_audio_cond, # cfg for cond audio
148
  drop_text, # cfg for text
149
  mask: bool["b n"] | None = None, # noqa: F722
 
150
  ):
151
  batch, seq_len = x.shape[0], x.shape[1]
152
  if time.ndim == 0:
153
  time = time.repeat(batch)
154
 
155
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
156
  t = self.time_embed(time)
157
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
158
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
159
 
160
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
f5_tts/model/backbones/mmdit.py CHANGED
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
- AdaLayerNorm_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
@@ -28,24 +28,18 @@ from f5_tts.model.modules import (
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
- def __init__(self, out_dim, text_num_embeds, mask_padding=True):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
35
- self.mask_padding = mask_padding # mask filler and batch padding tokens or not
36
-
37
  self.precompute_max_pos = 1024
38
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
39
 
40
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
41
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
42
- if self.mask_padding:
43
- text_mask = text == 0
44
-
45
- if drop_text: # cfg for text
46
  text = torch.zeros_like(text)
47
-
48
- text = self.text_embed(text) # b nt -> b nt d
49
 
50
  # sinus pos emb
51
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
@@ -55,9 +49,6 @@ class TextEmbedding(nn.Module):
55
 
56
  text = text + text_pos_embed
57
 
58
- if self.mask_padding:
59
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
60
-
61
  return text
62
 
63
 
@@ -92,16 +83,13 @@ class MMDiT(nn.Module):
92
  dim_head=64,
93
  dropout=0.1,
94
  ff_mult=4,
95
- mel_dim=100,
96
  text_num_embeds=256,
97
- text_mask_padding=True,
98
- qk_norm=None,
99
  ):
100
  super().__init__()
101
 
102
  self.time_embed = TimestepEmbedding(dim)
103
- self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
104
- self.text_cond, self.text_uncond = None, None # text cache
105
  self.audio_embed = AudioEmbedding(mel_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -118,33 +106,13 @@ class MMDiT(nn.Module):
118
  dropout=dropout,
119
  ff_mult=ff_mult,
120
  context_pre_only=i == depth - 1,
121
- qk_norm=qk_norm,
122
  )
123
  for i in range(depth)
124
  ]
125
  )
126
- self.norm_out = AdaLayerNorm_Final(dim) # final modulation
127
  self.proj_out = nn.Linear(dim, mel_dim)
128
 
129
- self.initialize_weights()
130
-
131
- def initialize_weights(self):
132
- # Zero-out AdaLN layers in MMDiT blocks:
133
- for block in self.transformer_blocks:
134
- nn.init.constant_(block.attn_norm_x.linear.weight, 0)
135
- nn.init.constant_(block.attn_norm_x.linear.bias, 0)
136
- nn.init.constant_(block.attn_norm_c.linear.weight, 0)
137
- nn.init.constant_(block.attn_norm_c.linear.bias, 0)
138
-
139
- # Zero-out output layers:
140
- nn.init.constant_(self.norm_out.linear.weight, 0)
141
- nn.init.constant_(self.norm_out.linear.bias, 0)
142
- nn.init.constant_(self.proj_out.weight, 0)
143
- nn.init.constant_(self.proj_out.bias, 0)
144
-
145
- def clear_cache(self):
146
- self.text_cond, self.text_uncond = None, None
147
-
148
  def forward(
149
  self,
150
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -154,7 +122,6 @@ class MMDiT(nn.Module):
154
  drop_audio_cond, # cfg for cond audio
155
  drop_text, # cfg for text
156
  mask: bool["b n"] | None = None, # noqa: F722
157
- cache=False,
158
  ):
159
  batch = x.shape[0]
160
  if time.ndim == 0:
@@ -162,17 +129,7 @@ class MMDiT(nn.Module):
162
 
163
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
164
  t = self.time_embed(time)
165
- if cache:
166
- if drop_text:
167
- if self.text_uncond is None:
168
- self.text_uncond = self.text_embed(text, drop_text=True)
169
- c = self.text_uncond
170
- else:
171
- if self.text_cond is None:
172
- self.text_cond = self.text_embed(text, drop_text=False)
173
- c = self.text_cond
174
- else:
175
- c = self.text_embed(text, drop_text=drop_text)
176
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
177
 
178
  seq_len = x.shape[1]
 
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
+ AdaLayerNormZero_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
 
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
 
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
+ text = text + 1
40
+ if drop_text:
 
 
 
41
  text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
 
43
 
44
  # sinus pos emb
45
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
 
49
 
50
  text = text + text_pos_embed
51
 
 
 
 
52
  return text
53
 
54
 
 
83
  dim_head=64,
84
  dropout=0.1,
85
  ff_mult=4,
 
86
  text_num_embeds=256,
87
+ mel_dim=100,
 
88
  ):
89
  super().__init__()
90
 
91
  self.time_embed = TimestepEmbedding(dim)
92
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
 
93
  self.audio_embed = AudioEmbedding(mel_dim, dim)
94
 
95
  self.rotary_embed = RotaryEmbedding(dim_head)
 
106
  dropout=dropout,
107
  ff_mult=ff_mult,
108
  context_pre_only=i == depth - 1,
 
109
  )
110
  for i in range(depth)
111
  ]
112
  )
113
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
  self.proj_out = nn.Linear(dim, mel_dim)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def forward(
117
  self,
118
  x: float["b n d"], # nosied input audio # noqa: F722
 
122
  drop_audio_cond, # cfg for cond audio
123
  drop_text, # cfg for text
124
  mask: bool["b n"] | None = None, # noqa: F722
 
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
 
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
+ c = self.text_embed(text, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
133
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
f5_tts/model/backbones/unett.py CHANGED
@@ -33,12 +33,10 @@ from f5_tts.model.modules import (
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
40
- self.mask_padding = mask_padding # mask filler and batch padding tokens or not
41
-
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -54,8 +52,6 @@ class TextEmbedding(nn.Module):
54
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
55
  batch, text_len = text.shape[0], text.shape[1]
56
  text = F.pad(text, (0, seq_len - text_len), value=0)
57
- if self.mask_padding:
58
- text_mask = text == 0
59
 
60
  if drop_text: # cfg for text
61
  text = torch.zeros_like(text)
@@ -71,13 +67,7 @@ class TextEmbedding(nn.Module):
71
  text = text + text_pos_embed
72
 
73
  # convnextv2 blocks
74
- if self.mask_padding:
75
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
- for block in self.text_blocks:
77
- text = block(text)
78
- text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
79
- else:
80
- text = self.text_blocks(text)
81
 
82
  return text
83
 
@@ -116,10 +106,7 @@ class UNetT(nn.Module):
116
  mel_dim=100,
117
  text_num_embeds=256,
118
  text_dim=None,
119
- text_mask_padding=True,
120
- qk_norm=None,
121
  conv_layers=0,
122
- pe_attn_head=None,
123
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
124
  ):
125
  super().__init__()
@@ -128,10 +115,7 @@ class UNetT(nn.Module):
128
  self.time_embed = TimestepEmbedding(dim)
129
  if text_dim is None:
130
  text_dim = mel_dim
131
- self.text_embed = TextEmbedding(
132
- text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
133
- )
134
- self.text_cond, self.text_uncond = None, None # text cache
135
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
136
 
137
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -150,12 +134,11 @@ class UNetT(nn.Module):
150
 
151
  attn_norm = RMSNorm(dim)
152
  attn = Attention(
153
- processor=AttnProcessor(pe_attn_head=pe_attn_head),
154
  dim=dim,
155
  heads=heads,
156
  dim_head=dim_head,
157
  dropout=dropout,
158
- qk_norm=qk_norm,
159
  )
160
 
161
  ff_norm = RMSNorm(dim)
@@ -178,9 +161,6 @@ class UNetT(nn.Module):
178
  self.norm_out = RMSNorm(dim)
179
  self.proj_out = nn.Linear(dim, mel_dim)
180
 
181
- def clear_cache(self):
182
- self.text_cond, self.text_uncond = None, None
183
-
184
  def forward(
185
  self,
186
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -190,7 +170,6 @@ class UNetT(nn.Module):
190
  drop_audio_cond, # cfg for cond audio
191
  drop_text, # cfg for text
192
  mask: bool["b n"] | None = None, # noqa: F722
193
- cache=False,
194
  ):
195
  batch, seq_len = x.shape[0], x.shape[1]
196
  if time.ndim == 0:
@@ -198,17 +177,7 @@ class UNetT(nn.Module):
198
 
199
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
200
  t = self.time_embed(time)
201
- if cache:
202
- if drop_text:
203
- if self.text_uncond is None:
204
- self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
205
- text_embed = self.text_uncond
206
- else:
207
- if self.text_cond is None:
208
- self.text_cond = self.text_embed(text, seq_len, drop_text=False)
209
- text_embed = self.text_cond
210
- else:
211
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
212
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
213
 
214
  # postfix time t to input x, [b n d] -> [b n+1 d]
 
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
 
40
  if conv_layers > 0:
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
  batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
 
67
  text = text + text_pos_embed
68
 
69
  # convnextv2 blocks
70
+ text = self.text_blocks(text)
 
 
 
 
 
 
71
 
72
  return text
73
 
 
106
  mel_dim=100,
107
  text_num_embeds=256,
108
  text_dim=None,
 
 
109
  conv_layers=0,
 
110
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
 
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
 
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
+ processor=AttnProcessor(),
138
  dim=dim,
139
  heads=heads,
140
  dim_head=dim_head,
141
  dropout=dropout,
 
142
  )
143
 
144
  ff_norm = RMSNorm(dim)
 
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
 
 
 
164
  def forward(
165
  self,
166
  x: float["b n d"], # nosied input audio # noqa: F722
 
170
  drop_audio_cond, # cfg for cond audio
171
  drop_text, # cfg for text
172
  mask: bool["b n"] | None = None, # noqa: F722
 
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
 
177
 
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
181
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
f5_tts/model/cfm.py CHANGED
@@ -162,13 +162,13 @@ class CFM(nn.Module):
162
 
163
  # predict flow
164
  pred = self.transformer(
165
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
166
  )
167
  if cfg_strength < 1e-5:
168
  return pred
169
 
170
  null_pred = self.transformer(
171
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
172
  )
173
  return pred + (pred - null_pred) * cfg_strength
174
 
@@ -195,7 +195,6 @@ class CFM(nn.Module):
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
197
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
198
- self.transformer.clear_cache()
199
 
200
  sampled = trajectory[-1]
201
  out = sampled
 
162
 
163
  # predict flow
164
  pred = self.transformer(
165
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
166
  )
167
  if cfg_strength < 1e-5:
168
  return pred
169
 
170
  null_pred = self.transformer(
171
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
172
  )
173
  return pred + (pred - null_pred) * cfg_strength
174
 
 
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
197
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
 
198
 
199
  sampled = trajectory[-1]
200
  out = sampled
f5_tts/model/dataset.py CHANGED
@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
173
  """
174
 
175
  def __init__(
176
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
@@ -208,15 +208,12 @@ class DynamicBatchSampler(Sampler[list[int]]):
208
  batch = []
209
  batch_frames = 0
210
 
211
- if not drop_residual and len(batch) > 0:
212
  batches.append(batch)
213
 
214
  del indices
215
  self.batches = batches
216
 
217
- # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
218
- self.drop_last = True
219
-
220
  def set_epoch(self, epoch: int) -> None:
221
  """Sets the epoch for this sampler."""
222
  self.epoch = epoch
@@ -256,7 +253,7 @@ def load_dataset(
256
  print("Loading dataset ...")
257
 
258
  if dataset_type == "CustomDataset":
259
- rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}"))
260
  if audio_type == "raw":
261
  try:
262
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
 
173
  """
174
 
175
  def __init__(
176
+ self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
 
208
  batch = []
209
  batch_frames = 0
210
 
211
+ if not drop_last and len(batch) > 0:
212
  batches.append(batch)
213
 
214
  del indices
215
  self.batches = batches
216
 
 
 
 
217
  def set_epoch(self, epoch: int) -> None:
218
  """Sets the epoch for this sampler."""
219
  self.epoch = epoch
 
253
  print("Loading dataset ...")
254
 
255
  if dataset_type == "CustomDataset":
256
+ rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
257
  if audio_type == "raw":
258
  try:
259
  train_dataset = load_from_disk(f"{rel_data_path}/raw")
f5_tts/model/modules.py CHANGED
@@ -269,36 +269,11 @@ class ConvNeXtV2Block(nn.Module):
269
  return residual + x
270
 
271
 
272
- # RMSNorm
273
-
274
-
275
- class RMSNorm(nn.Module):
276
- def __init__(self, dim: int, eps: float):
277
- super().__init__()
278
- self.eps = eps
279
- self.weight = nn.Parameter(torch.ones(dim))
280
- self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
281
-
282
- def forward(self, x):
283
- if self.native_rms_norm:
284
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
285
- x = x.to(self.weight.dtype)
286
- x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
287
- else:
288
- variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
- x = x * torch.rsqrt(variance + self.eps)
290
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
291
- x = x.to(self.weight.dtype)
292
- x = x * self.weight
293
-
294
- return x
295
-
296
-
297
- # AdaLayerNorm
298
  # return with modulated x for attn input, and params for later mlp modulation
299
 
300
 
301
- class AdaLayerNorm(nn.Module):
302
  def __init__(self, dim):
303
  super().__init__()
304
 
@@ -315,11 +290,11 @@ class AdaLayerNorm(nn.Module):
315
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
316
 
317
 
318
- # AdaLayerNorm for final layer
319
  # return only with modulated x for attn input, cuz no more mlp modulation
320
 
321
 
322
- class AdaLayerNorm_Final(nn.Module):
323
  def __init__(self, dim):
324
  super().__init__()
325
 
@@ -366,8 +341,7 @@ class Attention(nn.Module):
366
  dim_head: int = 64,
367
  dropout: float = 0.0,
368
  context_dim: Optional[int] = None, # if not None -> joint attention
369
- context_pre_only: bool = False,
370
- qk_norm: Optional[str] = None,
371
  ):
372
  super().__init__()
373
 
@@ -388,32 +362,18 @@ class Attention(nn.Module):
388
  self.to_k = nn.Linear(dim, self.inner_dim)
389
  self.to_v = nn.Linear(dim, self.inner_dim)
390
 
391
- if qk_norm is None:
392
- self.q_norm = None
393
- self.k_norm = None
394
- elif qk_norm == "rms_norm":
395
- self.q_norm = RMSNorm(dim_head, eps=1e-6)
396
- self.k_norm = RMSNorm(dim_head, eps=1e-6)
397
- else:
398
- raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
399
-
400
  if self.context_dim is not None:
401
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
402
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
403
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
404
- if qk_norm is None:
405
- self.c_q_norm = None
406
- self.c_k_norm = None
407
- elif qk_norm == "rms_norm":
408
- self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
409
- self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
410
 
411
  self.to_out = nn.ModuleList([])
412
  self.to_out.append(nn.Linear(self.inner_dim, dim))
413
  self.to_out.append(nn.Dropout(dropout))
414
 
415
- if self.context_dim is not None and not self.context_pre_only:
416
- self.to_out_c = nn.Linear(self.inner_dim, context_dim)
417
 
418
  def forward(
419
  self,
@@ -433,11 +393,8 @@ class Attention(nn.Module):
433
 
434
 
435
  class AttnProcessor:
436
- def __init__(
437
- self,
438
- pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
439
- ):
440
- self.pe_attn_head = pe_attn_head
441
 
442
  def __call__(
443
  self,
@@ -448,11 +405,19 @@ class AttnProcessor:
448
  ) -> torch.FloatTensor:
449
  batch_size = x.shape[0]
450
 
451
- # `sample` projections
452
  query = attn.to_q(x)
453
  key = attn.to_k(x)
454
  value = attn.to_v(x)
455
 
 
 
 
 
 
 
 
 
456
  # attention
457
  inner_dim = key.shape[-1]
458
  head_dim = inner_dim // attn.heads
@@ -460,25 +425,6 @@ class AttnProcessor:
460
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
461
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
462
 
463
- # qk norm
464
- if attn.q_norm is not None:
465
- query = attn.q_norm(query)
466
- if attn.k_norm is not None:
467
- key = attn.k_norm(key)
468
-
469
- # apply rotary position embedding
470
- if rope is not None:
471
- freqs, xpos_scale = rope
472
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
473
-
474
- if self.pe_attn_head is not None:
475
- pn = self.pe_attn_head
476
- query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
477
- key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
478
- else:
479
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
480
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
481
-
482
  # mask. e.g. inference got a batch with different target durations, mask out the padding
483
  if mask is not None:
484
  attn_mask = mask
@@ -524,36 +470,16 @@ class JointAttnProcessor:
524
 
525
  batch_size = c.shape[0]
526
 
527
- # `sample` projections
528
  query = attn.to_q(x)
529
  key = attn.to_k(x)
530
  value = attn.to_v(x)
531
 
532
- # `context` projections
533
  c_query = attn.to_q_c(c)
534
  c_key = attn.to_k_c(c)
535
  c_value = attn.to_v_c(c)
536
 
537
- # attention
538
- inner_dim = key.shape[-1]
539
- head_dim = inner_dim // attn.heads
540
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
541
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
542
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
543
- c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
544
- c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
545
- c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
546
-
547
- # qk norm
548
- if attn.q_norm is not None:
549
- query = attn.q_norm(query)
550
- if attn.k_norm is not None:
551
- key = attn.k_norm(key)
552
- if attn.c_q_norm is not None:
553
- c_query = attn.c_q_norm(c_query)
554
- if attn.c_k_norm is not None:
555
- c_key = attn.c_k_norm(c_key)
556
-
557
  # apply rope for context and noised input independently
558
  if rope is not None:
559
  freqs, xpos_scale = rope
@@ -566,10 +492,16 @@ class JointAttnProcessor:
566
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
567
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
568
 
569
- # joint attention
570
- query = torch.cat([query, c_query], dim=2)
571
- key = torch.cat([key, c_key], dim=2)
572
- value = torch.cat([value, c_value], dim=2)
 
 
 
 
 
 
573
 
574
  # mask. e.g. inference got a batch with different target durations, mask out the padding
575
  if mask is not None:
@@ -608,17 +540,16 @@ class JointAttnProcessor:
608
 
609
 
610
  class DiTBlock(nn.Module):
611
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
612
  super().__init__()
613
 
614
- self.attn_norm = AdaLayerNorm(dim)
615
  self.attn = Attention(
616
- processor=AttnProcessor(pe_attn_head=pe_attn_head),
617
  dim=dim,
618
  heads=heads,
619
  dim_head=dim_head,
620
  dropout=dropout,
621
- qk_norm=qk_norm,
622
  )
623
 
624
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -654,30 +585,26 @@ class MMDiTBlock(nn.Module):
654
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
655
  """
656
 
657
- def __init__(
658
- self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
659
- ):
660
  super().__init__()
661
- if context_dim is None:
662
- context_dim = dim
663
  self.context_pre_only = context_pre_only
664
 
665
- self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
666
- self.attn_norm_x = AdaLayerNorm(dim)
667
  self.attn = Attention(
668
  processor=JointAttnProcessor(),
669
  dim=dim,
670
  heads=heads,
671
  dim_head=dim_head,
672
  dropout=dropout,
673
- context_dim=context_dim,
674
  context_pre_only=context_pre_only,
675
- qk_norm=qk_norm,
676
  )
677
 
678
  if not context_pre_only:
679
- self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
680
- self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
681
  else:
682
  self.ff_norm_c = None
683
  self.ff_c = None
 
269
  return residual + x
270
 
271
 
272
+ # AdaLayerNormZero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # return with modulated x for attn input, and params for later mlp modulation
274
 
275
 
276
+ class AdaLayerNormZero(nn.Module):
277
  def __init__(self, dim):
278
  super().__init__()
279
 
 
290
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
 
292
 
293
+ # AdaLayerNormZero for final layer
294
  # return only with modulated x for attn input, cuz no more mlp modulation
295
 
296
 
297
+ class AdaLayerNormZero_Final(nn.Module):
298
  def __init__(self, dim):
299
  super().__init__()
300
 
 
341
  dim_head: int = 64,
342
  dropout: float = 0.0,
343
  context_dim: Optional[int] = None, # if not None -> joint attention
344
+ context_pre_only=None,
 
345
  ):
346
  super().__init__()
347
 
 
362
  self.to_k = nn.Linear(dim, self.inner_dim)
363
  self.to_v = nn.Linear(dim, self.inner_dim)
364
 
 
 
 
 
 
 
 
 
 
365
  if self.context_dim is not None:
 
366
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
+ if self.context_pre_only is not None:
369
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
 
 
 
 
370
 
371
  self.to_out = nn.ModuleList([])
372
  self.to_out.append(nn.Linear(self.inner_dim, dim))
373
  self.to_out.append(nn.Dropout(dropout))
374
 
375
+ if self.context_pre_only is not None and not self.context_pre_only:
376
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
377
 
378
  def forward(
379
  self,
 
393
 
394
 
395
  class AttnProcessor:
396
+ def __init__(self):
397
+ pass
 
 
 
398
 
399
  def __call__(
400
  self,
 
405
  ) -> torch.FloatTensor:
406
  batch_size = x.shape[0]
407
 
408
+ # `sample` projections.
409
  query = attn.to_q(x)
410
  key = attn.to_k(x)
411
  value = attn.to_v(x)
412
 
413
+ # apply rotary position embedding
414
+ if rope is not None:
415
+ freqs, xpos_scale = rope
416
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
+
418
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
+
421
  # attention
422
  inner_dim = key.shape[-1]
423
  head_dim = inner_dim // attn.heads
 
425
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # mask. e.g. inference got a batch with different target durations, mask out the padding
429
  if mask is not None:
430
  attn_mask = mask
 
470
 
471
  batch_size = c.shape[0]
472
 
473
+ # `sample` projections.
474
  query = attn.to_q(x)
475
  key = attn.to_k(x)
476
  value = attn.to_v(x)
477
 
478
+ # `context` projections.
479
  c_query = attn.to_q_c(c)
480
  c_key = attn.to_k_c(c)
481
  c_value = attn.to_v_c(c)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  # apply rope for context and noised input independently
484
  if rope is not None:
485
  freqs, xpos_scale = rope
 
492
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
 
495
+ # attention
496
+ query = torch.cat([query, c_query], dim=1)
497
+ key = torch.cat([key, c_key], dim=1)
498
+ value = torch.cat([value, c_value], dim=1)
499
+
500
+ inner_dim = key.shape[-1]
501
+ head_dim = inner_dim // attn.heads
502
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
 
506
  # mask. e.g. inference got a batch with different target durations, mask out the padding
507
  if mask is not None:
 
540
 
541
 
542
  class DiTBlock(nn.Module):
543
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
  super().__init__()
545
 
546
+ self.attn_norm = AdaLayerNormZero(dim)
547
  self.attn = Attention(
548
+ processor=AttnProcessor(),
549
  dim=dim,
550
  heads=heads,
551
  dim_head=dim_head,
552
  dropout=dropout,
 
553
  )
554
 
555
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
 
585
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
  """
587
 
588
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
 
 
589
  super().__init__()
590
+
 
591
  self.context_pre_only = context_pre_only
592
 
593
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
+ self.attn_norm_x = AdaLayerNormZero(dim)
595
  self.attn = Attention(
596
  processor=JointAttnProcessor(),
597
  dim=dim,
598
  heads=heads,
599
  dim_head=dim_head,
600
  dropout=dropout,
601
+ context_dim=dim,
602
  context_pre_only=context_pre_only,
 
603
  )
604
 
605
  if not context_pre_only:
606
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
  else:
609
  self.ff_norm_c = None
610
  self.ff_c = None
f5_tts/model/trainer.py CHANGED
@@ -32,7 +32,7 @@ class Trainer:
32
  save_per_updates=1000,
33
  keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
34
  checkpoint_path=None,
35
- batch_size_per_gpu=32,
36
  batch_size_type: str = "sample",
37
  max_samples=32,
38
  grad_accumulation_steps=1,
@@ -40,7 +40,7 @@ class Trainer:
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
43
- wandb_project="test_f5-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
46
  log_samples: bool = False,
@@ -51,7 +51,6 @@ class Trainer:
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
54
- cfg_dict: dict = dict(), # training config
55
  ):
56
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
57
 
@@ -73,23 +72,21 @@ class Trainer:
73
  else:
74
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
75
 
76
- if not cfg_dict:
77
- cfg_dict = {
 
 
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
81
- "batch_size_per_gpu": batch_size_per_gpu,
82
  "batch_size_type": batch_size_type,
83
  "max_samples": max_samples,
84
  "grad_accumulation_steps": grad_accumulation_steps,
85
  "max_grad_norm": max_grad_norm,
 
86
  "noise_scheduler": noise_scheduler,
87
- }
88
- cfg_dict["gpus"] = self.accelerator.num_processes
89
- self.accelerator.init_trackers(
90
- project_name=wandb_project,
91
- init_kwargs=init_kwargs,
92
- config=cfg_dict,
93
  )
94
 
95
  elif self.logger == "tensorboard":
@@ -114,9 +111,9 @@ class Trainer:
114
  self.save_per_updates = save_per_updates
115
  self.keep_last_n_checkpoints = keep_last_n_checkpoints
116
  self.last_per_updates = default(last_per_updates, save_per_updates)
117
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
118
 
119
- self.batch_size_per_gpu = batch_size_per_gpu
120
  self.batch_size_type = batch_size_type
121
  self.max_samples = max_samples
122
  self.grad_accumulation_steps = grad_accumulation_steps
@@ -182,7 +179,7 @@ class Trainer:
182
  if (
183
  not exists(self.checkpoint_path)
184
  or not os.path.exists(self.checkpoint_path)
185
- or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
186
  ):
187
  return 0
188
 
@@ -194,7 +191,7 @@ class Trainer:
194
  all_checkpoints = [
195
  f
196
  for f in os.listdir(self.checkpoint_path)
197
- if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
198
  ]
199
 
200
  # First try to find regular training checkpoints
@@ -208,16 +205,8 @@ class Trainer:
208
  # If no training checkpoints, use pretrained model
209
  latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
210
 
211
- if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
212
- from safetensors.torch import load_file
213
-
214
- checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
215
- checkpoint = {"ema_model_state_dict": checkpoint}
216
- elif latest_checkpoint.endswith(".pt"):
217
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
218
- checkpoint = torch.load(
219
- f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
220
- )
221
 
222
  # patch for backward compatibility, 305e3ea
223
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
@@ -282,7 +271,7 @@ class Trainer:
282
  num_workers=num_workers,
283
  pin_memory=True,
284
  persistent_workers=True,
285
- batch_size=self.batch_size_per_gpu,
286
  shuffle=True,
287
  generator=generator,
288
  )
@@ -291,10 +280,10 @@ class Trainer:
291
  sampler = SequentialSampler(train_dataset)
292
  batch_sampler = DynamicBatchSampler(
293
  sampler,
294
- self.batch_size_per_gpu,
295
  max_samples=self.max_samples,
296
  random_seed=resumable_with_seed, # This enables reproducible shuffling
297
- drop_residual=False,
298
  )
299
  train_dataloader = DataLoader(
300
  train_dataset,
 
32
  save_per_updates=1000,
33
  keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
34
  checkpoint_path=None,
35
+ batch_size=32,
36
  batch_size_type: str = "sample",
37
  max_samples=32,
38
  grad_accumulation_steps=1,
 
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
43
+ wandb_project="test_e2-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
46
  log_samples: bool = False,
 
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
 
54
  ):
55
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
56
 
 
72
  else:
73
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
74
 
75
+ self.accelerator.init_trackers(
76
+ project_name=wandb_project,
77
+ init_kwargs=init_kwargs,
78
+ config={
79
  "epochs": epochs,
80
  "learning_rate": learning_rate,
81
  "num_warmup_updates": num_warmup_updates,
82
+ "batch_size": batch_size,
83
  "batch_size_type": batch_size_type,
84
  "max_samples": max_samples,
85
  "grad_accumulation_steps": grad_accumulation_steps,
86
  "max_grad_norm": max_grad_norm,
87
+ "gpus": self.accelerator.num_processes,
88
  "noise_scheduler": noise_scheduler,
89
+ },
 
 
 
 
 
90
  )
91
 
92
  elif self.logger == "tensorboard":
 
111
  self.save_per_updates = save_per_updates
112
  self.keep_last_n_checkpoints = keep_last_n_checkpoints
113
  self.last_per_updates = default(last_per_updates, save_per_updates)
114
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
115
 
116
+ self.batch_size = batch_size
117
  self.batch_size_type = batch_size_type
118
  self.max_samples = max_samples
119
  self.grad_accumulation_steps = grad_accumulation_steps
 
179
  if (
180
  not exists(self.checkpoint_path)
181
  or not os.path.exists(self.checkpoint_path)
182
+ or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
183
  ):
184
  return 0
185
 
 
191
  all_checkpoints = [
192
  f
193
  for f in os.listdir(self.checkpoint_path)
194
+ if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
195
  ]
196
 
197
  # First try to find regular training checkpoints
 
205
  # If no training checkpoints, use pretrained model
206
  latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
207
 
208
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
209
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
 
 
 
 
210
 
211
  # patch for backward compatibility, 305e3ea
212
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
 
271
  num_workers=num_workers,
272
  pin_memory=True,
273
  persistent_workers=True,
274
+ batch_size=self.batch_size,
275
  shuffle=True,
276
  generator=generator,
277
  )
 
280
  sampler = SequentialSampler(train_dataset)
281
  batch_sampler = DynamicBatchSampler(
282
  sampler,
283
+ self.batch_size,
284
  max_samples=self.max_samples,
285
  random_seed=resumable_with_seed, # This enables reproducible shuffling
286
+ drop_last=False,
287
  )
288
  train_dataloader = DataLoader(
289
  train_dataset,
f5_tts/model/utils.py CHANGED
@@ -109,7 +109,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
109
  - if use "byte", set to 256 (unicode byte range)
110
  """
111
  if tokenizer in ["pinyin", "char"]:
112
- tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}/vocab.txt")
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
@@ -133,12 +133,22 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
133
 
134
  # convert char to pinyin
135
 
 
 
136
 
137
- def convert_char_to_pinyin(text_list, polyphone=True):
138
- if jieba.dt.initialized is False:
139
- jieba.default_logger.setLevel(50) # CRITICAL
140
- jieba.initialize()
 
 
 
 
 
 
 
141
 
 
142
  final_text_list = []
143
  custom_trans = str.maketrans(
144
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
@@ -174,13 +184,11 @@ def convert_char_to_pinyin(text_list, polyphone=True):
174
  else:
175
  char_list.append(c)
176
  final_text_list.append(char_list)
177
-
178
  return final_text_list
179
 
180
-
181
  # filter func for dirty data with many repetitions
182
 
183
-
184
  def repetition_found(text, length=2, tolerance=10):
185
  pattern_count = defaultdict(int)
186
  for i in range(len(text) - length + 1):
 
109
  - if use "byte", set to 256 (unicode byte range)
110
  """
111
  if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
 
133
 
134
  # convert char to pinyin
135
 
136
+ jieba.initialize()
137
+ print("Word segmentation module jieba initialized.\n")
138
 
139
+ # def convert_char_to_pinyin(text_list, polyphone=True):
140
+ # final_text_list = []
141
+ # for text in text_list:
142
+ # char_list = [char for char in text if char not in "。,、;:?!《》【】—…:;\"()[]{}"]
143
+ # final_text_list.append(char_list)
144
+ # # print(final_text_list)
145
+ # return final_text_list
146
+
147
+ # def convert_char_to_pinyin(text_list, polyphone=True):
148
+ # final_text_list = [char for char in text_list if char not in "。,、;:?!《》【】—…:;?!\"()[]{}"]
149
+ # return final_text_list
150
 
151
+ def convert_char_to_pinyin(text_list, polyphone=True):
152
  final_text_list = []
153
  custom_trans = str.maketrans(
154
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
 
184
  else:
185
  char_list.append(c)
186
  final_text_list.append(char_list)
187
+ # print(final_text_list)
188
  return final_text_list
189
 
 
190
  # filter func for dirty data with many repetitions
191
 
 
192
  def repetition_found(text, length=2, tolerance=10):
193
  pattern_count = defaultdict(int)
194
  for i in range(len(text) - length + 1):
f5_tts/scripts/count_max_epoch.py CHANGED
@@ -9,7 +9,7 @@ mel_hop_length = 256
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
- wanted_max_updates = 1200000
13
 
14
  # train params
15
  gpus = 8
 
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
+ wanted_max_updates = 1000000
13
 
14
  # train params
15
  gpus = 8
f5_tts/socket_client.py DELETED
@@ -1,61 +0,0 @@
1
- import socket
2
- import asyncio
3
- import pyaudio
4
- import numpy as np
5
- import logging
6
- import time
7
-
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
13
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14
- await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
15
-
16
- start_time = time.time()
17
- first_chunk_time = None
18
-
19
- async def play_audio_stream():
20
- nonlocal first_chunk_time
21
- p = pyaudio.PyAudio()
22
- stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
23
-
24
- try:
25
- while True:
26
- data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
27
- if not data:
28
- break
29
- if data == b"END":
30
- logger.info("End of audio received.")
31
- break
32
-
33
- audio_array = np.frombuffer(data, dtype=np.float32)
34
- stream.write(audio_array.tobytes())
35
-
36
- if first_chunk_time is None:
37
- first_chunk_time = time.time()
38
-
39
- finally:
40
- stream.stop_stream()
41
- stream.close()
42
- p.terminate()
43
-
44
- logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
45
-
46
- try:
47
- data_to_send = f"{text}".encode("utf-8")
48
- await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
49
- await play_audio_stream()
50
-
51
- except Exception as e:
52
- logger.error(f"Error in listen_to_F5TTS: {e}")
53
-
54
- finally:
55
- client_socket.close()
56
-
57
-
58
- if __name__ == "__main__":
59
- text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
60
-
61
- asyncio.run(listen_to_F5TTS(text_to_send))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5_tts/socket_server.py CHANGED
@@ -1,75 +1,21 @@
1
  import argparse
2
  import gc
3
- import logging
4
- import numpy as np
5
- import queue
6
  import socket
7
  import struct
8
- import threading
 
9
  import traceback
10
- import wave
11
  from importlib.resources import files
 
12
 
13
- import torch
14
- import torchaudio
15
- from huggingface_hub import hf_hub_download
16
- from omegaconf import OmegaConf
17
-
18
- from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
19
- from f5_tts.infer.utils_infer import (
20
- chunk_text,
21
- preprocess_ref_audio_text,
22
- load_vocoder,
23
- load_model,
24
- infer_batch_process,
25
- )
26
-
27
- logging.basicConfig(level=logging.INFO)
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- class AudioFileWriterThread(threading.Thread):
32
- """Threaded file writer to avoid blocking the TTS streaming process."""
33
-
34
- def __init__(self, output_file, sampling_rate):
35
- super().__init__()
36
- self.output_file = output_file
37
- self.sampling_rate = sampling_rate
38
- self.queue = queue.Queue()
39
- self.stop_event = threading.Event()
40
- self.audio_data = []
41
-
42
- def run(self):
43
- """Process queued audio data and write it to a file."""
44
- logger.info("AudioFileWriterThread started.")
45
- with wave.open(self.output_file, "wb") as wf:
46
- wf.setnchannels(1)
47
- wf.setsampwidth(2)
48
- wf.setframerate(self.sampling_rate)
49
-
50
- while not self.stop_event.is_set() or not self.queue.empty():
51
- try:
52
- chunk = self.queue.get(timeout=0.1)
53
- if chunk is not None:
54
- chunk = np.int16(chunk * 32767)
55
- self.audio_data.append(chunk)
56
- wf.writeframes(chunk.tobytes())
57
- except queue.Empty:
58
- continue
59
-
60
- def add_chunk(self, chunk):
61
- """Add a new chunk to the queue."""
62
- self.queue.put(chunk)
63
-
64
- def stop(self):
65
- """Stop writing and ensure all queued data is written."""
66
- self.stop_event.set()
67
- self.join()
68
- logger.info("Audio writing completed.")
69
 
70
 
71
  class TTSStreamingProcessor:
72
- def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
73
  self.device = device or (
74
  "cuda"
75
  if torch.cuda.is_available()
@@ -79,135 +25,124 @@ class TTSStreamingProcessor:
79
  if torch.backends.mps.is_available()
80
  else "cpu"
81
  )
82
- model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
83
- self.model_cls = globals()[model_cfg.model.backbone]
84
- self.model_arc = model_cfg.model.arch
85
- self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
86
- self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
87
-
88
- self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
89
- self.vocoder = self.load_vocoder_model()
90
-
91
- self.update_reference(ref_audio, ref_text)
92
- self._warm_up()
93
- self.file_writer_thread = None
94
- self.first_package = True
95
 
96
- def load_ema_model(self, ckpt_file, vocab_file, dtype):
97
- return load_model(
98
- self.model_cls,
99
- self.model_arc,
100
  ckpt_path=ckpt_file,
101
- mel_spec_type=self.mel_spec_type,
102
  vocab_file=vocab_file,
103
  ode_method="euler",
104
  use_ema=True,
105
  device=self.device,
106
  ).to(self.device, dtype=dtype)
107
 
108
- def load_vocoder_model(self):
109
- return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device)
110
 
111
- def update_reference(self, ref_audio, ref_text):
112
- self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
113
- self.audio, self.sr = torchaudio.load(self.ref_audio)
114
 
115
- ref_audio_duration = self.audio.shape[-1] / self.sr
116
- ref_text_byte_len = len(self.ref_text.encode("utf-8"))
117
- self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
118
- self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
119
- self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
 
120
 
121
  def _warm_up(self):
122
- logger.info("Warming up the model...")
 
 
 
123
  gen_text = "Warm-up text for the model."
124
- for _ in infer_batch_process(
125
- (self.audio, self.sr),
126
- self.ref_text,
127
- [gen_text],
128
- self.model,
129
- self.vocoder,
130
- progress=None,
131
- device=self.device,
132
- streaming=True,
133
- ):
134
- pass
135
- logger.info("Warm-up completed.")
136
-
137
- def generate_stream(self, text, conn):
138
- text_batches = chunk_text(text, max_chars=self.max_chars)
139
- if self.first_package:
140
- text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
141
- text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
142
- self.first_package = False
143
-
144
- audio_stream = infer_batch_process(
145
- (self.audio, self.sr),
146
- self.ref_text,
147
- text_batches,
148
  self.model,
149
  self.vocoder,
150
- progress=None,
151
- device=self.device,
152
- streaming=True,
153
- chunk_size=2048,
154
  )
155
 
156
- # Reset the file writer thread
157
- if self.file_writer_thread is not None:
158
- self.file_writer_thread.stop()
159
- self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
160
- self.file_writer_thread.start()
161
-
162
- for audio_chunk, _ in audio_stream:
163
- if len(audio_chunk) > 0:
164
- logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
165
 
166
- # Send audio chunk via socket
167
- conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
 
 
168
 
169
- # Write to file asynchronously
170
- self.file_writer_thread.add_chunk(audio_chunk)
171
 
172
- logger.info("Finished sending audio stream.")
173
- conn.sendall(b"END") # Send end signal
 
174
 
175
- # Ensure all audio data is written before exiting
176
- self.file_writer_thread.stop()
 
 
177
 
178
 
179
- def handle_client(conn, processor):
180
  try:
181
- with conn:
182
- conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
183
- while True:
184
- data = conn.recv(1024)
185
- if not data:
186
- processor.first_package = True
187
- break
188
- data_str = data.decode("utf-8").strip()
189
- logger.info(f"Received text: {data_str}")
190
-
191
- try:
192
- processor.generate_stream(data_str, conn)
193
- except Exception as inner_e:
194
- logger.error(f"Error during processing: {inner_e}")
195
- traceback.print_exc()
196
- break
 
 
 
 
 
 
197
  except Exception as e:
198
- logger.error(f"Error handling client: {e}")
199
  traceback.print_exc()
 
 
200
 
201
 
202
  def start_server(host, port, processor):
203
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
204
- s.bind((host, port))
205
- s.listen()
206
- logger.info(f"Server started on {host}:{port}")
207
- while True:
208
- conn, addr = s.accept()
209
- logger.info(f"Connected by {addr}")
210
- handle_client(conn, processor)
 
 
211
 
212
 
213
  if __name__ == "__main__":
@@ -216,14 +151,9 @@ if __name__ == "__main__":
216
  parser.add_argument("--host", default="0.0.0.0")
217
  parser.add_argument("--port", default=9998)
218
 
219
- parser.add_argument(
220
- "--model",
221
- default="F5TTS_v1_Base",
222
- help="The model name, e.g. F5TTS_v1_Base",
223
- )
224
  parser.add_argument(
225
  "--ckpt_file",
226
- default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
227
  help="Path to the model checkpoint file",
228
  )
229
  parser.add_argument(
@@ -251,7 +181,6 @@ if __name__ == "__main__":
251
  try:
252
  # Initialize the processor with the model and vocoder
253
  processor = TTSStreamingProcessor(
254
- model=args.model,
255
  ckpt_file=args.ckpt_file,
256
  vocab_file=args.vocab_file,
257
  ref_audio=args.ref_audio,
 
1
  import argparse
2
  import gc
 
 
 
3
  import socket
4
  import struct
5
+ import torch
6
+ import torchaudio
7
  import traceback
 
8
  from importlib.resources import files
9
+ from threading import Thread
10
 
11
+ from cached_path import cached_path
12
+
13
+ from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
14
+ from model.backbones.dit import DiT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class TTSStreamingProcessor:
18
+ def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
19
  self.device = device or (
20
  "cuda"
21
  if torch.cuda.is_available()
 
25
  if torch.backends.mps.is_available()
26
  else "cpu"
27
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Load the model using the provided checkpoint and vocab files
30
+ self.model = load_model(
31
+ model_cls=DiT,
32
+ model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
33
  ckpt_path=ckpt_file,
34
+ mel_spec_type="vocos", # or "bigvgan" depending on vocoder
35
  vocab_file=vocab_file,
36
  ode_method="euler",
37
  use_ema=True,
38
  device=self.device,
39
  ).to(self.device, dtype=dtype)
40
 
41
+ # Load the vocoder
42
+ self.vocoder = load_vocoder(is_local=False)
43
 
44
+ # Set sampling rate for streaming
45
+ self.sampling_rate = 24000 # Consistency with client
 
46
 
47
+ # Set reference audio and text
48
+ self.ref_audio = ref_audio
49
+ self.ref_text = ref_text
50
+
51
+ # Warm up the model
52
+ self._warm_up()
53
 
54
  def _warm_up(self):
55
+ """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
56
+ print("Warming up the model...")
57
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
58
+ audio, sr = torchaudio.load(ref_audio)
59
  gen_text = "Warm-up text for the model."
60
+
61
+ # Pass the vocoder as an argument here
62
+ infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
63
+ print("Warm-up completed.")
64
+
65
+ def generate_stream(self, text, play_steps_in_s=0.5):
66
+ """Generate audio in chunks and yield them in real-time."""
67
+ # Preprocess the reference audio and text
68
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
69
+
70
+ # Load reference audio
71
+ audio, sr = torchaudio.load(ref_audio)
72
+
73
+ # Run inference for the input text
74
+ audio_chunk, final_sample_rate, _ = infer_batch_process(
75
+ (audio, sr),
76
+ ref_text,
77
+ [text],
 
 
 
 
 
 
78
  self.model,
79
  self.vocoder,
80
+ device=self.device, # Pass vocoder here
 
 
 
81
  )
82
 
83
+ # Break the generated audio into chunks and send them
84
+ chunk_size = int(final_sample_rate * play_steps_in_s)
 
 
 
 
 
 
 
85
 
86
+ if len(audio_chunk) < chunk_size:
87
+ packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk)
88
+ yield packed_audio
89
+ return
90
 
91
+ for i in range(0, len(audio_chunk), chunk_size):
92
+ chunk = audio_chunk[i : i + chunk_size]
93
 
94
+ # Check if it's the final chunk
95
+ if i + chunk_size >= len(audio_chunk):
96
+ chunk = audio_chunk[i:]
97
 
98
+ # Send the chunk if it is not empty
99
+ if len(chunk) > 0:
100
+ packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
101
+ yield packed_audio
102
 
103
 
104
+ def handle_client(client_socket, processor):
105
  try:
106
+ while True:
107
+ # Receive data from the client
108
+ data = client_socket.recv(1024).decode("utf-8")
109
+ if not data:
110
+ break
111
+
112
+ try:
113
+ # The client sends the text input
114
+ text = data.strip()
115
+
116
+ # Generate and stream audio chunks
117
+ for audio_chunk in processor.generate_stream(text):
118
+ client_socket.sendall(audio_chunk)
119
+
120
+ # Send end-of-audio signal
121
+ client_socket.sendall(b"END_OF_AUDIO")
122
+
123
+ except Exception as inner_e:
124
+ print(f"Error during processing: {inner_e}")
125
+ traceback.print_exc() # Print the full traceback to diagnose the issue
126
+ break
127
+
128
  except Exception as e:
129
+ print(f"Error handling client: {e}")
130
  traceback.print_exc()
131
+ finally:
132
+ client_socket.close()
133
 
134
 
135
  def start_server(host, port, processor):
136
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
137
+ server.bind((host, port))
138
+ server.listen(5)
139
+ print(f"Server listening on {host}:{port}")
140
+
141
+ while True:
142
+ client_socket, addr = server.accept()
143
+ print(f"Accepted connection from {addr}")
144
+ client_handler = Thread(target=handle_client, args=(client_socket, processor))
145
+ client_handler.start()
146
 
147
 
148
  if __name__ == "__main__":
 
151
  parser.add_argument("--host", default="0.0.0.0")
152
  parser.add_argument("--port", default=9998)
153
 
 
 
 
 
 
154
  parser.add_argument(
155
  "--ckpt_file",
156
+ default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
157
  help="Path to the model checkpoint file",
158
  )
159
  parser.add_argument(
 
181
  try:
182
  # Initialize the processor with the model and vocoder
183
  processor = TTSStreamingProcessor(
 
184
  ckpt_file=args.ckpt_file,
185
  vocab_file=args.vocab_file,
186
  ref_audio=args.ref_audio,
f5_tts/train/README.md CHANGED
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
- accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
- accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
- ### 3. W&B Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
- 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
@@ -75,7 +75,7 @@ On Windows:
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
- Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
 
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
+ accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
+ accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
 
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
+ ### 3. Wandb Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
 
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
 
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
+ Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc CHANGED
Binary files a/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc and b/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc differ
 
f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -24,7 +24,7 @@ from f5_tts.model.utils import (
24
  )
25
 
26
 
27
- PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/your_training_dataset/vocab.txt")
28
 
29
 
30
  def is_csv_wavs_format(input_dataset_dir):
@@ -224,7 +224,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
224
  voca_out_path = out_dir / "vocab.txt"
225
  if is_finetune:
226
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
227
- # shutil.copy2(file_vocab_finetune, voca_out_path) # Không cần copy lại vocab, do đã thực hiện ở bước chuẩn bị dữ liệu
228
  else:
229
  with open(voca_out_path.as_posix(), "w") as f:
230
  for vocab in sorted(text_vocab_set):
 
24
  )
25
 
26
 
27
+ PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
28
 
29
 
30
  def is_csv_wavs_format(input_dataset_dir):
 
224
  voca_out_path = out_dir / "vocab.txt"
225
  if is_finetune:
226
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
227
+ shutil.copy2(file_vocab_finetune, voca_out_path)
228
  else:
229
  with open(voca_out_path.as_posix(), "w") as f:
230
  for vocab in sorted(text_vocab_set):
f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -206,14 +206,14 @@ def main():
206
 
207
 
208
  if __name__ == "__main__":
209
- max_workers = 32
210
 
211
  tokenizer = "pinyin" # "pinyin" | "char"
212
  polyphone = True
213
 
214
- langs = ["ZH", "EN"]
215
- dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
216
- dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
217
  save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
218
  print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
219
 
 
206
 
207
 
208
  if __name__ == "__main__":
209
+ max_workers = 16
210
 
211
  tokenizer = "pinyin" # "pinyin" | "char"
212
  polyphone = True
213
 
214
+ langs = ["EN"]
215
+ dataset_dir = "data/datasetVN"
216
+ dataset_name = f"vnTTS_{'_'.join(langs)}_{tokenizer}"
217
  save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
218
  print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
219
 
f5_tts/train/datasets/prepare_libritts.py CHANGED
@@ -11,6 +11,11 @@ from tqdm import tqdm
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
13
 
 
 
 
 
 
14
 
15
  def deal_with_audio_dir(audio_dir):
16
  sub_result, durations = [], []
@@ -18,7 +23,7 @@ def deal_with_audio_dir(audio_dir):
18
  audio_lists = list(audio_dir.rglob("*.wav"))
19
 
20
  for line in audio_lists:
21
- text_path = line.with_suffix(".normalized.txt")
22
  text = open(text_path, "r").read().strip()
23
  duration = sf.info(line).duration
24
  if duration < 0.4 or duration > 30:
@@ -76,13 +81,13 @@ def main():
76
 
77
 
78
  if __name__ == "__main__":
79
- max_workers = 36
80
 
81
  tokenizer = "char" # "pinyin" | "char"
82
 
83
- SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
84
- dataset_dir = "<SOME_PATH>/LibriTTS"
85
- dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
86
  save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
87
  print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
88
  main()
 
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
13
 
14
+ from f5_tts.model.utils import (
15
+ repetition_found,
16
+ convert_char_to_pinyin,
17
+ )
18
+
19
 
20
  def deal_with_audio_dir(audio_dir):
21
  sub_result, durations = [], []
 
23
  audio_lists = list(audio_dir.rglob("*.wav"))
24
 
25
  for line in audio_lists:
26
+ text_path = line.with_suffix(".lab")
27
  text = open(text_path, "r").read().strip()
28
  duration = sf.info(line).duration
29
  if duration < 0.4 or duration > 30:
 
81
 
82
 
83
  if __name__ == "__main__":
84
+ max_workers = 16
85
 
86
  tokenizer = "char" # "pinyin" | "char"
87
 
88
+ SUB_SET = ["mc"]
89
+ dataset_dir = "data/datasetVN"
90
+ dataset_name = f"vnTTS_{'_'.join(SUB_SET)}_{tokenizer}"
91
  save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
92
  print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
93
  main()
f5_tts/train/datasets/prepare_metadata.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from tqdm import tqdm
3
+
4
+ wavs_path = glob.glob("data/datasetVN/mc/mc1/*.wav")
5
+
6
+ with open("data/vnTTS__char/metadata.csv", "w", encoding="utf8") as fw:
7
+ fw.write("audio_file|text\n")
8
+ for wav_path in tqdm(wavs_path):
9
+ wav_name = wav_path.split("/")[-1]
10
+ with open(wav_path.replace(".wav", ".lab"), "r", encoding="utf8") as fr:
11
+ text = fr.readlines()[0].replace("\n", "")
12
+ fw.write("wavs/" + wav_name + "|" + text + "\n")
f5_tts/train/finetune_cli.py CHANGED
@@ -1,13 +1,12 @@
1
  import argparse
2
  import os
3
  import shutil
4
- from importlib.resources import files
5
 
6
  from cached_path import cached_path
7
-
8
  from f5_tts.model import CFM, UNetT, DiT, Trainer
9
  from f5_tts.model.utils import get_tokenizer
10
  from f5_tts.model.dataset import load_dataset
 
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
@@ -21,16 +20,21 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
21
 
22
  # -------------------------- Argument Parsing --------------------------- #
23
  def parse_args():
 
 
 
 
 
 
 
 
 
24
  parser = argparse.ArgumentParser(description="Train CFM Model")
25
 
26
  parser.add_argument(
27
- "--exp_name",
28
- type=str,
29
- default="F5TTS_v1_Base",
30
- choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
31
- help="Experiment name",
32
  )
33
- parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
34
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
35
  parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
36
  parser.add_argument(
@@ -39,7 +43,7 @@ def parse_args():
39
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
40
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
41
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
42
- parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs")
43
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
44
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
45
  parser.add_argument(
@@ -50,7 +54,7 @@ def parse_args():
50
  )
51
  parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
52
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
53
- parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
54
  parser.add_argument(
55
  "--tokenizer", type=str, default="char", choices=["pinyin", "char", "custom"], help="Tokenizer type"
56
  )
@@ -84,54 +88,19 @@ def main():
84
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
85
 
86
  # Model parameters based on experiment name
87
-
88
- if args.exp_name == "F5TTS_v1_Base":
89
  wandb_resume_id = None
90
  model_cls = DiT
91
- model_cfg = dict(
92
- dim=1024,
93
- depth=22,
94
- heads=16,
95
- ff_mult=2,
96
- text_dim=512,
97
- conv_layers=4,
98
- )
99
- if args.finetune:
100
- if args.pretrain is None:
101
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
102
- else:
103
- ckpt_path = args.pretrain
104
-
105
- elif args.exp_name == "F5TTS_Base":
106
- wandb_resume_id = None
107
- model_cls = DiT
108
- model_cfg = dict(
109
- dim=1024,
110
- depth=22,
111
- heads=16,
112
- ff_mult=2,
113
- text_dim=512,
114
- text_mask_padding=False,
115
- conv_layers=4,
116
- pe_attn_head=1,
117
- )
118
  if args.finetune:
119
  if args.pretrain is None:
120
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
121
  else:
122
  ckpt_path = args.pretrain
123
-
124
  elif args.exp_name == "E2TTS_Base":
125
  wandb_resume_id = None
126
  model_cls = UNetT
127
- model_cfg = dict(
128
- dim=1024,
129
- depth=24,
130
- heads=16,
131
- ff_mult=4,
132
- text_mask_padding=False,
133
- pe_attn_head=1,
134
- )
135
  if args.finetune:
136
  if args.pretrain is None:
137
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
@@ -149,10 +118,8 @@ def main():
149
  if not os.path.isfile(file_checkpoint):
150
  shutil.copy2(ckpt_path, file_checkpoint)
151
  print("copy checkpoint for finetune")
152
- print("Pretrained checkpoint được sử dụng: " + file_checkpoint)
153
 
154
  # Use the tokenizer and tokenizer_path provided in the command line arguments
155
-
156
  tokenizer = args.tokenizer
157
  if tokenizer == "custom":
158
  if not args.tokenizer_path:
@@ -163,8 +130,8 @@ def main():
163
 
164
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
165
 
166
- print("vocab : ", vocab_size)
167
- print("vocoder : ", mel_spec_type)
168
 
169
  mel_spec_kwargs = dict(
170
  n_fft=n_fft,
@@ -189,7 +156,7 @@ def main():
189
  save_per_updates=args.save_per_updates,
190
  keep_last_n_checkpoints=args.keep_last_n_checkpoints,
191
  checkpoint_path=checkpoint_path,
192
- batch_size_per_gpu=args.batch_size_per_gpu,
193
  batch_size_type=args.batch_size_type,
194
  max_samples=args.max_samples,
195
  grad_accumulation_steps=args.grad_accumulation_steps,
 
1
  import argparse
2
  import os
3
  import shutil
 
4
 
5
  from cached_path import cached_path
 
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
+ from importlib.resources import files
10
 
11
 
12
  # -------------------------- Dataset Settings --------------------------- #
 
20
 
21
  # -------------------------- Argument Parsing --------------------------- #
22
  def parse_args():
23
+ # batch_size_per_gpu = 1000 settting for gpu 8GB
24
+ # batch_size_per_gpu = 1600 settting for gpu 12GB
25
+ # batch_size_per_gpu = 2000 settting for gpu 16GB
26
+ # batch_size_per_gpu = 3200 settting for gpu 24GB
27
+
28
+ # num_warmup_updates = 300 for 5000 sample about 10 hours
29
+
30
+ # change save_per_updates , last_per_updates change this value what you need ,
31
+
32
  parser = argparse.ArgumentParser(description="Train CFM Model")
33
 
34
  parser.add_argument(
35
+ "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
 
 
 
 
36
  )
37
+ parser.add_argument("--dataset_name", type=str, default="vnTTS_mc", help="Name of the dataset to use")
38
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
39
  parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
40
  parser.add_argument(
 
43
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
+ parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
48
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
49
  parser.add_argument(
 
54
  )
55
  parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
56
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
57
+ parser.add_argument("--pretrain", type=str, default="/mnt/d/ckpts/vn_tts_mc_vlog/pretrained_model_1200000.pt", help="the path to the checkpoint")
58
  parser.add_argument(
59
  "--tokenizer", type=str, default="char", choices=["pinyin", "char", "custom"], help="Tokenizer type"
60
  )
 
88
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
89
 
90
  # Model parameters based on experiment name
91
+ if args.exp_name == "F5TTS_Base":
 
92
  wandb_resume_id = None
93
  model_cls = DiT
94
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if args.finetune:
96
  if args.pretrain is None:
97
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
98
  else:
99
  ckpt_path = args.pretrain
 
100
  elif args.exp_name == "E2TTS_Base":
101
  wandb_resume_id = None
102
  model_cls = UNetT
103
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
 
104
  if args.finetune:
105
  if args.pretrain is None:
106
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
118
  if not os.path.isfile(file_checkpoint):
119
  shutil.copy2(ckpt_path, file_checkpoint)
120
  print("copy checkpoint for finetune")
 
121
 
122
  # Use the tokenizer and tokenizer_path provided in the command line arguments
 
123
  tokenizer = args.tokenizer
124
  if tokenizer == "custom":
125
  if not args.tokenizer_path:
 
130
 
131
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
132
 
133
+ print("\nvocab : ", vocab_size)
134
+ print("\nvocoder : ", mel_spec_type)
135
 
136
  mel_spec_kwargs = dict(
137
  n_fft=n_fft,
 
156
  save_per_updates=args.save_per_updates,
157
  keep_last_n_checkpoints=args.keep_last_n_checkpoints,
158
  checkpoint_path=checkpoint_path,
159
+ batch_size=args.batch_size_per_gpu,
160
  batch_size_type=args.batch_size_type,
161
  max_samples=args.max_samples,
162
  grad_accumulation_steps=args.grad_accumulation_steps,