diff --git a/.gitattributes b/.gitattributes
index 5b2c2cefa8481bbfc91b2d73fe574301a8f16d78..fbe3882f375b9158c51a9ea099d6c0bfbf5979e1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,43 +1,47 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
-examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
-examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
+examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..28c115f84cdbae6dd450777a7bbe5c23f130a3ec
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,28 @@
+# general things to ignore
+.DS_Store
+build/
+build_contrib/
+dist/
+.cache/
+*.egg-info/
+*.egg
+*.py[cod]
+__pycache__/
+*.so
+*~
+
+# IDE
+.vscode/
+.idea/
+
+# misc
+checkpoints/
+test_waves/
+reconstructed/
+.python-version
+ruff.log
+/configs/inuse/
+runs/
+/garbages/
+/flagged/
+/experimental/
diff --git a/README.md b/README.md
index eb7e29cb828e7162313500060b2137ae6137e8dd..887bc8c01db2950bd731f3f2aaa2ec1fdcdfebe3 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,13 @@
----
-title: Seed Voice Conversion
-emoji: 🎤🔄
-colorFrom: green
-colorTo: green
-sdk: gradio
-sdk_version: 4.42.0
-app_file: app.py
-pinned: false
-license: gpl-3.0
----
-
+---
+title: Seed Voice Conversion
+emoji: 🎤🔄
+colorFrom: green
+colorTo: green
+sdk: gradio
+sdk_version: 5.23.0
+app_file: app_v1v2.py
+pinned: false
+license: gpl-3.0
+---
+
 Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app_v1v2.py b/app_v1v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5e63efa4da6092483cf663e81eacd0a56886c4f
--- /dev/null
+++ b/app_v1v2.py
@@ -0,0 +1,175 @@
+import spaces
+import gradio as gr
+import torch
+import yaml
+import argparse
+from seed_vc_wrapper import SeedVCWrapper
+
+# Set up device and torch configurations
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+elif torch.backends.mps.is_available():
+    device = torch.device("mps")
+else:
+    device = torch.device("cpu")
+
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+    # Experimental feature to reduce compilation times, will be on by default in future
+    torch._inductor.config.fx_graph_cache = True
+
+dtype = torch.float16
+
+def load_v2_models(args):
+    from hydra.utils import instantiate
+    from omegaconf import DictConfig
+    cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
+    vc_wrapper = instantiate(cfg)
+    vc_wrapper.load_checkpoints()
+    vc_wrapper.to(device)
+    vc_wrapper.eval()
+
+    vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
+
+    if args.compile:
+        vc_wrapper.compile_ar()
+        # vc_wrapper.compile_cfm()
+
+    return vc_wrapper
+
+def create_v1_interface():
+    # Initialize the V1 wrapper
+    vc_wrapper = SeedVCWrapper()
+    
+    # Set up Gradio interface
+    description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
+                   "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
+                   "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
+                   "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
+                   "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
+    
+    inputs = [
+        gr.Audio(type="filepath", label="Source Audio / 源音频"),
+        gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
+        gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数", 
+                 info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
+        gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", 
+                 info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
+        gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", 
+                 info="has subtle influence / 有微小影响"),
+        gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, 
+                   info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
+        gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
+                   info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
+        gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, 
+                 info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
+    ]
+    
+    examples = [
+        ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
+        ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
+        ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
+         "examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
+        ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
+         "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
+    ]
+    
+    outputs = [
+        gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
+        gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
+    ]
+    
+    return gr.Interface(
+        fn=vc_wrapper.convert_voice,
+        description=description,
+        inputs=inputs,
+        outputs=outputs,
+        title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)",
+        examples=examples,
+        cache_examples=False,
+    )
+
+def create_v2_interface(vc_wrapper):
+    # Set up Gradio interface
+    description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
+                   "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
+                   "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
+                   "Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.<br> "
+                   "Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.<br> "
+                   "无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
+                   "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。"
+                   "<br>请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。<br>"
+                   "勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。<br>"
+                   
+                   "Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)"
+                   )
+    inputs = [
+        gr.Audio(type="filepath", label="Source Audio / 源音频"),
+        gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
+        gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数", 
+                 info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
+        gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", 
+                 info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
+        gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate",
+                 info="controls pronunciation intelligibility / 控制发音清晰度"),
+        gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate",
+                  info="controls similarity to reference audio / 控制与参考音频的相似度"),
+        gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
+                 info="AR model sampling top P"),
+        gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
+                 info="AR model sampling temperature"),
+        gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
+                 info="AR model sampling repetition penalty"),
+        gr.Checkbox(label="convert style/emotion/accent", value=False),
+        gr.Checkbox(label="anonymization only", value=False),
+    ]
+    
+    examples = [
+        ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
+        ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
+    ]
+    
+    outputs = [
+        gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
+        gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
+    ]
+    
+    return gr.Interface(
+        fn=vc_wrapper.convert_voice_with_streaming,
+        description=description,
+        inputs=inputs,
+        outputs=outputs,
+        title="Seed Voice Conversion V2 (Voice & Style Conversion)",
+        examples=examples,
+        cache_examples=False,
+    )
+
+def main(args):
+    # Load V2 models
+    vc_wrapper_v2 = load_v2_models(args)
+    
+    # Create interfaces
+    v1_interface = create_v1_interface()
+    v2_interface = create_v2_interface(vc_wrapper_v2)
+    
+    # Create tabs
+    with gr.Blocks(title="Seed Voice Conversion") as demo:
+        gr.Markdown("# Seed Voice Conversion")
+        gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)")
+        
+        with gr.Tabs():
+            with gr.TabItem("V2 - Voice & Style Conversion"):
+                v2_interface.render()
+            with gr.TabItem("V1 - Voice & Singing Voice Conversion"):
+                v1_interface.render()
+    
+    # Launch the combined interface
+    demo.launch()
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--compile", type=bool, default=True)
+    args = parser.parse_args()
+    main(args) 
\ No newline at end of file
diff --git a/configs/astral_quantization/default_2048.yml b/configs/astral_quantization/default_2048.yml
new file mode 100644
index 0000000000000000000000000000000000000000..54f91e7cea7722a8cdb85c9855bffeeedb84e5db
--- /dev/null
+++ b/configs/astral_quantization/default_2048.yml
@@ -0,0 +1,40 @@
+_target_: modules.astral_quantization.default_model.AstralQuantizer
+tokenizer_name: "openai/whisper-small"
+ssl_model_name: "facebook/hubert-large-ll60k"
+ssl_output_layer: 18
+encoder:
+  _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+  dim: 512
+  num_blocks: 12
+  intermediate_dim: 1536
+  dilation: 1
+  input_dim: 1024
+quantizer:
+  _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+  codebook_size: 2048  # codebook size, must be a power of 2
+  dim: 512
+  entropy_loss_weight: 0.1
+  diversity_gamma: 1.0
+  spherical: True
+  enable_entropy_loss: True
+  soft_entropy_loss: True
+decoder:
+  _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+  dim: 512
+  num_blocks: 12
+  intermediate_dim: 1536
+  dilation: 1
+  output_dim: 1024
+  gin_channels: 192
+asr_decoder:
+  _target_: modules.astral_quantization.asr_decoder.ASRDecoder
+  hidden_dim: 768
+  num_heads: 12
+  depth: 12
+  block_size: 4096
+  in_channels: 512
+  n_vocab: 51866
+  bos_id: 50528
+  eos_id: 50527
+  dropout_rate: 0.0
+  attn_dropout_rate: 0.0
\ No newline at end of file
diff --git a/configs/astral_quantization/default_32.yml b/configs/astral_quantization/default_32.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bf160129893fb893eed26eabcb8a2da9c42a7159
--- /dev/null
+++ b/configs/astral_quantization/default_32.yml
@@ -0,0 +1,40 @@
+_target_: default_model.AstralQuantizer
+tokenizer_name: "openai/whisper-small"
+ssl_model_name: "facebook/hubert-large-ll60k"
+ssl_output_layer: 18
+encoder:
+  _target_: modules.convnext.ConvNeXtV2Stage
+  dim: 512
+  num_blocks: 12
+  intermediate_dim: 1536
+  dilation: 1
+  input_dim: 1024
+quantizer:
+  _target_: modules.bsq.BinarySphericalQuantize
+  codebook_size: 32  # codebook size, must be a power of 2
+  dim: 512
+  entropy_loss_weight: 0.1
+  diversity_gamma: 1.0
+  spherical: True
+  enable_entropy_loss: True
+  soft_entropy_loss: True
+decoder:
+  _target_: modules.convnext.ConvNeXtV2Stage
+  dim: 512
+  num_blocks: 12
+  intermediate_dim: 1536
+  dilation: 1
+  output_dim: 1024
+  gin_channels: 192
+asr_decoder:
+  _target_: modules.asr_decoder.ASRDecoder
+  hidden_dim: 768
+  num_heads: 12
+  depth: 12
+  block_size: 4096
+  in_channels: 512
+  n_vocab: 51866
+  bos_id: 50528
+  eos_id: 50527
+  dropout_rate: 0.0
+  attn_dropout_rate: 0.0
\ No newline at end of file
diff --git a/configs/config.json b/configs/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e74f0b4898f6e47e1d198b62cdba989784ce2bb0
--- /dev/null
+++ b/configs/config.json
@@ -0,0 +1 @@
+{"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
\ No newline at end of file
diff --git a/configs/inuse/.gitignore b/configs/inuse/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/inuse/config.json b/configs/inuse/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..17d2df311cd2b19c1888c8fbb82effbfbc6edef3
--- /dev/null
+++ b/configs/inuse/config.json
@@ -0,0 +1 @@
+{"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02}
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0ec7ef4f6d8c7cf160687747ae01e0d39f6c128d
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
@@ -0,0 +1,98 @@
+log_dir: "./runs"
+save_freq: 1
+log_interval: 10
+save_interval: 1000
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 1
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+  sr: 44100
+  spect_params:
+    n_fft: 2048
+    win_length: 2048
+    hop_length: 512
+    n_mels: 128
+    fmin: 0
+    fmax: "None"
+
+model_params:
+  dit_type: "DiT" # uDiT or DiT
+  reg_loss_type: "l1" # l1 or l2
+
+  timbre_shifter:
+    se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+    ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+  vocoder:
+    type: "bigvgan"
+    name: "nvidia/bigvgan_v2_44khz_128band_512x"
+
+  speech_tokenizer:
+    type: 'whisper'
+    name: "openai/whisper-small"
+
+  style_encoder:
+    dim: 192
+    campplus_path: "campplus_cn_common.bin"
+
+  DAC:
+    encoder_dim: 64
+    encoder_rates: [2, 5, 5, 6]
+    decoder_dim: 1536
+    decoder_rates: [ 6, 5, 5, 2 ]
+    sr: 24000
+
+  length_regulator:
+    channels: 768
+    is_discrete: false
+    in_channels: 768
+    content_codebook_size: 2048
+    sampling_ratios: [1, 1, 1, 1]
+    vector_quantize: false
+    n_codebooks: 1
+    quantizer_dropout: 0.0
+    f0_condition: true
+    n_f0_bins: 256
+
+  DiT:
+    hidden_dim: 768
+    num_heads: 12
+    depth: 17
+    class_dropout_prob: 0.1
+    block_size: 8192
+    in_channels: 128
+    style_condition: true
+    final_layer_type: 'mlp'
+    target: 'mel' # mel or codec
+    content_dim: 768
+    content_codebook_size: 1024
+    content_type: 'discrete'
+    f0_condition: true
+    n_f0_bins: 256
+    content_codebooks: 1
+    is_causal: false
+    long_skip_connection: false
+    zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+    time_as_token: false
+    style_as_token: false
+    uvit_skip_connection: true
+    add_resblock_in_transformer: false
+
+  wavenet:
+    hidden_dim: 768
+    num_layers: 8
+    kernel_size: 5
+    dilation_rate: 1
+    p_dropout: 0.2
+    style_condition: true
+
+loss_params:
+  base_lr: 0.0001
+  lambda_mel: 45
+  lambda_kl: 1.0
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..492910d163c7773c64f846ee55384e3e8b81ac00
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
@@ -0,0 +1,91 @@
+log_dir: "./runs"
+save_freq: 1
+log_interval: 10
+save_interval: 1000
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 2
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+  sr: 22050
+  spect_params:
+    n_fft: 1024
+    win_length: 1024
+    hop_length: 256
+    n_mels: 80
+    fmin: 0
+    fmax: "None"
+
+model_params:
+  dit_type: "DiT" # uDiT or DiT
+  reg_loss_type: "l1" # l1 or l2
+
+  timbre_shifter:
+    se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+    ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+  speech_tokenizer:
+    type: 'whisper'
+    name: "openai/whisper-small"
+
+  style_encoder:
+    dim: 192
+    campplus_path: "campplus_cn_common.bin"
+
+  vocoder:
+    type: "bigvgan"
+    name: "nvidia/bigvgan_v2_22khz_80band_256x"
+
+  length_regulator:
+    channels: 512
+    is_discrete: false
+    in_channels: 768
+    content_codebook_size: 2048
+    sampling_ratios: [1, 1, 1, 1]
+    vector_quantize: false
+    n_codebooks: 1
+    quantizer_dropout: 0.0
+    f0_condition: false
+    n_f0_bins: 512
+
+  DiT:
+    hidden_dim: 512
+    num_heads: 8
+    depth: 13
+    class_dropout_prob: 0.1
+    block_size: 8192
+    in_channels: 80
+    style_condition: true
+    final_layer_type: 'wavenet'
+    target: 'mel' # mel or codec
+    content_dim: 512
+    content_codebook_size: 1024
+    content_type: 'discrete'
+    f0_condition: false
+    n_f0_bins: 512
+    content_codebooks: 1
+    is_causal: false
+    long_skip_connection: true
+    zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+    time_as_token: false
+    style_as_token: false
+    uvit_skip_connection: true
+    add_resblock_in_transformer: false
+
+  wavenet:
+    hidden_dim: 512
+    num_layers: 8
+    kernel_size: 5
+    dilation_rate: 1
+    p_dropout: 0.2
+    style_condition: true
+
+loss_params:
+  base_lr: 0.0001
+  lambda_mel: 45
+  lambda_kl: 1.0
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e0677397377158dd30ffdf905946fbd297b36bd5
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
@@ -0,0 +1,82 @@
+log_dir: "./runs/"
+save_freq: 1
+log_interval: 10
+save_interval: 500
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 2
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+  sr: 22050
+  spect_params:
+    n_fft: 1024
+    win_length: 1024
+    hop_length: 256
+    n_mels: 80
+    fmin: 0
+    fmax: 8000
+
+model_params:
+  dit_type: "DiT" # uDiT or DiT
+  reg_loss_type: "l1" # l1 or l2
+  diffusion_type: "flow"
+
+  timbre_shifter:
+    se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+    ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+  vocoder:
+    type: "hifigan"
+
+  speech_tokenizer:
+    type: 'xlsr'
+    output_layer: 12
+    name: 'facebook/wav2vec2-xls-r-300m'
+
+  style_encoder:
+    dim: 192
+    campplus_path: "campplus_cn_common.bin"
+
+  length_regulator:
+    channels: 384
+    is_discrete: false
+    in_channels: 1024
+    content_codebook_size: 1024
+    sampling_ratios: [1, 1, 1, 1]
+    vector_quantize: false
+    n_codebooks: 2
+    quantizer_dropout: 0.0
+    f0_condition: false
+    n_f0_bins: 512
+
+  DiT:
+    hidden_dim: 384
+    num_heads: 6
+    depth: 9
+    class_dropout_prob: 0.1
+    block_size: 8192
+    in_channels: 80
+    style_condition: true
+    final_layer_type: 'mlp'
+    target: 'mel' # mel or betavae
+    content_dim: 384
+    content_codebook_size: 1024
+    content_type: 'discrete'
+    f0_condition: false
+    n_f0_bins: 512
+    content_codebooks: 1
+    is_causal: false
+    long_skip_connection: false
+    zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+    time_as_token: true
+    style_as_token: true
+    uvit_skip_connection: true
+    add_resblock_in_transformer: false
+
+loss_params:
+  base_lr: 0.0001
\ No newline at end of file
diff --git a/configs/v2/ar_base.yaml b/configs/v2/ar_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/v2/dit_small.yaml b/configs/v2/dit_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60b93e953d7ec19039af147b7a52002b757a822a
--- /dev/null
+++ b/configs/v2/dit_small.yaml
@@ -0,0 +1,17 @@
+_target_: modules.v2.cfm.CFM
+estimator:
+  _target_: modules.v2.dit_wrapper.DiT
+  time_as_token: true
+  style_as_token: true
+  uvit_skip_connection: false
+  block_size: 8192
+  depth: 13
+  num_heads: 8
+  hidden_dim: 512
+  in_channels: 80
+  content_dim: 512
+  style_encoder_dim: 192
+  class_dropout_prob: 0.1
+  dropout_rate: 0.0
+  attn_dropout_rate: 0.0
+
diff --git a/configs/v2/vc_wrapper.yaml b/configs/v2/vc_wrapper.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3fe5b84431f53ebd12ec60663f40f61f4a8c231
--- /dev/null
+++ b/configs/v2/vc_wrapper.yaml
@@ -0,0 +1,105 @@
+_target_: modules.v2.vc_wrapper.VoiceConversionWrapper
+sr: 22050
+hop_size: 256
+mel_fn:
+  _target_: modules.audio.mel_spectrogram
+  _partial_: true
+  n_fft: 1024
+  win_size: 1024
+  hop_size: 256
+  num_mels: 80
+  sampling_rate: 22050
+  fmin: 0
+  fmax: null
+  center: False
+cfm:
+  _target_: modules.v2.cfm.CFM
+  estimator:
+    _target_: modules.v2.dit_wrapper.DiT
+    time_as_token: true
+    style_as_token: true
+    uvit_skip_connection: false
+    block_size: 8192
+    depth: 13
+    num_heads: 8
+    hidden_dim: 512
+    in_channels: 80
+    content_dim: 512
+    style_encoder_dim: 192
+    class_dropout_prob: 0.1
+    dropout_rate: 0.0
+    attn_dropout_rate: 0.0
+cfm_length_regulator:
+  _target_: modules.v2.length_regulator.InterpolateRegulator
+  channels: 512
+  is_discrete: true
+  codebook_size: 2048
+  sampling_ratios: [ 1, 1, 1, 1 ]
+  f0_condition: false
+ar:
+  _target_: modules.v2.ar.NaiveWrapper
+  model:
+    _target_: modules.v2.ar.NaiveTransformer
+    config:
+      _target_: modules.v2.ar.NaiveModelArgs
+      dropout: 0.0
+      rope_base: 10000.0
+      dim: 768
+      head_dim: 64
+      n_local_heads: 2
+      intermediate_size: 2304
+      n_head: 12
+      n_layer: 12
+      vocab_size: 2049  # 1 + 1 for eos
+ar_length_regulator:
+  _target_: modules.v2.length_regulator.InterpolateRegulator
+  channels: 768
+  is_discrete: true
+  codebook_size: 32
+  sampling_ratios: [ ]
+  f0_condition: false
+style_encoder:
+  _target_: modules.campplus.DTDNN.CAMPPlus
+  feat_dim: 80
+  embedding_size: 192
+content_extractor_narrow:
+  _target_: modules.astral_quantization.default_model.AstralQuantizer
+  tokenizer_name: "openai/whisper-small"
+  ssl_model_name: "facebook/hubert-large-ll60k"
+  ssl_output_layer: 18
+  skip_ssl: true
+  encoder: &bottleneck_encoder
+    _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+    dim: 512
+    num_blocks: 12
+    intermediate_dim: 1536
+    dilation: 1
+    input_dim: 1024
+  quantizer:
+    _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+    codebook_size: 32  # codebook size, must be a power of 2
+    dim: 512
+    entropy_loss_weight: 0.1
+    diversity_gamma: 1.0
+    spherical: True
+    enable_entropy_loss: True
+    soft_entropy_loss: True
+content_extractor_wide:
+  _target_: modules.astral_quantization.default_model.AstralQuantizer
+  tokenizer_name: "openai/whisper-small"
+  ssl_model_name: "facebook/hubert-large-ll60k"
+  ssl_output_layer: 18
+  encoder: *bottleneck_encoder
+  quantizer:
+    _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+    codebook_size: 2048  # codebook size, must be a power of 2
+    dim: 512
+    entropy_loss_weight: 0.1
+    diversity_gamma: 1.0
+    spherical: True
+    enable_entropy_loss: True
+    soft_entropy_loss: True
+vocoder:
+  _target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
+  pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
+  use_cuda_kernel: false
diff --git a/hf_utils.py b/hf_utils.py
index 9f8c7f7d5f1b82efbd788c7327f76c0dc6a9355a..4ae986c1d88d5b1214c9d3101249e90c5c370ca9 100644
--- a/hf_utils.py
+++ b/hf_utils.py
@@ -2,7 +2,7 @@ import os
 from huggingface_hub import hf_hub_download
 
 
-def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
+def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
     os.makedirs("./checkpoints", exist_ok=True)
     model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
     if config_filename is None:
diff --git a/modules/__pycache__/audio.cpython-310.pyc b/modules/__pycache__/audio.cpython-310.pyc
index 79bf4bb4261f2d91fe5d3efb024339a88274162c..651e7ad6c3e297013d527e4a9218ae35f2b92c41 100644
Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ
diff --git a/modules/__pycache__/commons.cpython-310.pyc b/modules/__pycache__/commons.cpython-310.pyc
index 9289cfe3ac9362aaeece6546f5b78bbfea6ca40b..5adfe7b95903d4c3134f74bdf68b5b413b848c97 100644
Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ
diff --git a/modules/__pycache__/commons.cpython-38.pyc b/modules/__pycache__/commons.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfe6a34f14ee453f4a79de9b39d80fc62b4d7fda
Binary files /dev/null and b/modules/__pycache__/commons.cpython-38.pyc differ
diff --git a/modules/__pycache__/diffusion_transformer.cpython-310.pyc b/modules/__pycache__/diffusion_transformer.cpython-310.pyc
index c721982be6feb37d7ae333799842c197225c9697..4bcbf046526a3c8e13cf180ebea709e89305d848 100644
Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ
diff --git a/modules/__pycache__/flow_matching.cpython-310.pyc b/modules/__pycache__/flow_matching.cpython-310.pyc
index 4f42f2602f27a9f430f3daf9ca3825ebe86172b5..4a6f71cb0ef1cd096cdeed48330dbce1aed55734 100644
Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ
diff --git a/modules/__pycache__/length_regulator.cpython-310.pyc b/modules/__pycache__/length_regulator.cpython-310.pyc
index 301c8f57e7713f62bf83b9c4fa712fe7680f26be..c2ada28a9ce6e0b8a61901f913b5ddde33b0c9ac 100644
Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ
diff --git a/modules/__pycache__/rmvpe.cpython-310.pyc b/modules/__pycache__/rmvpe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18f08923f7c01ba34a8719dd56cd1c05fdf4e33d
Binary files /dev/null and b/modules/__pycache__/rmvpe.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..918dc5ae503d2a7da3860b7637f9af7aa81ea466
Binary files /dev/null and b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..508b91e576abbb86db68531cda662c8e9daa0fe6
Binary files /dev/null and b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..381bd7f361d729a1640b7eed29bbfdc0515cbcc0
Binary files /dev/null and b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc differ
diff --git a/modules/astral_quantization/bsq.py b/modules/astral_quantization/bsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b70f3401dafa187fc19666d31a0877f533abe33
--- /dev/null
+++ b/modules/astral_quantization/bsq.py
@@ -0,0 +1,569 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+In the simplest setup, each dimension is quantized into {-1, 1}.
+An entropy penalty is used to encourage utilization.
+"""
+
+from math import log2, ceil
+from functools import partial, cache
+from collections import namedtuple
+from contextlib import nullcontext
+
+import torch.distributed as dist
+from torch.distributed import nn as dist_nn
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from torch.nn import Module
+from torch.amp import autocast
+
+from einops import rearrange, reduce, pack, unpack
+
+# constants
+
+Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
+
+LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
+
+# distributed helpers
+
+@cache
+def is_distributed():
+    return dist.is_initialized() and dist.get_world_size() > 1
+
+def maybe_distributed_mean(t):
+    if not is_distributed():
+        return t
+
+    dist_nn.all_reduce(t)
+    t = t / dist.get_world_size()
+    return t
+
+# helper functions
+
+def exists(v):
+    return v is not None
+
+def identity(t):
+    return t
+
+def default(*args):
+    for arg in args:
+        if exists(arg):
+            return arg() if callable(arg) else arg
+    return None
+
+def pack_one(t, pattern):
+    return pack([t], pattern)
+
+def unpack_one(t, ps, pattern):
+    return unpack(t, ps, pattern)[0]
+
+def l2norm(t):
+    return F.normalize(t, dim = -1)
+
+# entropy
+
+def log(t, eps = 1e-5):
+    return t.clamp(min = eps).log()
+
+def entropy(prob):
+    return (-prob * log(prob)).sum(dim=-1)
+
+# cosine sim linear
+
+class CosineSimLinear(Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_out,
+        scale = 1.
+    ):
+        super().__init__()
+        self.scale = scale
+        self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
+
+    def forward(self, x):
+        x = F.normalize(x, dim = -1)
+        w = F.normalize(self.weight, dim = 0)
+        return (x @ w) * self.scale
+
+def soft_entropy_loss(u, tau=1.0, gamma=1.0):
+    """
+    Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
+
+    Args:
+        u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
+        tau (float): Temperature scaling factor.
+        gamma (float): Weight for the second entropy term.
+
+    Returns:
+        torch.Tensor: Soft entropy loss.
+    """
+    # Binary quantization: Generate implicit codebook corners
+    L = u.size(1)  # Dimensionality of codebook
+    corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
+
+    # Compute soft quantization probabilities for all dimensions
+    # q_hat(c|u) for each dimension
+    prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2))  # Shape: (batch_size, L, 2)
+
+    # Entropy of q_hat(c|u) (independent along each dimension)
+    entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1)  # Shape: (batch_size, L)
+    entropy_term1 = entropy_per_dim.mean()
+
+    # Expected probabilities for dataset entropy (approximation)
+    expected_probs = prob_matrix.mean(dim=0)  # Mean across batch, shape: (L, 2)
+    entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
+
+    # Final entropy loss
+    loss = entropy_term1 - gamma * entropy_term2
+    return loss
+
+# class
+
+class BinarySphericalQuantize(Module):
+    def __init__(
+        self,
+        *,
+        dim = None,
+        codebook_size = None,
+        entropy_loss_weight = 0.1,
+        commitment_loss_weight = 0.,
+        diversity_gamma = 1.,
+        straight_through_activation = nn.Identity(),
+        num_codebooks = 1,
+        keep_num_codebooks_dim = None,
+        codebook_scale = 1.,                        # for residual LFQ, codebook scaled down by 2x at each layer
+        frac_per_sample_entropy = 0.25,               # make less than 1. to only use a random fraction of the probs for per sample entropy
+        has_projections = None,
+        projection_has_bias = True,
+        soft_clamp_input_value = None,
+        cosine_sim_project_in = False,
+        cosine_sim_project_in_scale = None,
+        channel_first = None,
+        experimental_softplus_entropy_loss = False,
+        entropy_loss_offset = 5.,                   # how much to shift the loss before softplus
+        spherical = True,                          # from https://arxiv.org/abs/2406.07548
+        force_quantization_f32 = True,               # will force the quantization step to be full precision
+        enable_entropy_loss = True,
+        soft_entropy_loss = True,
+    ):
+        super().__init__()
+
+        # some assert validations
+
+        assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
+        assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
+
+        codebook_size = default(codebook_size, lambda: 2 ** dim)
+        self.codebook_size = codebook_size
+
+        codebook_dim = int(log2(codebook_size))
+        codebook_dims = codebook_dim * num_codebooks
+        dim = default(dim, codebook_dims)
+
+        has_projections = default(has_projections, dim != codebook_dims)
+
+        if cosine_sim_project_in:
+            cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
+            project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
+        else:
+            project_in_klass = partial(nn.Linear, bias = projection_has_bias)
+
+        self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
+        self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
+        self.has_projections = has_projections
+
+        self.dim = dim
+        self.codebook_dim = codebook_dim
+        self.num_codebooks = num_codebooks
+
+        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+        self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+        # channel first
+
+        self.channel_first = channel_first
+
+        # straight through activation
+
+        self.activation = straight_through_activation
+
+        # whether to use BSQ (binary spherical quantization)
+
+        self.spherical = spherical
+        self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
+
+        # entropy aux loss related weights
+
+        assert 0 < frac_per_sample_entropy <= 1.
+        self.frac_per_sample_entropy = frac_per_sample_entropy
+
+        self.diversity_gamma = diversity_gamma
+        self.entropy_loss_weight = entropy_loss_weight
+
+        # codebook scale
+
+        self.codebook_scale = codebook_scale
+
+        # commitment loss
+
+        self.commitment_loss_weight = commitment_loss_weight
+
+        # whether to soft clamp the input value from -value to value
+
+        self.soft_clamp_input_value = soft_clamp_input_value
+        assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
+
+        # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
+
+        self.entropy_loss_offset = entropy_loss_offset
+        self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
+
+        # for no auxiliary loss, during inference
+
+        self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
+        self.register_buffer('zero', torch.tensor(0.), persistent = False)
+
+        # whether to force quantization step to be f32
+
+        self.force_quantization_f32 = force_quantization_f32
+
+        # codes
+        self.enable_entropy_loss = enable_entropy_loss
+        self.soft_entropy_loss = soft_entropy_loss
+        if codebook_size <= 100000:
+            all_codes = torch.arange(codebook_size)
+            bits = ((all_codes[..., None].int() & self.mask) != 0).float()
+            codebook = self.bits_to_codes(bits)
+
+            self.register_buffer('codebook', codebook.float(), persistent = False)
+        else:
+            all_codes = torch.arange(pow(2, 16))
+            mask = 2 ** torch.arange(16 - 1, -1, -1)
+            bits = ((all_codes[..., None].int() & mask) != 0).float()
+            codebook = self.bits_to_codes(bits)
+
+            self.register_buffer('codebook', codebook.float(), persistent = False)
+
+    def bits_to_codes(self, bits):
+        return bits * self.codebook_scale * 2 - self.codebook_scale
+
+    @property
+    def dtype(self):
+        return self.codebook.dtype
+
+    def indices_to_codes(
+        self,
+        indices,
+        project_out = True
+    ):
+        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+        should_transpose = default(self.channel_first, is_img_or_video)
+
+        if not self.keep_num_codebooks_dim:
+            indices = rearrange(indices, '... -> ... 1')
+
+        # indices to codes, which are bits of either -1 or 1
+
+        bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
+
+        codes = self.bits_to_codes(bits)
+
+        codes = self.maybe_l2norm(codes)
+
+        codes = rearrange(codes, '... c d -> ... (c d)')
+
+        # whether to project codes out to original dimensions
+        # if the input feature dimensions were not log2(codebook size)
+
+        if project_out:
+            codes = self.project_out(codes)
+
+        # rearrange codes back to original shape
+
+        if should_transpose:
+            codes = rearrange(codes, 'b ... d -> b d ...')
+
+        return codes
+
+    def bits_to_z(self, bits):
+        # assert bits must contain only -1 and 1
+        assert torch.all(bits.abs() == 1)
+        quantized = bits.float()
+        quantized = self.maybe_l2norm(quantized)
+        z = self.project_out(quantized)
+        return z
+
+    def forward(
+        self,
+        x,
+        inv_temperature = 100.,
+        return_loss_breakdown = False,
+        mask = None,
+        return_bits = False
+    ):
+        """
+        einstein notation
+        b - batch
+        n - sequence (or flattened spatial dimensions)
+        d - feature dimension, which is also log2(codebook size)
+        c - number of codebook dim
+        """
+
+        is_img_or_video = x.ndim >= 4
+        should_transpose = default(self.channel_first, is_img_or_video)
+
+        # standardize image or video into (batch, seq, dimension)
+
+        if should_transpose:
+            x = rearrange(x, 'b d ... -> b ... d')
+            x, ps = pack_one(x, 'b * d')
+
+        assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
+
+        x = self.project_in(x)
+
+        # maybe soft clamp
+
+        if exists(self.soft_clamp_input_value):
+            clamp_value = self.soft_clamp_input_value
+            x = (x / clamp_value).tanh() * clamp_value
+
+        # split out number of codebooks
+
+        x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
+
+        # maybe l2norm
+
+        x = self.maybe_l2norm(x)
+
+        # whether to force quantization step to be full precision or not
+
+        force_f32 = self.force_quantization_f32
+
+        quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
+
+        with quantization_context():
+
+            if force_f32:
+                orig_dtype = x.dtype
+                x = x.float()
+
+            # quantize by eq 3.
+
+            original_input = x
+
+            codebook_value = torch.ones_like(x) * self.codebook_scale
+            quantized = torch.where(x > 0, codebook_value, -codebook_value)
+            if return_bits:
+                return quantized
+
+            # calculate indices
+
+            indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
+
+            # maybe l2norm
+
+            quantized = self.maybe_l2norm(quantized)
+
+            # use straight-through gradients (optionally with custom activation fn) if training
+
+            if self.training:
+                x = self.activation(x)
+                x = x + (quantized - x).detach()
+            else:
+                x = quantized
+
+            # entropy aux loss
+            if self.soft_entropy_loss:
+                entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
+            elif self.training and self.enable_entropy_loss:
+
+                if force_f32:
+                    codebook = self.codebook.float()
+
+                codebook = self.maybe_l2norm(codebook)
+
+                # whether to only use a fraction of probs, for reducing memory
+
+                if self.frac_per_sample_entropy < 1.:
+                    # account for mask
+                    if exists(mask):
+                        original_input = original_input[mask]
+                    original_input = rearrange(original_input, 'b n ... -> (b n) ...')
+
+                    rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
+
+                    sampled_input = original_input[..., rand_mask]
+
+                    sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
+
+                    sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
+
+                    per_sample_probs = sampled_prob
+                else:
+                    if exists(mask):
+                        original_input = original_input[mask]
+                    original_input = rearrange(original_input, 'b n ... -> (b n) ...')
+                    # the same as euclidean distance up to a constant
+                    distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
+
+                    prob = (-distance * inv_temperature).softmax(dim = -1)
+
+                    per_sample_probs = prob
+
+                # calculate per sample entropy
+
+                per_sample_entropy = entropy(per_sample_probs).mean()
+
+                # distribution over all available tokens in the batch
+
+                avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
+
+                avg_prob = maybe_distributed_mean(avg_prob)
+
+                codebook_entropy = entropy(avg_prob).mean()
+
+                # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
+                # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
+
+                entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
+            else:
+                # if not training, just return dummy 0
+                entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
+
+            # whether to make the entropy loss positive or not through a (shifted) softplus
+
+            if self.training and self.experimental_softplus_entropy_loss:
+                entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
+
+            # commit loss
+
+            if self.training and self.commitment_loss_weight > 0.:
+
+                commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
+
+                if exists(mask):
+                    commit_loss = commit_loss[mask]
+
+                commit_loss = commit_loss.mean()
+            else:
+                commit_loss = self.zero
+
+            # input back to original dtype if needed
+
+            if force_f32:
+                x = x.type(orig_dtype)
+
+        # merge back codebook dim
+
+        x = rearrange(x, 'b n c d -> b n (c d)')
+
+        # project out to feature dimension if needed
+
+        x = self.project_out(x)
+
+        # reconstitute image or video dimensions
+
+        if should_transpose:
+            x = unpack_one(x, ps, 'b * d')
+            x = rearrange(x, 'b ... d -> b d ...')
+
+            indices = unpack_one(indices, ps, 'b * c')
+
+        # whether to remove single codebook dim
+
+        if not self.keep_num_codebooks_dim:
+            indices = rearrange(indices, '... 1 -> ...')
+
+        # complete aux loss
+
+        aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
+
+        # returns
+
+        ret = Return(x, indices, aux_loss)
+
+        if not return_loss_breakdown:
+            return ret
+
+        return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
+
+class GroupedResidualBSQ(Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        groups = 1,
+        accept_image_fmap = False,
+        **kwargs
+    ):
+        super().__init__()
+        self.dim = dim
+        self.groups = groups
+        assert (dim % groups) == 0
+        dim_per_group = dim // groups
+
+        self.accept_image_fmap = accept_image_fmap
+
+        self.rvqs = nn.ModuleList([])
+
+        for _ in range(groups):
+            self.rvqs.append(LFQ(
+                dim = dim_per_group,
+                **kwargs
+            ))
+
+        self.codebook_size = self.rvqs[0].codebook_size
+
+    @property
+    def codebooks(self):
+        return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
+
+    @property
+    def split_dim(self):
+        return 1 if self.accept_image_fmap else -1
+
+    def get_codes_from_indices(self, indices):
+        codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
+        return torch.stack(codes)
+
+    def get_output_from_indices(self, indices):
+        outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
+        return torch.cat(outputs, dim = self.split_dim)
+
+    def forward(
+        self,
+        x,
+        return_all_codes = False
+    ):
+        shape, split_dim = x.shape, self.split_dim
+        assert shape[split_dim] == self.dim
+
+        # split the feature dimension into groups
+
+        x = x.chunk(self.groups, dim = split_dim)
+
+        forward_kwargs = dict(
+        )
+
+        # invoke residual vq on each group
+
+        out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
+        out = tuple(zip(*out))
+
+        # otherwise, get all the zipped outputs and combine them
+
+        quantized, all_indices, *maybe_aux_loss = out
+
+        quantized = torch.cat(quantized, dim = split_dim)
+        all_indices = torch.stack(all_indices)
+
+        ret = (quantized, all_indices, *maybe_aux_loss)
+        return ret
diff --git a/modules/astral_quantization/convnext.py b/modules/astral_quantization/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bef9e282ac332b7169339eeda082d0581d739d1
--- /dev/null
+++ b/modules/astral_quantization/convnext.py
@@ -0,0 +1,209 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+
+class ConvNextV2LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+    """
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError(f"Unsupported data format: {self.data_format}")
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.data_format == "channels_last":
+            x = torch.nn.functional.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            input_dtype = x.dtype
+            x = x.float()
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = x.to(dtype=input_dtype)
+            x = self.weight[None, :, None] * x + self.bias[None, :, None]
+        return x
+
+
+class GRN(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
+        self.beta = nn.Parameter(torch.zeros(1, 1, dim))
+
+    def forward(self, x):
+        Gx = torch.norm(x, p=2, dim=1, keepdim=True)
+        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+        return self.gamma * (x * Nx) + self.beta + x
+
+class InterpolationLayer(nn.Module):
+    def __init__(self, ):  # this is a default of 1 / 50 * (44100 / 512) / 4
+        super().__init__()
+        pass
+
+    def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+        x = F.interpolate(x, size=target_len, mode='linear')
+        return x
+
+class ConvNeXtV2Stage(nn.Module):
+    def __init__(
+        self,
+        dim: int = 512,
+        intermediate_dim: int = 2048,
+        num_blocks: int = 1,
+        dilation: int = 1,
+        downsample_layer_indices: List[int] = None,
+        downsample_factors: List[int] = None,
+        upsample_layer_indices: List[int] = None,
+        upsample_factors: List[int] = None,
+        interpolation_layer_indices: List[int] = None,
+        input_dim: int = None,
+        output_dim: int = None,
+        gin_channels: int = 0,
+    ):
+        super().__init__()
+        # maybe downsample layers
+        if downsample_layer_indices is not None:
+            assert downsample_factors is not None
+            self.downsample_blocks = nn.ModuleList(
+                [
+                    nn.Sequential(
+                        ConvNextV2LayerNorm(dim, data_format="channels_first"),
+                        nn.Conv1d(
+                            dim, dim, kernel_size=downsample_factor, stride=downsample_factor
+                        ),
+                    ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
+                ]
+            )
+            self.downsample_layer_indices = downsample_layer_indices
+        else:
+            self.downsample_blocks = nn.ModuleList()
+            self.downsample_layer_indices = []
+
+        # maybe upsample layers
+        if upsample_layer_indices is not None:
+            assert upsample_factors is not None
+            self.upsample_blocks = nn.ModuleList(
+                [
+                    nn.Sequential(
+                        ConvNextV2LayerNorm(dim, data_format="channels_first"),
+                        nn.ConvTranspose1d(
+                            dim, dim, kernel_size=upsample_factor, stride=upsample_factor
+                        ),
+                    ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
+                ]
+            )
+            self.upsample_layer_indices = upsample_layer_indices
+        else:
+            self.upsample_blocks = nn.ModuleList()
+            self.upsample_layer_indices = []
+
+        # maybe interpolation layers
+        if interpolation_layer_indices is not None:
+            self.interpolation_blocks = nn.ModuleList(
+                [
+                    InterpolationLayer()
+                    for _ in interpolation_layer_indices
+                ]
+            )
+            self.interpolation_layer_indices = interpolation_layer_indices
+        else:
+            self.interpolation_blocks = nn.ModuleList()
+            self.interpolation_layer_indices = []
+
+        # main blocks
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtV2Block(
+                    dim=dim,
+                    intermediate_dim=intermediate_dim,
+                    dilation=dilation,
+                )
+                for _ in range(num_blocks)
+            ]
+        )
+        # maybe input and output projections
+        if input_dim is not None and input_dim != dim:
+            self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
+        else:
+            self.input_projection = nn.Identity()
+        if output_dim is not None and output_dim != dim:
+            self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
+        else:
+            self.output_projection = nn.Identity()
+
+        if gin_channels > 0:
+            self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
+
+    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+        x = self.input_projection(x)  # B, D, T
+        if hasattr(self, 'gin'):
+            g = kwargs['g']
+            x = x + self.gin(g)
+        # pad to a multiple of cumprod(downsample_factors)
+        if len(self.downsample_blocks) > 0:
+            downsample_factor = 1
+            for factor in self.downsample_blocks:
+                downsample_factor *= factor[1].stride[0]
+            pad_len = downsample_factor - x.size(-1) % downsample_factor
+            if pad_len > 0:
+                x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
+
+        # main blocks
+        for layer_idx, block in enumerate(self.blocks):
+            if layer_idx in self.downsample_layer_indices:
+                x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
+            if layer_idx in self.upsample_layer_indices:
+                x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
+            if layer_idx in self.interpolation_layer_indices:
+                x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
+            x = block(x)
+        x = self.output_projection(x)
+        return x
+
+    def setup_caches(self, *args, **kwargs):
+        pass
+
+
+class ConvNeXtV2Block(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        intermediate_dim: int,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        padding = (dilation * (7 - 1)) // 2
+        self.dwconv = nn.Conv1d(
+            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
+        )  # depthwise conv
+        self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
+        self.pwconv1 = nn.Linear(
+            dim, intermediate_dim
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.grn = GRN(intermediate_dim)
+        self.pwconv2 = nn.Linear(intermediate_dim, dim)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        residual = x
+        x = self.dwconv(x)
+        x = self.norm(x)
+        x = x.transpose(1, 2)  # b d n -> b n d
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.grn(x)
+        x = self.pwconv2(x)
+        x = x.transpose(1, 2)  # b n d -> b d n
+        return residual + x
\ No newline at end of file
diff --git a/modules/astral_quantization/default_model.py b/modules/astral_quantization/default_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca50d1cb1b177a89955c9f13451e5ed159b753b5
--- /dev/null
+++ b/modules/astral_quantization/default_model.py
@@ -0,0 +1,73 @@
+import torch
+from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
+
+class AstralQuantizer(torch.nn.Module):
+    def __init__(
+            self,
+            tokenizer_name: str,
+            ssl_model_name: str,
+            ssl_output_layer: int,
+            encoder: torch.nn.Module,
+            quantizer: torch.nn.Module,
+            skip_ssl: bool = False,
+    ):
+        super().__init__()
+        self.encoder = encoder
+        self.quantizer = quantizer
+        self.tokenizer_name = tokenizer_name
+        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+        # Load SSL model from Huggingface
+        self.ssl_model_name = ssl_model_name
+        self.ssl_output_layer = ssl_output_layer
+        self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
+
+        if skip_ssl:  # in case the same SSL model has been loaded somewhere else
+            self.ssl_model = None
+        else:
+            self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
+            self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
+            self.ssl_model.encoder.layer_norm = torch.nn.Identity()
+
+    def load_separate_checkpoint(self, checkpoint_path):
+        params = torch.load(checkpoint_path, map_location='cpu')['net']
+        for key in params.keys():
+            for k in list(params[key].keys()):
+                if k.startswith("module."):
+                    params[key][k[len("module."):]] = params[key][k]
+                    del params[key][k]
+        self.encoder.load_state_dict(params['encoder'])
+        self.quantizer.load_state_dict(params['vq'])
+        if self.decoder is not None:
+            self.decoder.load_state_dict(params['decoder'])
+        if self.asr_decoder is not None:
+            self.asr_decoder.load_state_dict(params['predictor'], strict=False)
+
+    def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
+        ssl_fn = self.ssl_model if self.ssl_model else ssl_model
+        assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
+        waves_16k_input_list = [
+            waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
+            for bib in range(len(waves_16k))
+        ]
+        alt_inputs = self.ssl_feature_extractor(
+            waves_16k_input_list,
+            return_tensors='pt',
+            return_attention_mask=True,
+            padding=True,
+            sampling_rate=16000
+        ).to(waves_16k.device)
+        feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320  # frame rate of hubert is 50 Hz
+
+        outputs = ssl_fn(
+            alt_inputs.input_values,
+            attention_mask=alt_inputs.attention_mask,
+        )
+        last_hidden_states = outputs.last_hidden_state
+        last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
+        feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
+        last_hidden_states = last_hidden_states.transpose(1, 2)
+        x_hidden = self.encoder(last_hidden_states, feature_lens)
+        x_hidden = x_hidden.transpose(1, 2)
+        x_quantized, indices = self.quantizer(x_hidden)[:2]
+        return x_quantized, indices, feature_lens
\ No newline at end of file
diff --git a/modules/astral_quantization/transformer.py b/modules/astral_quantization/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..015ec417d02e3442d8ea120c3802807c73e89bb4
--- /dev/null
+++ b/modules/astral_quantization/transformer.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+import time
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNorm, self).__init__()
+        self.project_layer = nn.Linear(d_model, 2 * d_model)
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+        if embedding is None:
+            return self.norm(input)
+        weight, bias = torch.split(
+            self.project_layer(embedding),
+            split_size_or_sections=self.d_model,
+            dim=-1,
+        )
+        return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+    block_size: int = 2048
+    vocab_size: int = 32000
+    n_layer: int = 32
+    n_head: int = 32
+    dim: int = 4096
+    intermediate_size: int = None
+    n_local_heads: int = -1
+    head_dim: int = 64
+    rope_base: float = 10000
+    norm_eps: float = 1e-5
+    has_cross_attention: bool = False
+    context_dim: int = 0
+    is_causal: bool = False
+    dropout_rate: float = 0.1
+    attn_dropout_rate: float = 0.1
+
+    def __post_init__(self):
+        if self.n_local_heads == -1:
+            self.n_local_heads = self.n_head
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.config = config
+
+        self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+        self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        self.max_batch_size = -1
+        self.max_seq_length = config.block_size
+        freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+                                              self.config.rope_base)
+        self.register_buffer("freqs_cis", freqs_cis)
+
+        causal_mask = torch.tril(
+            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+        )
+        self.register_buffer("causal_mask", causal_mask)
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Optional[Tensor] = None,
+                mask: Optional[Tensor] = None,
+                context: Optional[Tensor] = None,
+                context_input_pos: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                ) -> Tensor:
+        if mask is None:
+            mask = self.causal_mask[:x.size(1), :x.size(1)]
+        else:
+            mask = mask[..., input_pos]
+        freqs_cis = self.freqs_cis[input_pos]
+        if context is not None:
+            context_freqs_cis = self.freqs_cis[context_input_pos]
+        else:
+            context_freqs_cis = None
+        skip_in_x_list = []
+        for i, layer in enumerate(self.layers):
+            x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
+        x = self.norm(x, c)
+        return x
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.attention = Attention(config)
+        self.feed_forward = FeedForward(config)
+        self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        if config.has_cross_attention:
+            self.has_cross_attention = True
+            self.cross_attention = Attention(config, is_cross_attention=True)
+            self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        else:
+            self.has_cross_attention = False
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                ) -> Tensor:
+        #time_attn_start = time.time()
+        h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
+        #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
+        if self.has_cross_attention:
+            h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
+        out = h + self.feed_forward(self.ffn_norm(h, c))
+        return out
+
+
+class Attention(nn.Module):
+    def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+        # key, query, value projections for all heads, but in a batch
+        if is_cross_attention:
+            self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+            self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+        else:
+            self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+        self.kv_cache = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.n_local_heads = config.n_local_heads
+        self.dim = config.dim
+        self.attn_dropout_rate = config.attn_dropout_rate
+
+    def forward(self,
+                x: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_local_heads * self.head_dim
+        if context is None:
+            q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+            context_seqlen = seqlen
+        else:
+            q = self.wq(x)
+            k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+            context_seqlen = context.shape[1]
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+        v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+        self.dropout = nn.Dropout(config.dropout_rate)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-5):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+    def forward(self, x: Tensor) -> Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(
+        seq_len: int, n_elem: int, base: int = 10000,
+        dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
diff --git a/modules/audio.py b/modules/audio.py
index abe783b0e0af630319700c931eb51d2ce375282b..ae677ffb1c124b557b3dbe0343ae415f5281cddb 100644
--- a/modules/audio.py
+++ b/modules/audio.py
@@ -1,82 +1,82 @@
-import numpy as np
-import torch
-import torch.utils.data
-from librosa.filters import mel as librosa_mel_fn
-from scipy.io.wavfile import read
-
-MAX_WAV_VALUE = 32768.0
-
-
-def load_wav(full_path):
-    sampling_rate, data = read(full_path)
-    return data, sampling_rate
-
-
-def dynamic_range_compression(x, C=1, clip_val=1e-5):
-    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
-
-
-def dynamic_range_decompression(x, C=1):
-    return np.exp(x) / C
-
-
-def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
-    return torch.log(torch.clamp(x, min=clip_val) * C)
-
-
-def dynamic_range_decompression_torch(x, C=1):
-    return torch.exp(x) / C
-
-
-def spectral_normalize_torch(magnitudes):
-    output = dynamic_range_compression_torch(magnitudes)
-    return output
-
-
-def spectral_de_normalize_torch(magnitudes):
-    output = dynamic_range_decompression_torch(magnitudes)
-    return output
-
-
-mel_basis = {}
-hann_window = {}
-
-
-def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
-    if torch.min(y) < -1.0:
-        print("min value is ", torch.min(y))
-    if torch.max(y) > 1.0:
-        print("max value is ", torch.max(y))
-
-    global mel_basis, hann_window  # pylint: disable=global-statement
-    if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
-        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
-        mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
-        hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
-
-    y = torch.nn.functional.pad(
-        y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
-    )
-    y = y.squeeze(1)
-
-    spec = torch.view_as_real(
-        torch.stft(
-            y,
-            n_fft,
-            hop_length=hop_size,
-            win_length=win_size,
-            window=hann_window[str(sampling_rate) + "_" + str(y.device)],
-            center=center,
-            pad_mode="reflect",
-            normalized=False,
-            onesided=True,
-            return_complex=True,
-        )
-    )
-
-    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
-
-    spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
-    spec = spectral_normalize_torch(spec)
-
-    return spec
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+    sampling_rate, data = read(full_path)
+    return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+    return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+    output = dynamic_range_decompression_torch(magnitudes)
+    return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global mel_basis, hann_window  # pylint: disable=global-statement
+    if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
+        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+        mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+        hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
+    )
+    y = y.squeeze(1)
+
+    spec = torch.view_as_real(
+        torch.stft(
+            y,
+            n_fft,
+            hop_length=hop_size,
+            win_length=win_size,
+            window=hann_window[str(sampling_rate) + "_" + str(y.device)],
+            center=center,
+            pad_mode="reflect",
+            normalized=False,
+            onesided=True,
+            return_complex=True,
+        )
+    )
+
+    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+    spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec
diff --git a/modules/bigvgan/__pycache__/activations.cpython-310.pyc b/modules/bigvgan/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bedc597b24c34a77805b188e40b71f1d5f221118
Binary files /dev/null and b/modules/bigvgan/__pycache__/activations.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dba1f3aac51089857f5a8226f45af8441fbf65a3
Binary files /dev/null and b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/env.cpython-310.pyc b/modules/bigvgan/__pycache__/env.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c5f5045d1715338beb0335539cb76f229124f54
Binary files /dev/null and b/modules/bigvgan/__pycache__/env.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58995f652192d5e524e02445b22451a1d8ea87b2
Binary files /dev/null and b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/utils.cpython-310.pyc b/modules/bigvgan/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..763d6d7ead834b85c9ea9aa27f306f8d59041fd5
Binary files /dev/null and b/modules/bigvgan/__pycache__/utils.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7692a9a3df20d2fe2b7923398589e05741607fd2
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e13bedd060a666d59013e80b43573fad05e3516
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..089a26db09f702d6a6984743f4befb52b1c80fbf
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/activation1d.py b/modules/bigvgan/alias_free_activation/cuda/activation1d.py
index fc0d313cb265170943fb7cb16742b031038f7859..76797ef1424b99a160803d53cb6c24fe20599bd2 100644
--- a/modules/bigvgan/alias_free_activation/cuda/activation1d.py
+++ b/modules/bigvgan/alias_free_activation/cuda/activation1d.py
@@ -3,10 +3,10 @@
 
 import torch
 import torch.nn as nn
-from alias_free_activation.torch.resample import UpSample1d, DownSample1d
+from ..torch.resample import UpSample1d, DownSample1d
 
 # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
-from alias_free_activation.cuda import load
+from ..cuda import load
 
 anti_alias_activation_cuda = load.load()
 
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
new file mode 100644
index 0000000000000000000000000000000000000000..ce9bc6ec886c4bda48c73677998abe0b2b73bfcc
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1
+size 522068
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
new file mode 100644
index 0000000000000000000000000000000000000000..bd3c097a5622dde1a5c17fd152e04750a1dedded
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
@@ -0,0 +1,7 @@
+# ninja log v5
+9	39554	7516864785377831	anti_alias_activation.o	3a177f31dd72c43c
+13	152601	7516865914203767	anti_alias_activation_cuda.cuda.o	2d613e7382d803fd
+152628	153062	7516865920541751	anti_alias_activation_cuda.pyd	f6366e9bdfb27f7
+128	50503	7654004565901584	anti_alias_activation.o	9ed3213f2e0d0858
+133	176837	7654005827401976	anti_alias_activation_cuda.cuda.o	a679b6661c609136
+176839	177401	7654005835005523	anti_alias_activation_cuda.pyd	f6366e9bdfb27f7
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
new file mode 100644
index 0000000000000000000000000000000000000000..812f06975323c9e1937fa01c943e31ae02322145
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c
+size 3246655
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..329fb7a9b147a0af665ff7f7686bd0cc915ecc84
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4
+size 4724513
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp
new file mode 100644
index 0000000000000000000000000000000000000000..3093a741ef126748042cafaef5c368f3ec5e2d3f
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib
new file mode 100644
index 0000000000000000000000000000000000000000..1be22a5e2a68606c56c961333b67a251bf40d8ea
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
new file mode 100644
index 0000000000000000000000000000000000000000..dc51b91fc3d147e24ad0155ee31809556bb7208a
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03
+size 2454528
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/build.ninja b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja
new file mode 100644
index 0000000000000000000000000000000000000000..8c41cf88948be2657b26c226365a86b99278764a
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja
@@ -0,0 +1,38 @@
+ninja_required_version = 1.3
+cxx = cl
+nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc
+
+cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc
+post_cflags = 
+cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80
+cuda_post_cflags = 
+cuda_dlink_post_cflags = 
+sycl_dlink_post_cflags = 
+ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib
+
+rule compile
+  command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags
+  deps = msvc
+
+rule cuda_compile
+  depfile = $out.d
+  deps = gcc
+  command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+
+
+rule link
+  command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out
+
+build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp
+build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu
+
+
+
+
+
+build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o
+
+default anti_alias_activation_cuda.pyd
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76ee62bc7f4fd61a4da3faf3b5b608082eecd92a
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6235f3b7476dd85812d5100a2d6ea941bca1280b
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb021ff5f7292a9def47b079d56dbfcc5a18f150
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17524739fecc6a6ffecbc9ec37007645cd6b422d
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc differ
diff --git a/modules/bigvgan/bigvgan.py b/modules/bigvgan/bigvgan.py
index 5a1196fa9fc6bca4276e23d5fe659e3f5af9b04a..41d6e44a2cb59a39d51cb4994d05804dc497dda7 100644
--- a/modules/bigvgan/bigvgan.py
+++ b/modules/bigvgan/bigvgan.py
@@ -42,15 +42,15 @@ class AMPBlock1(torch.nn.Module):
     """
 
     def __init__(
-        self,
-        h: AttrDict,
-        channels: int,
-        kernel_size: int = 3,
-        dilation: tuple = (1, 3, 5),
-        activation: str = None,
+            self,
+            h: AttrDict,
+            channels: int,
+            kernel_size: int = 3,
+            dilation: tuple = (1, 3, 5),
+            activation: str = None,
     ):
         super().__init__()
-        
+
         self.h = h
 
         self.convs1 = nn.ModuleList(
@@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
 
         # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
         if self.h.get("use_cuda_kernel", False):
-            from alias_free_activation.cuda.activation1d import (
+            from .alias_free_activation.cuda.activation1d import (
                 Activation1d as CudaActivation1d,
             )
 
@@ -161,15 +161,15 @@ class AMPBlock2(torch.nn.Module):
     """
 
     def __init__(
-        self,
-        h: AttrDict,
-        channels: int,
-        kernel_size: int = 3,
-        dilation: tuple = (1, 3, 5),
-        activation: str = None,
+            self,
+            h: AttrDict,
+            channels: int,
+            kernel_size: int = 3,
+            dilation: tuple = (1, 3, 5),
+            activation: str = None,
     ):
         super().__init__()
-        
+
         self.h = h
 
         self.convs = nn.ModuleList(
@@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
 
         # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
         if self.h.get("use_cuda_kernel", False):
-            from alias_free_activation.cuda.activation1d import (
+            from .alias_free_activation.cuda.activation1d import (
                 Activation1d as CudaActivation1d,
             )
 
@@ -270,7 +270,7 @@ class BigVGAN(
 
         # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
         if self.h.get("use_cuda_kernel", False):
-            from alias_free_activation.cuda.activation1d import (
+            from .alias_free_activation.cuda.activation1d import (
                 Activation1d as CudaActivation1d,
             )
 
@@ -304,7 +304,7 @@ class BigVGAN(
                     [
                         weight_norm(
                             ConvTranspose1d(
-                                h.upsample_initial_channel // (2**i),
+                                h.upsample_initial_channel // (2 ** i),
                                 h.upsample_initial_channel // (2 ** (i + 1)),
                                 k,
                                 u,
@@ -320,7 +320,7 @@ class BigVGAN(
         for i in range(len(self.ups)):
             ch = h.upsample_initial_channel // (2 ** (i + 1))
             for j, (k, d) in enumerate(
-                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+                    zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
             ):
                 self.resblocks.append(
                     resblock_class(h, ch, k, d, activation=h.activation)
@@ -412,20 +412,20 @@ class BigVGAN(
 
     @classmethod
     def _from_pretrained(
-        cls,
-        *,
-        model_id: str,
-        revision: str,
-        cache_dir: str,
-        force_download: bool,
-        proxies: Optional[Dict],
-        resume_download: bool,
-        local_files_only: bool,
-        token: Union[str, bool, None],
-        map_location: str = "cpu",  # Additional argument
-        strict: bool = False,  # Additional argument
-        use_cuda_kernel: bool = False,
-        **model_kwargs,
+            cls,
+            *,
+            model_id: str,
+            revision: str,
+            cache_dir: str,
+            force_download: bool,
+            proxies: Optional[Dict],
+            resume_download: bool,
+            local_files_only: bool,
+            token: Union[str, bool, None],
+            map_location: str = "cpu",  # Additional argument
+            strict: bool = False,  # Additional argument
+            use_cuda_kernel: bool = False,
+            **model_kwargs,
     ):
         """Load Pytorch pretrained weights and return the loaded model."""
 
@@ -489,4 +489,4 @@ class BigVGAN(
             model.remove_weight_norm()
             model.load_state_dict(checkpoint_dict["generator"])
 
-        return model
+        return model
\ No newline at end of file
diff --git a/modules/commons.py b/modules/commons.py
index 350208e50ba8630e53c30847db345cc3ace77473..fb0f6f89ef550e7570beffdef1d438e7dc51259f 100644
--- a/modules/commons.py
+++ b/modules/commons.py
@@ -1,490 +1,476 @@
-import math
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import functional as F
-from munch import Munch
-import json
-
-
-class AttrDict(dict):
-    def __init__(self, *args, **kwargs):
-        super(AttrDict, self).__init__(*args, **kwargs)
-        self.__dict__ = self
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def intersperse(lst, item):
-    result = [item] * (len(lst) * 2 + 1)
-    result[1::2] = lst
-    return result
-
-
-def kl_divergence(m_p, logs_p, m_q, logs_q):
-    """KL(P||Q)"""
-    kl = (logs_q - logs_p) - 0.5
-    kl += (
-        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
-    )
-    return kl
-
-
-def rand_gumbel(shape):
-    """Sample from the Gumbel distribution, protect from overflows."""
-    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
-    return -torch.log(-torch.log(uniform_samples))
-
-
-def rand_gumbel_like(x):
-    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
-    return g
-
-
-def slice_segments(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, :, idx_str:idx_end]
-    return ret
-
-
-def slice_segments_audio(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, idx_str:idx_end]
-    return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
-    b, d, t = x.size()
-    if x_lengths is None:
-        x_lengths = t
-    ids_str_max = x_lengths - segment_size + 1
-    ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
-        dtype=torch.long
-    )
-    ret = slice_segments(x, ids_str, segment_size)
-    return ret, ids_str
-
-
-def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
-    position = torch.arange(length, dtype=torch.float)
-    num_timescales = channels // 2
-    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
-        num_timescales - 1
-    )
-    inv_timescales = min_timescale * torch.exp(
-        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
-    )
-    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
-    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
-    signal = F.pad(signal, [0, 0, 0, channels % 2])
-    signal = signal.view(1, channels, length)
-    return signal
-
-
-def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return x + signal.to(dtype=x.dtype, device=x.device)
-
-
-def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
-
-
-def subsequent_mask(length):
-    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
-    return mask
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
-    n_channels_int = n_channels[0]
-    in_act = input_a + input_b
-    t_act = torch.tanh(in_act[:, :n_channels_int, :])
-    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
-    acts = t_act * s_act
-    return acts
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def shift_1d(x):
-    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
-    return x
-
-
-def sequence_mask(length, max_length=None):
-    if max_length is None:
-        max_length = length.max()
-    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def avg_with_mask(x, mask):
-    assert mask.dtype == torch.float, "Mask should be float"
-
-    if mask.ndim == 2:
-        mask = mask.unsqueeze(1)
-
-    if mask.shape[1] == 1:
-        mask = mask.expand_as(x)
-
-    return (x * mask).sum() / mask.sum()
-
-
-def generate_path(duration, mask):
-    """
-    duration: [b, 1, t_x]
-    mask: [b, 1, t_y, t_x]
-    """
-    device = duration.device
-
-    b, _, t_y, t_x = mask.shape
-    cum_duration = torch.cumsum(duration, -1)
-
-    cum_duration_flat = cum_duration.view(b * t_x)
-    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
-    path = path.view(b, t_x, t_y)
-    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
-    path = path.unsqueeze(1).transpose(2, 3) * mask
-    return path
-
-
-def clip_grad_value_(parameters, clip_value, norm_type=2):
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = list(filter(lambda p: p.grad is not None, parameters))
-    norm_type = float(norm_type)
-    if clip_value is not None:
-        clip_value = float(clip_value)
-
-    total_norm = 0
-    for p in parameters:
-        param_norm = p.grad.data.norm(norm_type)
-        total_norm += param_norm.item() ** norm_type
-        if clip_value is not None:
-            p.grad.data.clamp_(min=-clip_value, max=clip_value)
-    total_norm = total_norm ** (1.0 / norm_type)
-    return total_norm
-
-
-def log_norm(x, mean=-4, std=4, dim=2):
-    """
-    normalized log mel -> mel -> norm -> log(norm)
-    """
-    x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
-    return x
-
-
-def load_F0_models(path):
-    # load F0 model
-    from .JDC.model import JDCNet
-
-    F0_model = JDCNet(num_class=1, seq_len=192)
-    params = torch.load(path, map_location="cpu")["net"]
-    F0_model.load_state_dict(params)
-    _ = F0_model.train()
-
-    return F0_model
-
-
-def modify_w2v_forward(self, output_layer=15):
-    """
-    change forward method of w2v encoder to get its intermediate layer output
-    :param self:
-    :param layer:
-    :return:
-    """
-    from transformers.modeling_outputs import BaseModelOutput
-
-    def forward(
-        hidden_states,
-        attention_mask=None,
-        output_attentions=False,
-        output_hidden_states=False,
-        return_dict=True,
-    ):
-        all_hidden_states = () if output_hidden_states else None
-        all_self_attentions = () if output_attentions else None
-
-        conv_attention_mask = attention_mask
-        if attention_mask is not None:
-            # make sure padded tokens output 0
-            hidden_states = hidden_states.masked_fill(
-                ~attention_mask.bool().unsqueeze(-1), 0.0
-            )
-
-            # extend attention_mask
-            attention_mask = 1.0 - attention_mask[:, None, None, :].to(
-                dtype=hidden_states.dtype
-            )
-            attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
-            attention_mask = attention_mask.expand(
-                attention_mask.shape[0],
-                1,
-                attention_mask.shape[-1],
-                attention_mask.shape[-1],
-            )
-
-        hidden_states = self.dropout(hidden_states)
-
-        if self.embed_positions is not None:
-            relative_position_embeddings = self.embed_positions(hidden_states)
-        else:
-            relative_position_embeddings = None
-
-        deepspeed_zero3_is_enabled = False
-
-        for i, layer in enumerate(self.layers):
-            if output_hidden_states:
-                all_hidden_states = all_hidden_states + (hidden_states,)
-
-            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
-            dropout_probability = torch.rand([])
-
-            skip_the_layer = (
-                True
-                if self.training and (dropout_probability < self.config.layerdrop)
-                else False
-            )
-            if not skip_the_layer or deepspeed_zero3_is_enabled:
-                # under deepspeed zero3 all gpus must run in sync
-                if self.gradient_checkpointing and self.training:
-                    layer_outputs = self._gradient_checkpointing_func(
-                        layer.__call__,
-                        hidden_states,
-                        attention_mask,
-                        relative_position_embeddings,
-                        output_attentions,
-                        conv_attention_mask,
-                    )
-                else:
-                    layer_outputs = layer(
-                        hidden_states,
-                        attention_mask=attention_mask,
-                        relative_position_embeddings=relative_position_embeddings,
-                        output_attentions=output_attentions,
-                        conv_attention_mask=conv_attention_mask,
-                    )
-                hidden_states = layer_outputs[0]
-
-            if skip_the_layer:
-                layer_outputs = (None, None)
-
-            if output_attentions:
-                all_self_attentions = all_self_attentions + (layer_outputs[1],)
-
-            if i == output_layer - 1:
-                break
-
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        if not return_dict:
-            return tuple(
-                v
-                for v in [hidden_states, all_hidden_states, all_self_attentions]
-                if v is not None
-            )
-        return BaseModelOutput(
-            last_hidden_state=hidden_states,
-            hidden_states=all_hidden_states,
-            attentions=all_self_attentions,
-        )
-
-    return forward
-
-
-MATPLOTLIB_FLAG = False
-
-
-def plot_spectrogram_to_numpy(spectrogram):
-    global MATPLOTLIB_FLAG
-    if not MATPLOTLIB_FLAG:
-        import matplotlib
-        import logging
-
-        matplotlib.use("Agg")
-        MATPLOTLIB_FLAG = True
-        mpl_logger = logging.getLogger("matplotlib")
-        mpl_logger.setLevel(logging.WARNING)
-    import matplotlib.pylab as plt
-    import numpy as np
-
-    fig, ax = plt.subplots(figsize=(10, 2))
-    im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
-    plt.colorbar(im, ax=ax)
-    plt.xlabel("Frames")
-    plt.ylabel("Channels")
-    plt.tight_layout()
-
-    fig.canvas.draw()
-    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
-    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
-    plt.close()
-    return data
-
-
-def normalize_f0(f0_sequence):
-    # Remove unvoiced frames (replace with -1)
-    voiced_indices = np.where(f0_sequence > 0)[0]
-    f0_voiced = f0_sequence[voiced_indices]
-
-    # Convert to log scale
-    log_f0 = np.log2(f0_voiced)
-
-    # Calculate mean and standard deviation
-    mean_f0 = np.mean(log_f0)
-    std_f0 = np.std(log_f0)
-
-    # Normalize the F0 sequence
-    normalized_f0 = (log_f0 - mean_f0) / std_f0
-
-    # Create the normalized F0 sequence with unvoiced frames
-    normalized_sequence = np.zeros_like(f0_sequence)
-    normalized_sequence[voiced_indices] = normalized_f0
-    normalized_sequence[f0_sequence <= 0] = -1  # Assign -1 to unvoiced frames
-
-    return normalized_sequence
-
-
-def build_model(args, stage="DiT"):
-    if stage == "DiT":
-        from modules.flow_matching import CFM
-        from modules.length_regulator import InterpolateRegulator
-
-        length_regulator = InterpolateRegulator(
-            channels=args.length_regulator.channels,
-            sampling_ratios=args.length_regulator.sampling_ratios,
-            is_discrete=args.length_regulator.is_discrete,
-            in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
-            vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
-            codebook_size=args.length_regulator.content_codebook_size,
-            n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
-            quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
-            f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
-            n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
-        )
-        cfm = CFM(args)
-        nets = Munch(
-            cfm=cfm,
-            length_regulator=length_regulator,
-        )
-    elif stage == 'codec':
-        from dac.model.dac import Encoder
-        from modules.quantize import (
-            FAquantizer,
-        )
-
-        encoder = Encoder(
-            d_model=args.DAC.encoder_dim,
-            strides=args.DAC.encoder_rates,
-            d_latent=1024,
-            causal=args.causal,
-            lstm=args.lstm,
-        )
-
-        quantizer = FAquantizer(
-            in_dim=1024,
-            n_p_codebooks=1,
-            n_c_codebooks=args.n_c_codebooks,
-            n_t_codebooks=2,
-            n_r_codebooks=3,
-            codebook_size=1024,
-            codebook_dim=8,
-            quantizer_dropout=0.5,
-            causal=args.causal,
-            separate_prosody_encoder=args.separate_prosody_encoder,
-            timbre_norm=args.timbre_norm,
-        )
-
-        nets = Munch(
-            encoder=encoder,
-            quantizer=quantizer,
-        )
-    else:
-        raise ValueError(f"Unknown stage: {stage}")
-
-    return nets
-
-
-def load_checkpoint(
-    model,
-    optimizer,
-    path,
-    load_only_params=True,
-    ignore_modules=[],
-    is_distributed=False,
-):
-    state = torch.load(path, map_location="cpu")
-    params = state["net"]
-    for key in model:
-        if key in params and key not in ignore_modules:
-            if not is_distributed:
-                # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
-                for k in list(params[key].keys()):
-                    if k.startswith("module."):
-                        params[key][k[len("module.") :]] = params[key][k]
-                        del params[key][k]
-            model_state_dict = model[key].state_dict()
-            # 过滤出形状匹配的键值对
-            filtered_state_dict = {
-                k: v
-                for k, v in params[key].items()
-                if k in model_state_dict and v.shape == model_state_dict[k].shape
-            }
-            skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
-            if skipped_keys:
-                print(
-                    f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
-                )
-            print("%s loaded" % key)
-            model[key].load_state_dict(filtered_state_dict, strict=False)
-    _ = [model[key].eval() for key in model]
-
-    if not load_only_params:
-        epoch = state["epoch"] + 1
-        iters = state["iters"]
-        optimizer.load_state_dict(state["optimizer"])
-        optimizer.load_scheduler_state_dict(state["scheduler"])
-
-    else:
-        epoch = 0
-        iters = 0
-
-    return model, optimizer, epoch, iters
-
-
-def recursive_munch(d):
-    if isinstance(d, dict):
-        return Munch((k, recursive_munch(v)) for k, v in d.items())
-    elif isinstance(d, list):
-        return [recursive_munch(v) for v in d]
-    else:
-        return d
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from munch import Munch
+import json
+import argparse
+
+def str2bool(v):
+    if isinstance(v, bool):
+        return v
+    if v.lower() in ("yes", "true", "t", "y", "1"):
+        return True
+    elif v.lower() in ("no", "false", "f", "n", "0"):
+        return False
+    else:
+        raise argparse.ArgumentTypeError("Boolean value expected.")
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def intersperse(lst, item):
+    result = [item] * (len(lst) * 2 + 1)
+    result[1::2] = lst
+    return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+    """KL(P||Q)"""
+    kl = (logs_q - logs_p) - 0.5
+    kl += (
+        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+    )
+    return kl
+
+
+def rand_gumbel(shape):
+    """Sample from the Gumbel distribution, protect from overflows."""
+    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+    return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+    return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def slice_segments_audio(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, idx_str:idx_end]
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = x_lengths - segment_size + 1
+    ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
+        dtype=torch.long
+    )
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+    position = torch.arange(length, dtype=torch.float)
+    num_timescales = channels // 2
+    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+        num_timescales - 1
+    )
+    inv_timescales = min_timescale * torch.exp(
+        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+    )
+    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+    signal = F.pad(signal, [0, 0, 0, channels % 2])
+    signal = signal.view(1, channels, length)
+    return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+    return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def shift_1d(x):
+    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+    return x
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def avg_with_mask(x, mask):
+    assert mask.dtype == torch.float, "Mask should be float"
+
+    if mask.ndim == 2:
+        mask = mask.unsqueeze(1)
+
+    if mask.shape[1] == 1:
+        mask = mask.expand_as(x)
+
+    return (x * mask).sum() / mask.sum()
+
+
+def generate_path(duration, mask):
+    """
+    duration: [b, 1, t_x]
+    mask: [b, 1, t_y, t_x]
+    """
+    device = duration.device
+
+    b, _, t_y, t_x = mask.shape
+    cum_duration = torch.cumsum(duration, -1)
+
+    cum_duration_flat = cum_duration.view(b * t_x)
+    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+    path = path.view(b, t_x, t_y)
+    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+    path = path.unsqueeze(1).transpose(2, 3) * mask
+    return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    if clip_value is not None:
+        clip_value = float(clip_value)
+
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item() ** norm_type
+        if clip_value is not None:
+            p.grad.data.clamp_(min=-clip_value, max=clip_value)
+    total_norm = total_norm ** (1.0 / norm_type)
+    return total_norm
+
+
+def log_norm(x, mean=-4, std=4, dim=2):
+    """
+    normalized log mel -> mel -> norm -> log(norm)
+    """
+    x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
+    return x
+
+
+def load_F0_models(path):
+    # load F0 model
+    from .JDC.model import JDCNet
+
+    F0_model = JDCNet(num_class=1, seq_len=192)
+    params = torch.load(path, map_location="cpu")["net"]
+    F0_model.load_state_dict(params)
+    _ = F0_model.train()
+
+    return F0_model
+
+
+def modify_w2v_forward(self, output_layer=15):
+    """
+    change forward method of w2v encoder to get its intermediate layer output
+    :param self:
+    :param layer:
+    :return:
+    """
+    from transformers.modeling_outputs import BaseModelOutput
+
+    def forward(
+        hidden_states,
+        attention_mask=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        conv_attention_mask = attention_mask
+        if attention_mask is not None:
+            # make sure padded tokens output 0
+            hidden_states = hidden_states.masked_fill(
+                ~attention_mask.bool().unsqueeze(-1), 0.0
+            )
+
+            # extend attention_mask
+            attention_mask = 1.0 - attention_mask[:, None, None, :].to(
+                dtype=hidden_states.dtype
+            )
+            attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+            attention_mask = attention_mask.expand(
+                attention_mask.shape[0],
+                1,
+                attention_mask.shape[-1],
+                attention_mask.shape[-1],
+            )
+
+        hidden_states = self.dropout(hidden_states)
+
+        if self.embed_positions is not None:
+            relative_position_embeddings = self.embed_positions(hidden_states)
+        else:
+            relative_position_embeddings = None
+
+        deepspeed_zero3_is_enabled = False
+
+        for i, layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            dropout_probability = torch.rand([])
+
+            skip_the_layer = (
+                True
+                if self.training and (dropout_probability < self.config.layerdrop)
+                else False
+            )
+            if not skip_the_layer or deepspeed_zero3_is_enabled:
+                # under deepspeed zero3 all gpus must run in sync
+                if self.gradient_checkpointing and self.training:
+                    layer_outputs = self._gradient_checkpointing_func(
+                        layer.__call__,
+                        hidden_states,
+                        attention_mask,
+                        relative_position_embeddings,
+                        output_attentions,
+                        conv_attention_mask,
+                    )
+                else:
+                    layer_outputs = layer(
+                        hidden_states,
+                        attention_mask=attention_mask,
+                        relative_position_embeddings=relative_position_embeddings,
+                        output_attentions=output_attentions,
+                        conv_attention_mask=conv_attention_mask,
+                    )
+                hidden_states = layer_outputs[0]
+
+            if skip_the_layer:
+                layer_outputs = (None, None)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+            if i == output_layer - 1:
+                break
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, all_hidden_states, all_self_attentions]
+                if v is not None
+            )
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+    return forward
+
+
+MATPLOTLIB_FLAG = False
+
+
+def plot_spectrogram_to_numpy(spectrogram):
+    global MATPLOTLIB_FLAG
+    if not MATPLOTLIB_FLAG:
+        import matplotlib
+        import logging
+
+        matplotlib.use("Agg")
+        MATPLOTLIB_FLAG = True
+        mpl_logger = logging.getLogger("matplotlib")
+        mpl_logger.setLevel(logging.WARNING)
+    import matplotlib.pylab as plt
+    import numpy as np
+
+    fig, ax = plt.subplots(figsize=(10, 2))
+    im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+    plt.colorbar(im, ax=ax)
+    plt.xlabel("Frames")
+    plt.ylabel("Channels")
+    plt.tight_layout()
+
+    fig.canvas.draw()
+    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+    plt.close()
+    return data
+
+
+def normalize_f0(f0_sequence):
+    # Remove unvoiced frames (replace with -1)
+    voiced_indices = np.where(f0_sequence > 0)[0]
+    f0_voiced = f0_sequence[voiced_indices]
+
+    # Convert to log scale
+    log_f0 = np.log2(f0_voiced)
+
+    # Calculate mean and standard deviation
+    mean_f0 = np.mean(log_f0)
+    std_f0 = np.std(log_f0)
+
+    # Normalize the F0 sequence
+    normalized_f0 = (log_f0 - mean_f0) / std_f0
+
+    # Create the normalized F0 sequence with unvoiced frames
+    normalized_sequence = np.zeros_like(f0_sequence)
+    normalized_sequence[voiced_indices] = normalized_f0
+    normalized_sequence[f0_sequence <= 0] = -1  # Assign -1 to unvoiced frames
+
+    return normalized_sequence
+
+
+def build_model(args, stage="DiT"):
+    if stage == "DiT":
+        from modules.flow_matching import CFM
+        from modules.length_regulator import InterpolateRegulator
+
+        length_regulator = InterpolateRegulator(
+            channels=args.length_regulator.channels,
+            sampling_ratios=args.length_regulator.sampling_ratios,
+            is_discrete=args.length_regulator.is_discrete,
+            in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
+            codebook_size=args.length_regulator.content_codebook_size,
+            f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
+            n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
+        )
+        cfm = CFM(args)
+        nets = Munch(
+            cfm=cfm,
+            length_regulator=length_regulator,
+        )
+    else:
+        raise ValueError(f"Unknown stage: {stage}")
+
+    return nets
+
+
+def load_checkpoint(
+    model,
+    optimizer,
+    path,
+    load_only_params=True,
+    ignore_modules=[],
+    is_distributed=False,
+    load_ema=False,
+):
+    state = torch.load(path, map_location="cpu")
+    params = state["net"]
+    if load_ema and "ema" in state:
+        print("Loading EMA")
+        for key in model:
+            i = 0
+            for param_name in params[key]:
+                if "input_pos" in param_name:
+                    continue
+                assert params[key][param_name].shape == state["ema"][key][0][i].shape
+                params[key][param_name] = state["ema"][key][0][i].clone()
+                i += 1
+    for key in model:
+        if key in params and key not in ignore_modules:
+            if not is_distributed:
+                # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
+                for k in list(params[key].keys()):
+                    if k.startswith("module."):
+                        params[key][k[len("module.") :]] = params[key][k]
+                        del params[key][k]
+            model_state_dict = model[key].state_dict()
+            # 过滤出形状匹配的键值对
+            filtered_state_dict = {
+                k: v
+                for k, v in params[key].items()
+                if k in model_state_dict and v.shape == model_state_dict[k].shape
+            }
+            skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
+            if skipped_keys:
+                print(
+                    f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
+                )
+            print("%s loaded" % key)
+            model[key].load_state_dict(filtered_state_dict, strict=False)
+    _ = [model[key].eval() for key in model]
+
+    if not load_only_params:
+        epoch = state["epoch"] + 1
+        iters = state["iters"]
+        optimizer.load_state_dict(state["optimizer"])
+        optimizer.load_scheduler_state_dict(state["scheduler"])
+
+    else:
+        epoch = 0
+        iters = 0
+
+    return model, optimizer, epoch, iters
+
+
+def recursive_munch(d):
+    if isinstance(d, dict):
+        return Munch((k, recursive_munch(v)) for k, v in d.items())
+    elif isinstance(d, list):
+        return [recursive_munch(v) for v in d]
+    else:
+        return d
diff --git a/modules/diffusion_transformer.py b/modules/diffusion_transformer.py
index b7f40975e52d1cc7944192bff30e2e7341e4fedb..f9b468fa6701a72fab3e55e31dadc814e10c78f1 100644
--- a/modules/diffusion_transformer.py
+++ b/modules/diffusion_transformer.py
@@ -1,240 +1,537 @@
-import torch
-from torch import nn
-import math
-
-from modules.gpt_fast.model import ModelArgs, Transformer
-# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
-from modules.wavenet import WN
-from modules.commons import sequence_mask
-
-from torch.nn.utils import weight_norm
-
-def modulate(x, shift, scale):
-    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
-
-
-#################################################################################
-#               Embedding Layers for Timesteps and Class Labels                 #
-#################################################################################
-
-class TimestepEmbedder(nn.Module):
-    """
-    Embeds scalar timesteps into vector representations.
-    """
-    def __init__(self, hidden_size, frequency_embedding_size=256):
-        super().__init__()
-        self.mlp = nn.Sequential(
-            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
-            nn.SiLU(),
-            nn.Linear(hidden_size, hidden_size, bias=True),
-        )
-        self.frequency_embedding_size = frequency_embedding_size
-        self.max_period = 10000
-        self.scale = 1000
-
-        half = frequency_embedding_size // 2
-        freqs = torch.exp(
-            -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-        )
-        self.register_buffer("freqs", freqs)
-
-    def timestep_embedding(self, t):
-        """
-        Create sinusoidal timestep embeddings.
-        :param t: a 1-D Tensor of N indices, one per batch element.
-                          These may be fractional.
-        :param dim: the dimension of the output.
-        :param max_period: controls the minimum frequency of the embeddings.
-        :return: an (N, D) Tensor of positional embeddings.
-        """
-        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
-
-        args = self.scale * t[:, None].float() * self.freqs[None]
-        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
-        if self.frequency_embedding_size % 2:
-            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
-        return embedding
-
-    def forward(self, t):
-        t_freq = self.timestep_embedding(t)
-        t_emb = self.mlp(t_freq)
-        return t_emb
-
-
-class StyleEmbedder(nn.Module):
-    """
-    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
-    """
-    def __init__(self, input_size, hidden_size, dropout_prob):
-        super().__init__()
-        use_cfg_embedding = dropout_prob > 0
-        self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
-        self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
-        self.input_size = input_size
-        self.dropout_prob = dropout_prob
-
-    def forward(self, labels, train, force_drop_ids=None):
-        use_dropout = self.dropout_prob > 0
-        if (train and use_dropout) or (force_drop_ids is not None):
-            labels = self.token_drop(labels, force_drop_ids)
-        else:
-            labels = self.style_in(labels)
-        embeddings = labels
-        return embeddings
-
-class FinalLayer(nn.Module):
-    """
-    The final layer of DiT.
-    """
-    def __init__(self, hidden_size, patch_size, out_channels):
-        super().__init__()
-        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-        self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
-        self.adaLN_modulation = nn.Sequential(
-            nn.SiLU(),
-            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
-        )
-
-    def forward(self, x, c):
-        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
-        x = modulate(self.norm_final(x), shift, scale)
-        x = self.linear(x)
-        return x
-
-class DiT(torch.nn.Module):
-    def __init__(
-        self,
-        args
-    ):
-        super(DiT, self).__init__()
-        self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
-        self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
-        self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
-        model_args = ModelArgs(
-            block_size=16384,#args.DiT.block_size,
-            n_layer=args.DiT.depth,
-            n_head=args.DiT.num_heads,
-            dim=args.DiT.hidden_dim,
-            head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
-            vocab_size=1024,
-            uvit_skip_connection=self.uvit_skip_connection,
-        )
-        self.transformer = Transformer(model_args)
-        self.in_channels = args.DiT.in_channels
-        self.out_channels = args.DiT.in_channels
-        self.num_heads = args.DiT.num_heads
-
-        self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
-
-        self.content_type = args.DiT.content_type  # 'discrete' or 'continuous'
-        self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
-        self.content_dim = args.DiT.content_dim # for continuous content
-        self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim)  # discrete content
-        self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
-
-        self.is_causal = args.DiT.is_causal
-
-        self.n_f0_bins = args.DiT.n_f0_bins
-        self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
-        self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
-        self.f0_condition = args.DiT.f0_condition
-
-        self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
-        self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
-        # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
-        # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
-
-        input_pos = torch.arange(16384)
-        self.register_buffer("input_pos", input_pos)
-
-        self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
-        self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
-        self.final_layer_type = args.DiT.final_layer_type  # mlp or wavenet
-        if self.final_layer_type == 'wavenet':
-            self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
-                              kernel_size=args.wavenet.kernel_size,
-                              dilation_rate=args.wavenet.dilation_rate,
-                              n_layers=args.wavenet.num_layers,
-                              gin_channels=args.wavenet.hidden_dim,
-                              p_dropout=args.wavenet.p_dropout,
-                              causal=False)
-            self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
-        else:
-            self.final_mlp = nn.Sequential(
-                    nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
-                    nn.SiLU(),
-                    nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
-            )
-        self.transformer_style_condition = args.DiT.style_condition
-        self.wavenet_style_condition = args.wavenet.style_condition
-        assert args.DiT.style_condition == args.wavenet.style_condition
-
-        self.class_dropout_prob = args.DiT.class_dropout_prob
-        self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
-        self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)  # residual connection from tranformer output to final output
-        self.long_skip_connection = args.DiT.long_skip_connection
-        self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
-
-        self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
-                                             args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
-                                             args.DiT.hidden_dim)
-        if self.style_as_token:
-            self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
-
-    def setup_caches(self, max_batch_size, max_seq_length):
-        self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
-    def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
-        class_dropout = False
-        if self.training and torch.rand(1) < self.class_dropout_prob:
-            class_dropout = True
-        if not self.training and mask_content:
-            class_dropout = True
-        # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
-        cond_in_module = self.cond_projection
-
-        B, _, T = x.size()
-
-
-        t1 = self.t_embedder(t)  # (N, D)
-
-        cond = cond_in_module(cond)
-        if self.f0_condition and f0 is not None:
-            quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
-            cond = cond + self.f0_embedder(quantized_f0)
-
-        x = x.transpose(1, 2)
-        prompt_x = prompt_x.transpose(1, 2)
-
-        x_in = torch.cat([x, prompt_x, cond], dim=-1)
-        if self.transformer_style_condition and not self.style_as_token:
-            x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
-        if class_dropout:
-            x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
-        x_in = self.cond_x_merge_linear(x_in)  # (N, T, D)
-
-        if self.style_as_token:
-            style = self.style_in(style)
-            style = torch.zeros_like(style) if class_dropout else style
-            x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
-        if self.time_as_token:
-            x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
-        x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
-        input_pos = self.input_pos[:x_in.size(1)]  # (T,)
-        x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
-        x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
-        x_res = x_res[:, 1:] if self.time_as_token else x_res
-        x_res = x_res[:, 1:] if self.style_as_token else x_res
-        if self.long_skip_connection:
-            x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
-        if self.final_layer_type == 'wavenet':
-            x = self.conv1(x_res)
-            x = x.transpose(1, 2)
-            t2 = self.t_embedder2(t)
-            x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
-                x_res)  # long residual connection
-            x = self.final_layer(x, t1).transpose(1, 2)
-            x = self.conv2(x)
-        else:
-            x = self.final_mlp(x_res)
-            x = x.transpose(1, 2)
-        return x
+import torch
+from torch import nn
+import math
+
+# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
+from modules.wavenet import WN
+from modules.commons import sequence_mask
+
+from torch.nn.utils import weight_norm
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNorm, self).__init__()
+        self.project_layer = nn.Linear(d_model, 2 * d_model)
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+        if embedding is None:
+            return self.norm(input)
+        weight, bias = torch.split(
+            self.project_layer(embedding),
+            split_size_or_sections=self.d_model,
+            dim=-1,
+        )
+        return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+    block_size: int = 2048
+    vocab_size: int = 32000
+    n_layer: int = 32
+    n_head: int = 32
+    dim: int = 4096
+    intermediate_size: int = None
+    n_local_heads: int = -1
+    head_dim: int = 64
+    rope_base: float = 10000
+    norm_eps: float = 1e-5
+    has_cross_attention: bool = False
+    context_dim: int = 0
+    uvit_skip_connection: bool = False
+    time_as_token: bool = False
+
+    def __post_init__(self):
+        if self.n_local_heads == -1:
+            self.n_local_heads = self.n_head
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.config = config
+
+        self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+        self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        self.freqs_cis: Optional[Tensor] = None
+        self.mask_cache: Optional[Tensor] = None
+        self.max_batch_size = -1
+        self.max_seq_length = -1
+
+    def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False):
+        if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+            return
+        head_dim = self.config.dim // self.config.n_head
+        max_seq_length = find_multiple(max_seq_length, 8)
+        self.max_seq_length = max_seq_length
+        self.max_batch_size = max_batch_size
+        dtype = self.norm.project_layer.weight.dtype
+        device = self.norm.project_layer.weight.device
+
+        self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+                                              self.config.rope_base, dtype).to(device)
+        self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
+        self.use_kv_cache = use_kv_cache
+        self.uvit_skip_connection = self.config.uvit_skip_connection
+        if self.uvit_skip_connection:
+            self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+            self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+        else:
+            self.layers_emit_skip = []
+            self.layers_receive_skip = []
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Optional[Tensor] = None,
+                mask: Optional[Tensor] = None,
+                context: Optional[Tensor] = None,
+                context_input_pos: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                ) -> Tensor:
+        assert self.freqs_cis is not None, "Caches must be initialized first"
+        if mask is None: # in case of non-causal model
+            if not self.training and self.use_kv_cache:
+                mask = self.causal_mask[None, None, input_pos]
+            else:
+                mask = self.causal_mask[None, None, input_pos]
+                mask = mask[..., input_pos]
+        freqs_cis = self.freqs_cis[input_pos]
+        if context is not None:
+            context_freqs_cis = self.freqs_cis[context_input_pos]
+        else:
+            context_freqs_cis = None
+        skip_in_x_list = []
+        for i, layer in enumerate(self.layers):
+            if self.uvit_skip_connection and i in self.layers_receive_skip:
+                skip_in_x = skip_in_x_list.pop(-1)
+            else:
+                skip_in_x = None
+            x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
+            if self.uvit_skip_connection and i in self.layers_emit_skip:
+                skip_in_x_list.append(x)
+        x = self.norm(x, c)
+        return x
+
+    @classmethod
+    def from_name(cls, name: str):
+        return cls(ModelArgs.from_name(name))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.attention = Attention(config)
+        self.feed_forward = FeedForward(config)
+        self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        if config.has_cross_attention:
+            self.has_cross_attention = True
+            self.cross_attention = Attention(config, is_cross_attention=True)
+            self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        else:
+            self.has_cross_attention = False
+
+        if config.uvit_skip_connection:
+            self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
+            self.uvit_skip_connection = True
+        else:
+            self.uvit_skip_connection = False
+
+        self.time_as_token = config.time_as_token
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                skip_in_x: Optional[Tensor] = None,
+                ) -> Tensor:
+        c = None if self.time_as_token else c
+        if self.uvit_skip_connection and skip_in_x is not None:
+            x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
+        h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
+        if self.has_cross_attention:
+            h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
+        out = h + self.feed_forward(self.ffn_norm(h, c))
+        return out
+
+
+class Attention(nn.Module):
+    def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+        # key, query, value projections for all heads, but in a batch
+        if is_cross_attention:
+            self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+            self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+        else:
+            self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+        self.kv_cache = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.n_local_heads = config.n_local_heads
+        self.dim = config.dim
+        # self._register_load_state_dict_pre_hook(self.load_hook)
+
+    # def load_hook(self, state_dict, prefix, *args):
+    #     if prefix + "wq.weight" in state_dict:
+    #         wq = state_dict.pop(prefix + "wq.weight")
+    #         wk = state_dict.pop(prefix + "wk.weight")
+    #         wv = state_dict.pop(prefix + "wv.weight")
+    #         state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+    def forward(self,
+                x: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                input_pos: Optional[Tensor] = None,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_local_heads * self.head_dim
+        if context is None:
+            q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+            context_seqlen = seqlen
+        else:
+            q = self.wq(x)
+            k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+            context_seqlen = context.shape[1]
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+        v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-5):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+    def forward(self, x: Tensor) -> Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(
+        seq_len: int, n_elem: int, base: int = 10000,
+        dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+#               Embedding Layers for Timesteps and Class Labels                 #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, hidden_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+        self.max_period = 10000
+        self.scale = 1000
+
+        half = frequency_embedding_size // 2
+        freqs = torch.exp(
+            -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        )
+        self.register_buffer("freqs", freqs)
+
+    def timestep_embedding(self, t):
+        """
+        Create sinusoidal timestep embeddings.
+        :param t: a 1-D Tensor of N indices, one per batch element.
+                          These may be fractional.
+        :param dim: the dimension of the output.
+        :param max_period: controls the minimum frequency of the embeddings.
+        :return: an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+
+        args = self.scale * t[:, None].float() * self.freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if self.frequency_embedding_size % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+class StyleEmbedder(nn.Module):
+    """
+    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+    """
+    def __init__(self, input_size, hidden_size, dropout_prob):
+        super().__init__()
+        use_cfg_embedding = dropout_prob > 0
+        self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
+        self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
+        self.input_size = input_size
+        self.dropout_prob = dropout_prob
+
+    def forward(self, labels, train, force_drop_ids=None):
+        use_dropout = self.dropout_prob > 0
+        if (train and use_dropout) or (force_drop_ids is not None):
+            labels = self.token_drop(labels, force_drop_ids)
+        else:
+            labels = self.style_in(labels)
+        embeddings = labels
+        return embeddings
+
+class FinalLayer(nn.Module):
+    """
+    The final layer of DiT.
+    """
+    def __init__(self, hidden_size, patch_size, out_channels):
+        super().__init__()
+        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+        self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(),
+            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+        )
+
+    def forward(self, x, c):
+        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+        x = modulate(self.norm_final(x), shift, scale)
+        x = self.linear(x)
+        return x
+
+class DiT(torch.nn.Module):
+    def __init__(
+        self,
+        args
+    ):
+        super(DiT, self).__init__()
+        self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
+        self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
+        self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
+        model_args = ModelArgs(
+            block_size=16384,#args.DiT.block_size,
+            n_layer=args.DiT.depth,
+            n_head=args.DiT.num_heads,
+            dim=args.DiT.hidden_dim,
+            head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
+            vocab_size=1024,
+            uvit_skip_connection=self.uvit_skip_connection,
+            time_as_token=self.time_as_token,
+        )
+        self.transformer = Transformer(model_args)
+        self.in_channels = args.DiT.in_channels
+        self.out_channels = args.DiT.in_channels
+        self.num_heads = args.DiT.num_heads
+
+        self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
+
+        self.content_type = args.DiT.content_type  # 'discrete' or 'continuous'
+        self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
+        self.content_dim = args.DiT.content_dim # for continuous content
+        self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim)  # discrete content
+        self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
+
+        self.is_causal = args.DiT.is_causal
+
+        self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
+
+        input_pos = torch.arange(16384)
+        self.register_buffer("input_pos", input_pos)
+
+        self.final_layer_type = args.DiT.final_layer_type  # mlp or wavenet
+        if self.final_layer_type == 'wavenet':
+            self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
+            self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
+            self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
+            self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
+                              kernel_size=args.wavenet.kernel_size,
+                              dilation_rate=args.wavenet.dilation_rate,
+                              n_layers=args.wavenet.num_layers,
+                              gin_channels=args.wavenet.hidden_dim,
+                              p_dropout=args.wavenet.p_dropout,
+                              causal=False)
+            self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
+            self.res_projection = nn.Linear(args.DiT.hidden_dim,
+                                            args.wavenet.hidden_dim)  # residual connection from tranformer output to final output
+            self.wavenet_style_condition = args.wavenet.style_condition
+            assert args.DiT.style_condition == args.wavenet.style_condition
+        else:
+            self.final_mlp = nn.Sequential(
+                    nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
+                    nn.SiLU(),
+                    nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
+            )
+        self.transformer_style_condition = args.DiT.style_condition
+
+
+        self.class_dropout_prob = args.DiT.class_dropout_prob
+        self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
+
+        self.long_skip_connection = args.DiT.long_skip_connection
+        self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
+
+        self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
+                                             args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
+                                             args.DiT.hidden_dim)
+        if self.style_as_token:
+            self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
+
+    def setup_caches(self, max_batch_size, max_seq_length):
+        self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
+    def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
+        class_dropout = False
+        if self.training and torch.rand(1) < self.class_dropout_prob:
+            class_dropout = True
+        if not self.training and mask_content:
+            class_dropout = True
+        # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
+        cond_in_module = self.cond_projection
+
+        B, _, T = x.size()
+
+
+        t1 = self.t_embedder(t)  # (N, D)
+
+        cond = cond_in_module(cond)
+
+        x = x.transpose(1, 2)
+        prompt_x = prompt_x.transpose(1, 2)
+
+        x_in = torch.cat([x, prompt_x, cond], dim=-1)
+        if self.transformer_style_condition and not self.style_as_token:
+            x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
+        if class_dropout:
+            x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
+        x_in = self.cond_x_merge_linear(x_in)  # (N, T, D)
+
+        if self.style_as_token:
+            style = self.style_in(style)
+            style = torch.zeros_like(style) if class_dropout else style
+            x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
+        if self.time_as_token:
+            x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
+        x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
+        input_pos = self.input_pos[:x_in.size(1)]  # (T,)
+        x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
+        x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
+        x_res = x_res[:, 1:] if self.time_as_token else x_res
+        x_res = x_res[:, 1:] if self.style_as_token else x_res
+        if self.long_skip_connection:
+            x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
+        if self.final_layer_type == 'wavenet':
+            x = self.conv1(x_res)
+            x = x.transpose(1, 2)
+            t2 = self.t_embedder2(t)
+            x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
+                x_res)  # long residual connection
+            x = self.final_layer(x, t1).transpose(1, 2)
+            x = self.conv2(x)
+        else:
+            x = self.final_mlp(x_res)
+            x = x.transpose(1, 2)
+        return x
\ No newline at end of file
diff --git a/modules/flow_matching.py b/modules/flow_matching.py
index c2581c620f884b7c4b60164729240c310198d74a..61389183c6604edf80a9517ead197dd6aa097740 100644
--- a/modules/flow_matching.py
+++ b/modules/flow_matching.py
@@ -49,6 +49,7 @@ class BASECFM(torch.nn.Module, ABC):
         B, T = mu.size(0), mu.size(1)
         z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+        # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
         return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
 
     def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
@@ -66,7 +67,7 @@ class BASECFM(torch.nn.Module, ABC):
                 shape: (batch_size, spk_emb_dim)
             cond: Not used but kept for future purposes
         """
-        t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+        t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
 
         # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         # Or in future might add like a return_all_steps flag
@@ -79,16 +80,28 @@ class BASECFM(torch.nn.Module, ABC):
         if self.zero_prompt_speech_token:
             mu[..., :prompt_len] = 0
         for step in tqdm(range(1, len(t_span))):
-            dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
-            # Classifier-Free Guidance inference introduced in VoiceBox
+            dt = t_span[step] - t_span[step - 1]
             if inference_cfg_rate > 0:
-                cfg_dphi_dt = self.estimator(
-                    x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
-                    torch.zeros_like(style),
-                    torch.zeros_like(mu), None
+                # Stack original and CFG (null) inputs for batched processing
+                stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
+                stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
+                stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
+                stacked_x = torch.cat([x, x], dim=0)
+                stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
+
+                # Perform a single forward pass for both original and CFG inputs
+                stacked_dphi_dt = self.estimator(
+                    stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
                 )
-                dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
-                           inference_cfg_rate * cfg_dphi_dt)
+
+                # Split the output back into the original and CFG components
+                dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
+
+                # Apply CFG formula
+                dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
+            else:
+                dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
+
             x = x + dt * dphi_dt
             t = t + dt
             sol.append(x)
@@ -97,8 +110,7 @@ class BASECFM(torch.nn.Module, ABC):
             x[:, :, :prompt_len] = 0
 
         return sol[-1]
-
-    def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
+    def forward(self, x1, x_lens, prompt_lens, mu, style):
         """Computes diffusion loss
 
         Args:
@@ -134,13 +146,13 @@ class BASECFM(torch.nn.Module, ABC):
             if self.zero_prompt_speech_token:
                 mu[bib, :, :prompt_lens[bib]] = 0
 
-        estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
+        estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
         loss = 0
         for bib in range(b):
             loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
         loss /= b
 
-        return loss, y
+        return loss, estimator_out + (1 - self.sigma_min) * z
 
 
 
diff --git a/modules/length_regulator.py b/modules/length_regulator.py
index a896c6ced97e409ba657f60af59a2f82e1688e65..8bc875326f8b846a09fbb9602d3ebf3ba6cc3b0f 100644
--- a/modules/length_regulator.py
+++ b/modules/length_regulator.py
@@ -1,141 +1,141 @@
-from typing import Tuple
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from modules.commons import sequence_mask
-import numpy as np
-from dac.nn.quantize import VectorQuantize
-
-# f0_bin = 256
-f0_max = 1100.0
-f0_min = 50.0
-f0_mel_min = 1127 * np.log(1 + f0_min / 700)
-f0_mel_max = 1127 * np.log(1 + f0_max / 700)
-
-def f0_to_coarse(f0, f0_bin):
-  f0_mel = 1127 * (1 + f0 / 700).log()
-  a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
-  b = f0_mel_min * a - 1.
-  f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
-  # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
-  f0_coarse = torch.round(f0_mel).long()
-  f0_coarse = f0_coarse * (f0_coarse > 0)
-  f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
-  f0_coarse = f0_coarse * (f0_coarse < f0_bin)
-  f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
-  return f0_coarse
-
-class InterpolateRegulator(nn.Module):
-    def __init__(
-            self,
-            channels: int,
-            sampling_ratios: Tuple,
-            is_discrete: bool = False,
-            in_channels: int = None,  # only applies to continuous input
-            vector_quantize: bool = False,  # whether to use vector quantization, only applies to continuous input
-            codebook_size: int = 1024, # for discrete only
-            out_channels: int = None,
-            groups: int = 1,
-            n_codebooks: int = 1,  # number of codebooks
-            quantizer_dropout: float = 0.0,  # dropout for quantizer
-            f0_condition: bool = False,
-            n_f0_bins: int = 512,
-    ):
-        super().__init__()
-        self.sampling_ratios = sampling_ratios
-        out_channels = out_channels or channels
-        model = nn.ModuleList([])
-        if len(sampling_ratios) > 0:
-            self.interpolate = True
-            for _ in sampling_ratios:
-                module = nn.Conv1d(channels, channels, 3, 1, 1)
-                norm = nn.GroupNorm(groups, channels)
-                act = nn.Mish()
-                model.extend([module, norm, act])
-        else:
-            self.interpolate = False
-        model.append(
-            nn.Conv1d(channels, out_channels, 1, 1)
-        )
-        self.model = nn.Sequential(*model)
-        self.embedding = nn.Embedding(codebook_size, channels)
-        self.is_discrete = is_discrete
-
-        self.mask_token = nn.Parameter(torch.zeros(1, channels))
-
-        self.n_codebooks = n_codebooks
-        if n_codebooks > 1:
-            self.extra_codebooks = nn.ModuleList([
-                nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
-            ])
-            self.extra_codebook_mask_tokens = nn.ParameterList([
-                nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
-            ])
-        self.quantizer_dropout = quantizer_dropout
-
-        if f0_condition:
-            self.f0_embedding = nn.Embedding(n_f0_bins, channels)
-            self.f0_condition = f0_condition
-            self.n_f0_bins = n_f0_bins
-            self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
-            self.f0_mask = nn.Parameter(torch.zeros(1, channels))
-        else:
-            self.f0_condition = False
-
-        if not is_discrete:
-            self.content_in_proj = nn.Linear(in_channels, channels)
-            if vector_quantize:
-                self.vq = VectorQuantize(channels, codebook_size, 8)
-
-    def forward(self, x, ylens=None, n_quantizers=None, f0=None):
-        # apply token drop
-        if self.training:
-            n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
-            dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
-            n_dropout = int(x.shape[0] * self.quantizer_dropout)
-            n_quantizers[:n_dropout] = dropout[:n_dropout]
-            n_quantizers = n_quantizers.to(x.device)
-            # decide whether to drop for each sample in batch
-        else:
-            n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
-        if self.is_discrete:
-            if self.n_codebooks > 1:
-                assert len(x.size()) == 3
-                x_emb = self.embedding(x[:, 0])
-                for i, emb in enumerate(self.extra_codebooks):
-                    x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
-                    # add mask token if not using this codebook
-                    # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
-                x = x_emb
-            elif self.n_codebooks == 1:
-                if len(x.size()) == 2:
-                    x = self.embedding(x)
-                else:
-                    x = self.embedding(x[:, 0])
-        else:
-            x = self.content_in_proj(x)
-        # x in (B, T, D)
-        mask = sequence_mask(ylens).unsqueeze(-1)
-        if self.interpolate:
-            x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
-        else:
-            x = x.transpose(1, 2).contiguous()
-            mask = mask[:, :x.size(2), :]
-            ylens = ylens.clamp(max=x.size(2)).long()
-        if self.f0_condition:
-            if f0 is None:
-                x = x + self.f0_mask.unsqueeze(-1)
-            else:
-                #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
-                quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
-                quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
-                f0_emb = self.f0_embedding(quantized_f0)
-                f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
-                x = x + f0_emb
-        out = self.model(x).transpose(1, 2).contiguous()
-        if hasattr(self, 'vq'):
-            out_q, commitment_loss, codebook_loss, codes, out,  = self.vq(out.transpose(1, 2))
-            out_q = out_q.transpose(1, 2)
-            return out_q * mask, ylens, codes, commitment_loss, codebook_loss
-        olens = ylens
-        return out * mask, olens, None, None, None
+from typing import Tuple
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from modules.commons import sequence_mask
+import numpy as np
+from dac.nn.quantize import VectorQuantize
+
+# f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+def f0_to_coarse(f0, f0_bin):
+  f0_mel = 1127 * (1 + f0 / 700).log()
+  a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
+  b = f0_mel_min * a - 1.
+  f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
+  # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
+  f0_coarse = torch.round(f0_mel).long()
+  f0_coarse = f0_coarse * (f0_coarse > 0)
+  f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
+  f0_coarse = f0_coarse * (f0_coarse < f0_bin)
+  f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
+  return f0_coarse
+
+class InterpolateRegulator(nn.Module):
+    def __init__(
+            self,
+            channels: int,
+            sampling_ratios: Tuple,
+            is_discrete: bool = False,
+            in_channels: int = None,  # only applies to continuous input
+            vector_quantize: bool = False,  # whether to use vector quantization, only applies to continuous input
+            codebook_size: int = 1024, # for discrete only
+            out_channels: int = None,
+            groups: int = 1,
+            n_codebooks: int = 1,  # number of codebooks
+            quantizer_dropout: float = 0.0,  # dropout for quantizer
+            f0_condition: bool = False,
+            n_f0_bins: int = 512,
+    ):
+        super().__init__()
+        self.sampling_ratios = sampling_ratios
+        out_channels = out_channels or channels
+        model = nn.ModuleList([])
+        if len(sampling_ratios) > 0:
+            self.interpolate = True
+            for _ in sampling_ratios:
+                module = nn.Conv1d(channels, channels, 3, 1, 1)
+                norm = nn.GroupNorm(groups, channels)
+                act = nn.Mish()
+                model.extend([module, norm, act])
+        else:
+            self.interpolate = False
+        model.append(
+            nn.Conv1d(channels, out_channels, 1, 1)
+        )
+        self.model = nn.Sequential(*model)
+        self.embedding = nn.Embedding(codebook_size, channels)
+        self.is_discrete = is_discrete
+
+        self.mask_token = nn.Parameter(torch.zeros(1, channels))
+
+        self.n_codebooks = n_codebooks
+        if n_codebooks > 1:
+            self.extra_codebooks = nn.ModuleList([
+                nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
+            ])
+            self.extra_codebook_mask_tokens = nn.ParameterList([
+                nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
+            ])
+        self.quantizer_dropout = quantizer_dropout
+
+        if f0_condition:
+            self.f0_embedding = nn.Embedding(n_f0_bins, channels)
+            self.f0_condition = f0_condition
+            self.n_f0_bins = n_f0_bins
+            self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
+            self.f0_mask = nn.Parameter(torch.zeros(1, channels))
+        else:
+            self.f0_condition = False
+
+        if not is_discrete:
+            self.content_in_proj = nn.Linear(in_channels, channels)
+            if vector_quantize:
+                self.vq = VectorQuantize(channels, codebook_size, 8)
+
+    def forward(self, x, ylens=None, n_quantizers=None, f0=None):
+        # apply token drop
+        if self.training:
+            n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
+            dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
+            n_dropout = int(x.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(x.device)
+            # decide whether to drop for each sample in batch
+        else:
+            n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
+        if self.is_discrete:
+            if self.n_codebooks > 1:
+                assert len(x.size()) == 3
+                x_emb = self.embedding(x[:, 0])
+                for i, emb in enumerate(self.extra_codebooks):
+                    x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
+                    # add mask token if not using this codebook
+                    # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
+                x = x_emb
+            elif self.n_codebooks == 1:
+                if len(x.size()) == 2:
+                    x = self.embedding(x)
+                else:
+                    x = self.embedding(x[:, 0])
+        else:
+            x = self.content_in_proj(x)
+        # x in (B, T, D)
+        mask = sequence_mask(ylens).unsqueeze(-1)
+        if self.interpolate:
+            x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+        else:
+            x = x.transpose(1, 2).contiguous()
+            mask = mask[:, :x.size(2), :]
+            ylens = ylens.clamp(max=x.size(2)).long()
+        if self.f0_condition:
+            if f0 is None:
+                x = x + self.f0_mask.unsqueeze(-1)
+            else:
+                #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
+                quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
+                quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
+                f0_emb = self.f0_embedding(quantized_f0)
+                f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+                x = x + f0_emb
+        out = self.model(x).transpose(1, 2).contiguous()
+        if hasattr(self, 'vq'):
+            out_q, commitment_loss, codebook_loss, codes, out,  = self.vq(out.transpose(1, 2))
+            out_q = out_q.transpose(1, 2)
+            return out_q * mask, ylens, codes, commitment_loss, codebook_loss
+        olens = ylens
+        return out * mask, olens, None, None, None
diff --git a/modules/rmvpe.py b/modules/rmvpe.py
index 066a9eebdbcb16ab2d9de0e4738ad3575405907f..44ae2e0ec9fde661dd8360d3fd731fe66b5ab51c 100644
--- a/modules/rmvpe.py
+++ b/modules/rmvpe.py
@@ -486,7 +486,13 @@ class RMVPE:
         self.resample_kernel = {}
         self.is_half = is_half
         if device is None:
-            device = "cuda:0" if torch.cuda.is_available() else "cpu"
+            #device = "cuda:0" if torch.cuda.is_available() else "cpu"
+            if torch.cuda.is_available():
+                device = "cuda:0"
+            elif torch.backends.mps.is_available():
+                device = "mps"
+            else:
+                device = "cpu"
         self.device = device
         self.mel_extractor = MelSpectrogram(
             is_half, 128, 16000, 1024, 160, None, 30, 8000
@@ -572,6 +578,37 @@ class RMVPE:
         # t3 = ttime()
         # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
         return f0
+    def infer_from_audio_batch(self, audio, thred=0.03):
+        # torch.cuda.synchronize()
+        # t0 = ttime()
+        if not torch.is_tensor(audio):
+            audio = torch.from_numpy(audio)
+        mel = self.mel_extractor(
+            audio.float().to(self.device), center=True
+        )
+        # print(123123123,mel.device.type)
+        # torch.cuda.synchronize()
+        # t1 = ttime()
+        hidden = self.mel2hidden(mel)
+        # torch.cuda.synchronize()
+        # t2 = ttime()
+        # print(234234,hidden.device.type)
+        if "privateuseone" not in str(self.device):
+            hidden = hidden.cpu().numpy()
+        else:
+            pass
+        if self.is_half == True:
+            hidden = hidden.astype("float32")
+
+        f0s = []
+        for bib in range(hidden.shape[0]):
+            f0s.append(self.decode(hidden[bib], thred=thred))
+        f0s = np.stack(f0s)
+        f0s = torch.from_numpy(f0s).to(self.device)
+        # torch.cuda.synchronize()
+        # t3 = ttime()
+        # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
+        return f0s
 
     def to_local_average_cents(self, salience, thred=0.05):
         # t0 = ttime()
diff --git a/modules/v2/__pycache__/ar.cpython-310.pyc b/modules/v2/__pycache__/ar.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ef2289697c46cebf33dec7e22ad3832e22768d5
Binary files /dev/null and b/modules/v2/__pycache__/ar.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/cfm.cpython-310.pyc b/modules/v2/__pycache__/cfm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fff9c2cda4430b1f3849eb7e0941491130030e32
Binary files /dev/null and b/modules/v2/__pycache__/cfm.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/dit_model.cpython-310.pyc b/modules/v2/__pycache__/dit_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30550b7c8fa982077b5988cb11ecebcffa45ea8a
Binary files /dev/null and b/modules/v2/__pycache__/dit_model.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c16d4408efe4500b7caf352370c03562b0a6b43
Binary files /dev/null and b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/length_regulator.cpython-310.pyc b/modules/v2/__pycache__/length_regulator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d20e70c5fa10e54c6dd8abba283d448e34b66b96
Binary files /dev/null and b/modules/v2/__pycache__/length_regulator.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/model.cpython-310.pyc b/modules/v2/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..202b549ed99cfcf71724feae469c648534d13b30
Binary files /dev/null and b/modules/v2/__pycache__/model.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5ceaa1e09ce09776cb8e26c81e94ccf4850512d
Binary files /dev/null and b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc differ
diff --git a/modules/v2/ar.py b/modules/v2/ar.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c38b99e14c0aa93bf4a908c231437904886ad8e
--- /dev/null
+++ b/modules/v2/ar.py
@@ -0,0 +1,763 @@
+import dataclasses
+import json
+import math
+from collections import OrderedDict
+from functools import partial, wraps
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple, List
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from torch import Tensor
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+def l2norm(t, groups = 1):
+    t = rearrange(t, '... (g d) -> ... g d', g = groups)
+    t = F.normalize(t, p = 2, dim = -1)
+    return rearrange(t, '... g d -> ... (g d)')
+
+@dataclass
+class BaseModelArgs:
+    model_type: str = "base"
+
+    vocab_size: int = 32000
+    n_layer: int = 32
+    n_head: int = 32
+    dim: int = 4096
+    intermediate_size: int = None
+    n_local_heads: int = -1
+    head_dim: int = 64
+    rope_base: float = 10000
+    norm_eps: float = 1e-5
+    max_seq_len: int = 4096
+    dropout: float = 0.0
+    tie_word_embeddings: bool = True
+    attention_qkv_bias: bool = False
+
+    # Gradient checkpointing
+    use_gradient_checkpointing: bool = False
+
+    # Initialize the model
+    initializer_range: float = 0.02
+
+    qk_norm: bool = False
+    layerscale: bool = False
+
+    def __post_init__(self):
+        if self.n_local_heads == -1:
+            self.n_local_heads = self.n_head
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        self.head_dim = self.dim // self.n_head
+
+    def save(self, path: str):
+        with open(path, "w") as f:
+            json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+    model_type: str = "naive"
+
+
+class KVCache(nn.Module):
+    def __init__(
+        self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+    ):
+        super().__init__()
+        cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+    def update(self, input_pos, k_val, v_val):
+        # input_pos: [S], k_val: [B, H, S, D]
+        assert input_pos.shape[0] == k_val.shape[2]
+
+        k_out = self.k_cache
+        v_out = self.v_cache
+        k_out[:, :, input_pos] = k_val
+        v_out[:, :, input_pos] = v_val
+
+        return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+    token_logits: Tensor
+    token_targets: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+    logits: Tensor
+    hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+    def __init__(
+        self,
+        config: BaseModelArgs,
+        init_weights: bool = True,
+    ) -> None:
+        super().__init__()
+        self.config = config
+
+        # Slow transformer
+        self.embeddings = nn.Embedding(
+            config.vocab_size,
+            config.dim,
+        )
+        self.layers = nn.ModuleList(
+            TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+        )
+        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+        if self.config.tie_word_embeddings is False:
+            self.output = nn.Linear(
+                config.dim,
+                config.vocab_size,
+                bias=False,
+            )
+
+        self.register_buffer(
+            "freqs_cis",
+            precompute_freqs_cis(
+                config.max_seq_len,
+                config.dim // config.n_head,
+                config.rope_base,
+            ),
+            persistent=False,
+        )
+        self.register_buffer(
+            "causal_mask",
+            torch.tril(
+                torch.ones(
+                    config.max_seq_len,
+                    config.max_seq_len,
+                    dtype=torch.bool,
+                )
+            ),
+            persistent=False,
+        )
+
+        self.output = nn.Linear(
+            config.dim,
+            config.vocab_size,
+            bias=False,
+        )
+
+        # For kv cache
+        self.max_batch_size = -1
+        self.max_seq_len = -1
+
+        if init_weights:
+            self.apply(self._init_weights)
+
+    def setup_caches(
+        self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"
+    ):
+        if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+            return
+
+        head_dim = self.config.dim // self.config.n_head
+        max_seq_len = find_multiple(max_seq_len, 8)
+        self.max_seq_len = max_seq_len
+        self.max_batch_size = max_batch_size
+
+        for b in self.layers:
+            b.attention.kv_cache = KVCache(
+                max_batch_size,
+                max_seq_len,
+                self.config.n_local_heads,
+                head_dim,
+                dtype=dtype,
+            ).to(device)
+
+    def embed_base(self, x: Tensor, x_lens: Tensor) -> Tensor:
+        for bib in range(x.size(0)):
+            x[bib, x_lens[bib]:] = self.config.vocab_size - 1
+
+        x_emb = self.embeddings(x)
+        return x, x_emb
+
+    def forward(
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
+        input_pos: Optional[Tensor] = None,
+    ) -> BaseTransformerForwardResult:
+        seq_len = inp.size(1)
+
+        # Here we want to merge the embeddings of the codebooks
+        # x = self.embed(inp)
+        x = inp.clone()
+
+        if input_pos is None:
+            freqs_cis = self.freqs_cis[:seq_len].repeat(inp.size(0), 1, 1, 1)
+        else:
+            freqs_cis = self.freqs_cis[input_pos]
+
+        # Not that the causal mask here follows the definition of scaled_dot_product_attention
+        # That is, FALSE means masked out
+        # To maintain consistency, key_padding_mask use TRUE to mask out
+        mask = None
+        if key_padding_mask is not None:
+            mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
+            mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+        for layer in self.layers:
+            if self.config.use_gradient_checkpointing and self.training:
+                x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+            else:
+                x = layer(x, freqs_cis, mask)
+
+        # We got slow_out here
+        slow_out = self.norm(x)
+
+        if self.config.tie_word_embeddings:
+            token_logits = F.linear(slow_out, self.embeddings.weight)
+        else:
+            token_logits = self.output(slow_out)
+
+        return BaseTransformerForwardResult(
+            logits=token_logits,
+            hidden_states=x,
+        )
+
+    def forward_generate(
+        self,
+        inp: Tensor,
+        input_pos: Optional[Tensor] = None,
+        kv_pos: Optional[Tensor] = None,
+        return_all: bool = False,
+    ) -> BaseTransformerForwardResult:
+        # This is used for generation, optimized for torch compile
+
+        x = inp
+        max_seq_len = self.max_seq_len
+
+        mask = self.causal_mask[None, None, kv_pos, :max_seq_len]  # (B, N, Q, K)
+        freqs_cis = self.freqs_cis[input_pos]
+
+        for layer in self.layers:
+            x = layer(x, freqs_cis, mask, input_pos=kv_pos)
+
+        x = x[:, -1:]
+
+        # We got slow_out here
+        slow_out = self.norm(x)
+
+        token_logits = self.output(slow_out)
+
+        return BaseTransformerForwardResult(
+            logits=token_logits,
+            hidden_states=x,
+        )
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+class NaiveTransformer(BaseTransformer):
+    def __init__(self, config: NaiveModelArgs) -> None:
+        super().__init__(config, init_weights=False)
+        self.apply(self._init_weights)
+
+    def forward(
+        self,
+        inp: Tensor,
+        cond_lens: Tensor,
+        target: Tensor,
+        target_lens: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
+        input_pos: Optional[Tensor] = None,
+    ) -> TransformerForwardResult:
+        parent_result = super().forward(
+            inp=inp,
+            key_padding_mask=key_padding_mask,
+            input_pos=input_pos,
+        )
+        token_logits = parent_result.logits
+
+        # construct targets for token_logits
+        token_targets = torch.zeros(token_logits.size(0), token_logits.size(1), dtype=torch.long,
+                                    device=target.device) - 100
+        for bib in range(token_targets.size(0)):
+            token_targets[bib, cond_lens[bib] + 1:cond_lens[bib] + target_lens[bib] + 1] = target[bib, :target_lens[bib]]
+            token_targets[bib, cond_lens[bib] + target_lens[bib] + 1] = self.config.vocab_size - 1
+        return TransformerForwardResult(
+            token_logits=token_logits,
+            token_targets=token_targets,
+        )
+
+    def infer_slow(self, inp: Tensor, input_pos: Optional[Tensor] = None):
+        # no kv cache used
+        parent_result = super().forward(inp, input_pos=input_pos)
+        latent = parent_result.hidden_states[:, -1]
+        base_logits = parent_result.logits[:, -1]
+        base_sampled, _ = topk_sampling(base_logits, top_k=-1, top_p=1.0)
+        return base_sampled
+
+    def forward_generate(
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        kv_pos: Optional[Tensor] = None,
+        vq_masks: Optional[Tensor] = None,
+    ) -> TransformerForwardResult:
+        x = super().forward_generate(x, input_pos, kv_pos, vq_masks)
+        return x
+
+class NaiveWrapper(nn.Module):
+    def __init__(self, model: NaiveTransformer) -> None:
+        super().__init__()
+        self.model = model
+        self.sep_token_emb = nn.Parameter(torch.randn(model.config.dim))
+
+    def setup_caches(self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"):
+        self.model.setup_caches(max_batch_size, max_seq_len, dtype, device)
+
+    def forward(self, cond: Tensor, cond_lens: Tensor, x: Tensor, x_lens: Tensor) -> torch.Tensor:
+        # style_emb = self.style_in(style).unsqueeze(1)  #  [B, 1, D]
+        sep_token_emb = self.sep_token_emb.expand(x.size(0), 1, -1)
+        _, x_emb = self.model.embed_base(x, x_lens)
+        emb_seq_list = []
+        for i in range(x.size(0)):
+            emb_seq = torch.cat([
+                sep_token_emb[i:i + 1],
+                cond[i:i+1, :cond_lens[i]],
+                sep_token_emb[i:i+1],
+                x_emb[i:i+1, :x_lens[i]]], dim=1)
+            emb_seq_list.append(emb_seq)
+        max_len = max([emb_seq.size(1) for emb_seq in emb_seq_list])
+        emb_seq = torch.cat([
+            F.pad(emb_seq, (0, 0, 0, max_len - emb_seq.size(1)), value=0)
+            for emb_seq in emb_seq_list
+        ], dim=0)
+        # input_pos = torch.arange(emb_seq.size(1), device=emb_seq.device).repeat(emb_seq.size(0), 1)
+        input_pos = torch.zeros(emb_seq.size(0), emb_seq.size(1), device=emb_seq.device, dtype=torch.long)
+        for i in range(x.size(0)):
+            input_pos[i, :cond_lens[i] + 1] = torch.arange(cond_lens[i] + 1, device=emb_seq.device)
+            input_pos[i, cond_lens[i] + 1: cond_lens[i] + x_lens[i] + 2] = torch.arange(x_lens[i] + 1, device=emb_seq.device)
+        out = self.model(emb_seq, cond_lens, x, x_lens, input_pos=input_pos)
+        loss = F.cross_entropy(out.token_logits.transpose(1, 2), out.token_targets.long(), ignore_index=-100)
+        return loss
+
+    @torch.no_grad()
+    def infer(self, cond: Tensor) -> torch.Tensor:
+        sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
+        emb_seq = torch.cat([sep_token_emb, cond, sep_token_emb], dim=1)
+        pred_codes = []
+        input_pos = torch.arange(cond.size(1) + 1, device=cond.device)
+        for i in tqdm(range(4000)):
+            input_pos = torch.cat([input_pos, torch.LongTensor([i]).to(cond.device)], dim=0)
+            base = self.model.infer_slow(emb_seq, input_pos)
+            if base == self.model.config.vocab_size - 1:
+                break
+            new_emb = self.model.embed_base(base, torch.LongTensor([1]).to(base.device))[1]
+            emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+            pred_codes.append(base)
+        return torch.cat(pred_codes, dim=-1)
+
+    @torch.no_grad()
+    def generate(
+            self,
+            prompt_text,
+            prompt_target,
+            compiled_decode_fn = None,
+            **sampling_kwargs,
+    ):
+        sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
+        emb_seq = torch.cat([sep_token_emb, prompt_text, sep_token_emb], dim=1)
+        input_pos = torch.arange(prompt_text.size(1) + 1, device=emb_seq.device)
+        input_pos = torch.cat([input_pos, torch.LongTensor([0]).to(emb_seq.device)])
+        prompt_target_emb = self.model.embed_base(prompt_target,torch.LongTensor([prompt_target.size(1)]).to(prompt_target.device))[1]
+        emb_seq = torch.cat([emb_seq, prompt_target_emb], dim=1)
+        input_pos = torch.cat([input_pos, torch.arange(prompt_target_emb.size(1)).to(input_pos.device) + 1])
+
+        pred_codes = []
+        kv_pos = torch.arange(emb_seq.size(1), device=emb_seq.device)
+        next_tokens = self.decode_one_token_ar(emb_seq, input_pos, kv_pos, suppress_tokens=[self.model.config.vocab_size - 1], **sampling_kwargs)
+        pred_base = next_tokens[0]
+        pred_codes.append(pred_base)
+        new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
+        emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+        for _ in tqdm(range(4000)):
+            suppress_eos = len(pred_codes) < 10
+            input_pos = input_pos[-1:] + 1
+            kv_pos = kv_pos[-1:] + 1
+            next_tokens = self.decode_one_token_ar(
+                emb_seq[:, -1:].reshape(1, 1, -1),
+                input_pos.reshape(1),
+                kv_pos.reshape(1),
+                previous_tokens=torch.cat(pred_codes),
+                suppress_tokens=[self.model.config.vocab_size - 1] if suppress_eos else None,
+                compiled_decode_fn=compiled_decode_fn,
+                **sampling_kwargs)
+            pred_base = next_tokens[0]
+            if pred_base == self.model.config.vocab_size - 1:
+                break
+            pred_codes.append(pred_base.clone())
+            new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
+            emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+        return torch.stack(pred_codes, dim=-1)
+
+    def decode_one_token_ar(
+            self,
+            x: torch.Tensor,
+            input_pos: torch.Tensor,
+            kv_pos: torch.Tensor,
+            previous_tokens: torch.Tensor = None,
+            compiled_decode_fn = None,
+            **sampling_kwargs,
+    ) -> torch.Tensor:
+        if compiled_decode_fn is not None:
+            x = compiled_decode_fn(x, input_pos, kv_pos)
+        else:
+            x = self.model.forward_generate(x, input_pos, kv_pos)
+
+        sampling_kwargs_main = sampling_kwargs.copy()
+        codebooks = [
+            sample(
+                x.logits,
+                previous_tokens=(
+                    previous_tokens[0] if previous_tokens is not None else None
+                ),
+                **sampling_kwargs_main,
+            )[0]
+        ]
+        codebooks = torch.stack(codebooks, dim=0)
+        return codebooks
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+        super().__init__()
+        self.attention = Attention(config, use_sdpa=use_sdpa)
+        self.feed_forward = FeedForward(config)
+        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+    def forward(
+        self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+    ) -> Tensor:
+        h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+
+class Attention(nn.Module):
+    def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+        # key, query, value projections for all heads, but in a batch
+        self.wqkv = nn.Linear(
+            config.dim, total_head_dim, bias=config.attention_qkv_bias
+        )
+        self.wo = nn.Linear(config.dim, config.dim, bias=False)
+        self.kv_cache = None
+
+        self.dropout = config.dropout
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.n_local_heads = config.n_local_heads
+        self.dim = config.dim
+        self.use_sdpa = use_sdpa
+        self._register_load_state_dict_pre_hook(self.load_hook)
+        self.qk_norm = config.qk_norm
+        self.qk_norm_groups = 1
+        self.qk_norm_scale = 10
+        self.qk_norm_dim_scale = False
+        self.qk_norm_q_scale = self.qk_norm_k_scale = 1
+
+        if self.qk_norm and self.qk_norm_dim_scale:
+            self.qk_norm_q_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
+            self.qk_norm_k_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
+    def load_hook(self, state_dict, prefix, *args):
+        if prefix + "wq.weight" in state_dict:
+            wq = state_dict.pop(prefix + "wq.weight")
+            wk = state_dict.pop(prefix + "wk.weight")
+            wv = state_dict.pop(prefix + "wv.weight")
+            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+    def forward(
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
+    ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_local_heads * self.head_dim
+        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+        v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+        if self.qk_norm:
+            qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
+            q, k = map(qk_l2norm, (q, k))
+            scale = self.qk_norm_scale
+
+            q = q * self.qk_norm_q_scale
+            k = k * self.qk_norm_k_scale
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+        if self.use_sdpa:
+            if mask is None:
+                y = F.scaled_dot_product_attention(
+                    q,
+                    k,
+                    v,
+                    dropout_p=self.dropout if self.training else 0.0,
+                    is_causal=True,
+                    # No third party attn_mask here to use flash_attention
+                )
+            else:
+                y = F.scaled_dot_product_attention(
+                    q,
+                    k,
+                    v,
+                    attn_mask=mask,
+                    dropout_p=self.dropout if self.training else 0.0,
+                )
+        else:
+            y = self.eq_scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=mask,
+                dropout_p=self.dropout if self.training else 0.0,
+            )
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+        return self.wo(y)
+
+    def eq_scaled_dot_product_attention(
+        self,
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=0.0,
+    ) -> torch.Tensor:
+        # This is a standard scaled dot product attention
+        # It's low efficient, but it doesn't raise cuda error
+
+        L, S = query.size(-2), key.size(-2)
+        scale_factor = 1 / math.sqrt(query.size(-1))
+        attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+            else:
+                attn_bias += attn_mask
+
+        attn_weight = query @ key.transpose(-2, -1) * scale_factor
+        attn_weight += attn_bias
+        attn_weight = torch.softmax(attn_weight, dim=-1)
+        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+        return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+    def __init__(self, config: BaseModelArgs) -> None:
+        super().__init__()
+        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+        self.dropout = nn.Dropout(p=config.dropout)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-5):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+    def forward(self, x: Tensor) -> Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+    freqs = 1.0 / (
+        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+    )
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(x.size(0), xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
+def top_k_top_p_filtering(
+    logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+    Args:
+        logits: logits distribution shape (batch size, vocabulary size)
+        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+        Make sure we keep at least min_tokens_to_keep per batch example in the output
+    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+    """
+    if top_k > 0:
+        top_k = min(
+            max(top_k, min_tokens_to_keep), logits.size(-1)
+        )  # Safety check
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p < 1.0:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(
+            F.softmax(sorted_logits, dim=-1), dim=-1
+        )
+
+        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+        sorted_indices_to_remove = cumulative_probs > top_p
+        if min_tokens_to_keep > 1:
+            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+            ..., :-1
+        ].clone()
+        sorted_indices_to_remove[..., 0] = 0
+
+        # scatter sorted tensors to original indexing
+        indices_to_remove = sorted_indices_to_remove.scatter(
+            1, sorted_indices, sorted_indices_to_remove
+        )
+        logits[indices_to_remove] = filter_value
+    return logits
+
+def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
+    # temperature: (`optional`) float
+    #     The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+    # top_k: (`optional`) int
+    #     The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
+    # top_p: (`optional`) float
+    #     The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
+
+    # Temperature (higher temperature => more likely to sample low probability tokens)
+    if temperature != 1.0:
+        logits = logits / temperature
+    # Top-p/top-k filtering
+    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+    # Sample
+    token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
+    logprobs = F.log_softmax(logits.float(), dim=-1)
+    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
+    return token, current_logprobs
+
+def sample(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    probs = logits_to_probs(
+        logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+    )
+    idx_next = multinomial_sample_one_no_sync(probs)
+    return idx_next, probs
+
+def multinomial_sample_one_no_sync(
+    probs_sort,
+):  # Does multinomial sampling without a cuda synchronization
+    q = torch.empty_like(probs_sort).exponential_(1)
+    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    suppress_tokens: Optional[List[int]] = None,
+    temperature: torch.Tensor = 0.7,
+    top_p: torch.Tensor = 0.7,
+    repetition_penalty: torch.Tensor = 1.5,
+) -> torch.Tensor:
+    # Apply repetition penalty
+    if previous_tokens is not None:
+        previous_tokens = previous_tokens.long()
+        score = torch.gather(logits, dim=0, index=previous_tokens)
+        score = torch.where(
+            score < 0, score * repetition_penalty, score / repetition_penalty
+        )
+        logits.scatter_(dim=0, index=previous_tokens, src=score)
+    if suppress_tokens is not None:
+        for token in suppress_tokens:
+            logits[token] = -float("Inf")
+
+    # Apply top-p sampling
+    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+    sorted_indices_to_remove = cum_probs > top_p
+    sorted_indices_to_remove[0] = False  # keep at least one option
+    indices_to_remove = sorted_indices_to_remove.scatter(
+        dim=0, index=sorted_indices, src=sorted_indices_to_remove
+    )
+    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+    logits = logits / max(temperature, 1e-5)
+
+    probs = torch.nn.functional.softmax(logits, dim=-1)
+    return probs
diff --git a/modules/v2/cfm.py b/modules/v2/cfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ea58ef15f31c324bbcc061f15be8650824790e
--- /dev/null
+++ b/modules/v2/cfm.py
@@ -0,0 +1,173 @@
+import torch
+from tqdm import tqdm
+
+class CFM(torch.nn.Module):
+    def __init__(
+        self,
+        estimator: torch.nn.Module,
+    ):
+        super().__init__()
+        self.sigma_min = 1e-6
+        self.estimator = estimator
+        self.in_channels = estimator.in_channels
+        self.criterion = torch.nn.L1Loss()
+
+    @torch.inference_mode()
+    def inference(self,
+                  mu: torch.Tensor,
+                  x_lens: torch.Tensor,
+                  prompt: torch.Tensor,
+                  style: torch.Tensor,
+                  n_timesteps=10,
+                  temperature=1.0,
+                  inference_cfg_rate=[0.5, 0.5],
+                  random_voice=False,
+                  ):
+        """Forward diffusion
+
+        Args:
+            mu (torch.Tensor): output of encoder
+                shape: (batch_size, n_feats, mel_timesteps)
+            x_lens (torch.Tensor): length of each mel-spectrogram
+                shape: (batch_size,)
+            prompt (torch.Tensor): prompt
+                shape: (batch_size, n_feats, prompt_len)
+            style (torch.Tensor): style
+                shape: (batch_size, style_dim)
+            n_timesteps (int): number of diffusion steps
+            temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+            inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
+
+        Returns:
+            sample: generated mel-spectrogram
+                shape: (batch_size, n_feats, mel_timesteps)
+        """
+        B, T = mu.size(0), mu.size(1)
+        z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
+        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+        t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
+        return self.solve_euler(z, x_lens, prompt, mu, style, t_span, inference_cfg_rate, random_voice)
+    def solve_euler(self, x, x_lens, prompt, mu, style, t_span, inference_cfg_rate=[0.5, 0.5], random_voice=False,):
+        """
+        Fixed euler solver for ODEs.
+        Args:
+            x (torch.Tensor): random noise
+            t_span (torch.Tensor): n_timesteps interpolated
+                shape: (n_timesteps + 1,)
+            mu (torch.Tensor): output of encoder
+                shape: (batch_size, n_feats, mel_timesteps)
+            x_lens (torch.Tensor): length of each mel-spectrogram
+                shape: (batch_size,)
+            prompt (torch.Tensor): prompt
+                shape: (batch_size, n_feats, prompt_len)
+            style (torch.Tensor): style
+                shape: (batch_size, style_dim)
+            inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
+            sway_sampling (bool, optional): Sway sampling. Defaults to False.
+            amo_sampling (bool, optional): AMO sampling. Defaults to False.
+        """
+        t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+
+        # apply prompt
+        prompt_len = prompt.size(-1)
+        prompt_x = torch.zeros_like(x)
+        prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
+        x[..., :prompt_len] = 0
+        for step in tqdm(range(1, len(t_span))):
+            if random_voice:
+                cfg_dphi_dt = self.estimator(
+                    torch.cat([x, x], dim=0),
+                    torch.cat([torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
+                    torch.cat([x_lens, x_lens], dim=0),
+                    torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+                    torch.cat([torch.zeros_like(style), torch.zeros_like(style)], dim=0),
+                    torch.cat([mu, torch.zeros_like(mu)], dim=0),
+                )
+                cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+                dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt - inference_cfg_rate[0] * uncond)
+            elif all(i == 0 for i in inference_cfg_rate):
+                dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
+            elif inference_cfg_rate[0] == 0:
+                # Classifier-Free Guidance inference introduced in VoiceBox
+                cfg_dphi_dt = self.estimator(
+                    torch.cat([x, x], dim=0),
+                    torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
+                    torch.cat([x_lens, x_lens], dim=0),
+                    torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+                    torch.cat([style, torch.zeros_like(style)], dim=0),
+                    torch.cat([mu, mu], dim=0),
+                )
+                cond_txt_spk, cond_txt = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+                dphi_dt = ((1.0 + inference_cfg_rate[1]) * cond_txt_spk - inference_cfg_rate[1] * cond_txt)
+            elif inference_cfg_rate[1] == 0:
+                cfg_dphi_dt = self.estimator(
+                    torch.cat([x, x], dim=0),
+                    torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
+                    torch.cat([x_lens, x_lens], dim=0),
+                    torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+                    torch.cat([style, torch.zeros_like(style)], dim=0),
+                    torch.cat([mu, torch.zeros_like(mu)], dim=0),
+                )
+                cond_txt_spk, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+                dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt_spk - inference_cfg_rate[0] * uncond)
+            else:
+                # Multi-condition Classifier-Free Guidance inference introduced in MegaTTS3
+                cfg_dphi_dt = self.estimator(
+                    torch.cat([x, x, x], dim=0),
+                    torch.cat([prompt_x, torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
+                    torch.cat([x_lens, x_lens, x_lens], dim=0),
+                    torch.cat([t.unsqueeze(0), t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+                    torch.cat([style, torch.zeros_like(style), torch.zeros_like(style)], dim=0),
+                    torch.cat([mu, mu, torch.zeros_like(mu)], dim=0),
+                )
+                cond_txt_spk, cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2], cfg_dphi_dt[2:3]
+                dphi_dt = (1.0 + inference_cfg_rate[0] + inference_cfg_rate[1]) * cond_txt_spk - \
+                    inference_cfg_rate[0] * uncond - inference_cfg_rate[1] * cond_txt
+            x = x + dt * dphi_dt
+            t = t + dt
+            if step < len(t_span) - 1:
+                dt = t_span[step + 1] - t
+            x[:, :, :prompt_len] = 0
+
+        return x
+
+    def forward(self, x1, x_lens, prompt_lens, mu, style):
+        """Computes diffusion loss
+
+        Args:
+            x1 (torch.Tensor): Target
+                shape: (batch_size, n_feats, mel_timesteps)
+            mask (torch.Tensor): target mask
+                shape: (batch_size, 1, mel_timesteps)
+            mu (torch.Tensor): output of encoder
+                shape: (batch_size, n_feats, mel_timesteps)
+            spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+                shape: (batch_size, spk_emb_dim)
+
+        Returns:
+            loss: conditional flow matching loss
+            y: conditional flow
+                shape: (batch_size, n_feats, mel_timesteps)
+        """
+        b, _, t = x1.shape
+
+        # random timestep
+        t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
+        # sample noise p(x_0)
+        z = torch.randn_like(x1)
+
+        y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+        u = x1 - (1 - self.sigma_min) * z
+        prompt = torch.zeros_like(x1)
+        for bib in range(b):
+            prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
+            # range covered by prompt are set to 0
+            y[bib, :, :prompt_lens[bib]] = 0
+
+        estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu)
+        loss = 0
+        for bib in range(b):
+            loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
+        loss /= b
+
+        return loss
diff --git a/modules/v2/dit_model.py b/modules/v2/dit_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4374ac86a4d4d0869788cdd16087115c4418ba5f
--- /dev/null
+++ b/modules/v2/dit_model.py
@@ -0,0 +1,250 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional, Union, Tuple, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+import time
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNorm, self).__init__()
+        self.linear = nn.Linear(d_model, 6 * d_model)
+        self.act = nn.SiLU()
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]:
+        emb = self.linear(self.act(emb))
+        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=-1)
+
+        x = self.norm(x) * (1 + scale_msa) + shift_msa
+        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+class AdaptiveLayerNormFinal(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNormFinal, self).__init__()
+        self.linear = nn.Linear(d_model, 2 * d_model)
+        self.act = nn.SiLU()
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]:
+        emb = self.linear(self.act(emb))
+        scale, shift = torch.chunk(emb, 2, dim=-1)
+
+        x = self.norm(x) * (1 + scale) + shift
+        return x
+
+@dataclass
+class ModelArgs:
+    block_size: int = 2048
+    vocab_size: int = 32000
+    n_layer: int = 32
+    n_head: int = 32
+    dim: int = 4096
+    intermediate_size: int = None
+    n_local_heads: int = -1
+    head_dim: int = 64
+    rope_base: float = 10000
+    norm_eps: float = 1e-5
+    uvit_skip_connection: bool = False
+    time_as_token: bool = False
+    dropout_rate: float = 0.1
+    attn_dropout_rate: float = 0.1
+
+    def __post_init__(self):
+        if self.n_local_heads == -1:
+            self.n_local_heads = self.n_head
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.config = config
+
+        self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+        self.norm = AdaptiveLayerNormFinal(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        self.max_batch_size = -1
+        self.max_seq_length = config.block_size
+
+        self.uvit_skip_connection = self.config.uvit_skip_connection
+        if self.uvit_skip_connection:
+            self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+            self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+        else:
+            self.layers_emit_skip = []
+            self.layers_receive_skip = []
+        freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+                                              self.config.rope_base)
+        self.register_buffer("freqs_cis", freqs_cis)
+
+        causal_mask = torch.tril(
+            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+        )
+        self.register_buffer("causal_mask", causal_mask)
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Optional[Tensor] = None,
+                mask: Optional[Tensor] = None,
+                ) -> Tensor:
+        mask = mask[..., input_pos]
+        freqs_cis = self.freqs_cis[input_pos]
+        for i, layer in enumerate(self.layers):
+            x = layer(x, c, freqs_cis, mask)
+        x = self.norm(x, c)
+        return x
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.attention = Attention(config)
+        self.feed_forward = FeedForward(config)
+        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                ) -> Tensor:
+        normed_x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=c)
+        # attention
+        attn_output = self.attention(normed_x, freqs_cis, mask)
+        x = x + gate_msa * attn_output
+        normed_x = self.ffn_norm(x) * (1 + scale_mlp) + shift_mlp
+        ff_output = self.feed_forward(normed_x)
+        x = x + gate_mlp * ff_output
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+        # key, query, value projections for all heads, but in a batch
+        if is_cross_attention:
+            self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+            self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+        else:
+            self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+        self.kv_cache = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.n_local_heads = config.n_local_heads
+        self.dim = config.dim
+        self.attn_dropout_rate = config.attn_dropout_rate
+
+    def forward(self,
+                x: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_local_heads * self.head_dim
+        q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+        context_seqlen = seqlen
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+        v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+        self.dropout = nn.Dropout(config.dropout_rate)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-5):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+    def forward(self, x: Tensor) -> Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(
+        seq_len: int, n_elem: int, base: int = 10000,
+        dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
diff --git a/modules/v2/dit_wrapper.py b/modules/v2/dit_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f653239de475fba17794d6b8d7e28a8edd1b65a0
--- /dev/null
+++ b/modules/v2/dit_wrapper.py
@@ -0,0 +1,152 @@
+import torch
+from torch import nn
+import math
+
+from modules.v2.dit_model import ModelArgs, Transformer
+from modules.commons import sequence_mask
+
+from torch.nn.utils import weight_norm
+
+def modulate(x, shift, scale):
+    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+#               Embedding Layers for Timesteps and Class Labels                 #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, hidden_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    @staticmethod
+    def timestep_embedding(t, dim, max_period=10000, scale=1000):
+        """
+        Create sinusoidal timestep embeddings.
+        :param t: a 1-D Tensor of N indices, one per batch element.
+                          These may be fractional.
+        :param dim: the dimension of the output.
+        :param max_period: controls the minimum frequency of the embeddings.
+        :return: an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        ).to(device=t.device)
+        args = scale * t[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+class DiT(torch.nn.Module):
+    def __init__(
+        self,
+        time_as_token,
+        style_as_token,
+        uvit_skip_connection,
+        block_size,
+        depth,
+        num_heads,
+        hidden_dim,
+        in_channels,
+        content_dim,
+        style_encoder_dim,
+        class_dropout_prob,
+        dropout_rate,
+        attn_dropout_rate,
+    ):
+        super(DiT, self).__init__()
+        self.time_as_token = time_as_token
+        self.style_as_token = style_as_token
+        self.uvit_skip_connection = uvit_skip_connection
+        model_args = ModelArgs(
+            block_size=block_size,
+            n_layer=depth,
+            n_head=num_heads,
+            dim=hidden_dim,
+            head_dim=hidden_dim // num_heads,
+            vocab_size=1, # we don't use this
+            uvit_skip_connection=self.uvit_skip_connection,
+            time_as_token=self.time_as_token,
+            dropout_rate=dropout_rate,
+            attn_dropout_rate=attn_dropout_rate,
+        )
+        self.transformer = Transformer(model_args)
+        self.in_channels = in_channels
+        self.out_channels = in_channels
+        self.num_heads = num_heads
+
+        self.x_embedder = weight_norm(nn.Linear(in_channels, hidden_dim, bias=True))
+
+        self.content_dim = content_dim # for continuous content
+        self.cond_projection = nn.Linear(content_dim, hidden_dim, bias=True) # continuous content
+
+        self.t_embedder = TimestepEmbedder(hidden_dim)
+
+        self.final_mlp = nn.Sequential(
+                nn.Linear(hidden_dim, hidden_dim),
+                nn.SiLU(),
+                nn.Linear(hidden_dim, in_channels),
+        )
+
+        self.class_dropout_prob = class_dropout_prob
+
+        self.cond_x_merge_linear = nn.Linear(hidden_dim + in_channels + in_channels, hidden_dim)
+        self.style_in = nn.Linear(style_encoder_dim, hidden_dim)
+
+    def forward(self, x, prompt_x, x_lens, t, style, cond):
+        class_dropout = False
+        content_dropout = False
+        if self.training and torch.rand(1) < self.class_dropout_prob:
+            class_dropout = True
+            if self.training and torch.rand(1) < 0.5:
+                content_dropout = True
+        cond_in_module = self.cond_projection
+
+        B, _, T = x.size()
+
+        t1 = self.t_embedder(t)  # (N, D)
+        cond = cond_in_module(cond)
+
+        x = x.transpose(1, 2)
+        prompt_x = prompt_x.transpose(1, 2)
+
+        x_in = torch.cat([x, prompt_x, cond], dim=-1)
+        if class_dropout:
+            x_in[..., self.in_channels:self.in_channels*2] = 0
+            if content_dropout:
+                x_in[..., self.in_channels*2:] = 0
+        x_in = self.cond_x_merge_linear(x_in)  # (N, T, D)
+
+        style = self.style_in(style)
+        style = torch.zeros_like(style) if class_dropout else style
+        if self.style_as_token:
+            x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
+        if self.time_as_token:
+            x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
+        x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1)
+        input_pos = torch.arange(x_in.size(1)).to(x.device)
+        x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1)
+        x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
+        x_res = x_res[:, 1:] if self.time_as_token else x_res
+        x_res = x_res[:, 1:] if self.style_as_token else x_res
+        x = self.final_mlp(x_res)
+        x = x.transpose(1, 2)
+        return x
diff --git a/modules/v2/length_regulator.py b/modules/v2/length_regulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efe5a62bc5afba06a8abe5051aace9ad97dbf3e
--- /dev/null
+++ b/modules/v2/length_regulator.py
@@ -0,0 +1,105 @@
+from typing import Tuple
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from modules.commons import sequence_mask
+import numpy as np
+
+# f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+def f0_to_coarse(f0, f0_bin):
+  f0_mel = 1127 * (1 + f0 / 700).log()
+  a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
+  b = f0_mel_min * a - 1.
+  f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
+  # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
+  f0_coarse = torch.round(f0_mel).long()
+  f0_coarse = f0_coarse * (f0_coarse > 0)
+  f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
+  f0_coarse = f0_coarse * (f0_coarse < f0_bin)
+  f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
+  return f0_coarse
+
+class InterpolateRegulator(nn.Module):
+    def __init__(
+            self,
+            channels: int,
+            sampling_ratios: Tuple,
+            is_discrete: bool = False,
+            in_channels: int = None,  # only applies to continuous input
+            codebook_size: int = 1024, # for discrete only
+            out_channels: int = None,
+            groups: int = 1,
+            f0_condition: bool = False,
+            n_f0_bins: int = 512,
+    ):
+        super().__init__()
+        self.sampling_ratios = sampling_ratios
+        out_channels = out_channels or channels
+        model = nn.ModuleList([])
+        if len(sampling_ratios) > 0:
+            self.interpolate = True
+            for _ in sampling_ratios:
+                module = nn.Conv1d(channels, channels, 3, 1, 1)
+                norm = nn.GroupNorm(groups, channels)
+                act = nn.Mish()
+                model.extend([module, norm, act])
+        else:
+            self.interpolate = False
+        model.append(
+            nn.Conv1d(channels, out_channels, 1, 1) if channels != out_channels else nn.Identity()
+        )
+        self.model = nn.Sequential(*model)
+        self.embedding = nn.Embedding(codebook_size, channels)
+        self.is_discrete = is_discrete
+
+        self.mask_token = nn.Parameter(torch.zeros(1, channels))
+
+        if f0_condition:
+            self.f0_embedding = nn.Embedding(n_f0_bins, channels)
+            self.f0_condition = f0_condition
+            self.n_f0_bins = n_f0_bins
+            self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
+            self.f0_mask = nn.Parameter(torch.zeros(1, channels))
+        else:
+            self.f0_condition = False
+
+        if not is_discrete:
+            self.content_in_proj = nn.Linear(in_channels, channels)
+
+    def forward(self, x, ylens=None, f0=None):
+        if self.is_discrete:
+            if len(x.size()) == 2:
+                x = self.embedding(x)
+            else:
+                x = self.embedding(x[:, 0])
+        else:
+            x = self.content_in_proj(x)
+        # x in (B, T, D)
+
+        if self.interpolate:
+            mask = sequence_mask(ylens).unsqueeze(-1)
+            x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+        else:
+            x = x.transpose(1, 2).contiguous()
+            mask = None
+            # mask = mask[:, :x.size(2), :]
+            # ylens = ylens.clamp(max=x.size(2)).long()
+        if self.f0_condition:
+            if f0 is None:
+                x = x + self.f0_mask.unsqueeze(-1)
+            else:
+                # quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
+                quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
+                quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
+                f0_emb = self.f0_embedding(quantized_f0)
+                f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+                x = x + f0_emb
+        out = self.model(x).transpose(1, 2).contiguous()
+        out = out * mask if mask is not None else out
+        olens = ylens
+        return out, olens
diff --git a/modules/v2/model.py b/modules/v2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a96dd0b6c58991ca3e203ca6c5247dc0413e48b4
--- /dev/null
+++ b/modules/v2/model.py
@@ -0,0 +1,302 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNorm, self).__init__()
+        self.project_layer = nn.Linear(d_model, 2 * d_model)
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+        if embedding is None:
+            return self.norm(input)
+        weight, bias = torch.split(
+            self.project_layer(embedding),
+            split_size_or_sections=self.d_model,
+            dim=-1,
+        )
+        return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+    block_size: int = 2048
+    vocab_size: int = 32000
+    n_layer: int = 32
+    n_head: int = 32
+    dim: int = 4096
+    intermediate_size: int = None
+    n_local_heads: int = -1
+    head_dim: int = 64
+    rope_base: float = 10000
+    norm_eps: float = 1e-5
+    has_cross_attention: bool = False
+    context_dim: int = 0
+    uvit_skip_connection: bool = False
+    time_as_token: bool = False
+
+    def __post_init__(self):
+        if self.n_local_heads == -1:
+            self.n_local_heads = self.n_head
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.config = config
+
+        self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+        self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        self.freqs_cis: Optional[Tensor] = None
+        self.mask_cache: Optional[Tensor] = None
+        self.max_batch_size = -1
+        self.max_seq_length = -1
+
+    def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False):
+        if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+            return
+        head_dim = self.config.dim // self.config.n_head
+        max_seq_length = find_multiple(max_seq_length, 8)
+        self.max_seq_length = max_seq_length
+        self.max_batch_size = max_batch_size
+        dtype = self.norm.project_layer.weight.dtype
+        device = self.norm.project_layer.weight.device
+
+        self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+                                              self.config.rope_base, dtype).to(device)
+        self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
+        self.use_kv_cache = use_kv_cache
+        self.uvit_skip_connection = self.config.uvit_skip_connection
+        if self.uvit_skip_connection:
+            self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+            self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+        else:
+            self.layers_emit_skip = []
+            self.layers_receive_skip = []
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Optional[Tensor] = None,
+                mask: Optional[Tensor] = None,
+                context: Optional[Tensor] = None,
+                context_input_pos: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                ) -> Tensor:
+        assert self.freqs_cis is not None, "Caches must be initialized first"
+        if mask is None: # in case of non-causal model
+            if not self.training and self.use_kv_cache:
+                mask = self.causal_mask[None, None, input_pos]
+            else:
+                mask = self.causal_mask[None, None, input_pos]
+                mask = mask[..., input_pos]
+        freqs_cis = self.freqs_cis[input_pos]
+        if context is not None:
+            context_freqs_cis = self.freqs_cis[context_input_pos]
+        else:
+            context_freqs_cis = None
+        skip_in_x_list = []
+        for i, layer in enumerate(self.layers):
+            if self.uvit_skip_connection and i in self.layers_receive_skip:
+                skip_in_x = skip_in_x_list.pop(-1)
+            else:
+                skip_in_x = None
+            x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
+            if self.uvit_skip_connection and i in self.layers_emit_skip:
+                skip_in_x_list.append(x)
+        x = self.norm(x, c)
+        return x
+
+    @classmethod
+    def from_name(cls, name: str):
+        return cls(ModelArgs.from_name(name))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.attention = Attention(config)
+        self.feed_forward = FeedForward(config)
+        self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+        if config.has_cross_attention:
+            self.has_cross_attention = True
+            self.cross_attention = Attention(config, is_cross_attention=True)
+            self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+        else:
+            self.has_cross_attention = False
+
+        if config.uvit_skip_connection:
+            self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
+            self.uvit_skip_connection = True
+        else:
+            self.uvit_skip_connection = False
+
+        self.time_as_token = config.time_as_token
+
+    def forward(self,
+                x: Tensor,
+                c: Tensor,
+                input_pos: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                cross_attention_mask: Optional[Tensor] = None,
+                skip_in_x: Optional[Tensor] = None,
+                ) -> Tensor:
+        c = None if self.time_as_token else c
+        if self.uvit_skip_connection and skip_in_x is not None:
+            x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
+        h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
+        if self.has_cross_attention:
+            h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
+        out = h + self.feed_forward(self.ffn_norm(h, c))
+        return out
+
+
+class Attention(nn.Module):
+    def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+        # key, query, value projections for all heads, but in a batch
+        if is_cross_attention:
+            self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+            self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+        else:
+            self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+        self.kv_cache = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.n_local_heads = config.n_local_heads
+        self.dim = config.dim
+        # self._register_load_state_dict_pre_hook(self.load_hook)
+
+    # def load_hook(self, state_dict, prefix, *args):
+    #     if prefix + "wq.weight" in state_dict:
+    #         wq = state_dict.pop(prefix + "wq.weight")
+    #         wk = state_dict.pop(prefix + "wk.weight")
+    #         wv = state_dict.pop(prefix + "wv.weight")
+    #         state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+    def forward(self,
+                x: Tensor,
+                freqs_cis: Tensor,
+                mask: Tensor,
+                input_pos: Optional[Tensor] = None,
+                context: Optional[Tensor] = None,
+                context_freqs_cis: Optional[Tensor] = None,
+                ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_local_heads * self.head_dim
+        if context is None:
+            q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+            context_seqlen = seqlen
+        else:
+            q = self.wq(x)
+            k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+            context_seqlen = context.shape[1]
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+        v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-5):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+    def forward(self, x: Tensor) -> Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(
+        seq_len: int, n_elem: int, base: int = 10000,
+        dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
diff --git a/modules/v2/vc_wrapper.py b/modules/v2/vc_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..359e805001c831bbc257f2da7d377a32967a27df
--- /dev/null
+++ b/modules/v2/vc_wrapper.py
@@ -0,0 +1,606 @@
+import spaces
+import torch
+import librosa
+import torchaudio
+import numpy as np
+from pydub import AudioSegment
+from hf_utils import load_custom_model_from_hf
+
+DEFAULT_REPO_ID = "Plachta/Seed-VC"
+DEFAULT_CFM_CHECKPOINT = "v2/cfm_small.pth"
+DEFAULT_AR_CHECKPOINT = "v2/ar_base.pth"
+
+DEFAULT_CE_REPO_ID = "Plachta/ASTRAL-quantization"
+DEFAULT_CE_NARROW_CHECKPOINT = "bsq32/bsq32_light.pth"
+DEFAULT_CE_WIDE_CHECKPOINT = "bsq2048/bsq2048_light.pth"
+
+DEFAULT_SE_REPO_ID = "funasr/campplus"
+DEFAULT_SE_CHECKPOINT = "campplus_cn_common.bin"
+
+class VoiceConversionWrapper(torch.nn.Module):
+    def __init__(
+            self,
+            sr: int,
+            hop_size: int,
+            mel_fn: callable,
+            cfm: torch.nn.Module,
+            cfm_length_regulator: torch.nn.Module,
+            content_extractor_narrow: torch.nn.Module,
+            content_extractor_wide: torch.nn.Module,
+            ar_length_regulator: torch.nn.Module,
+            ar: torch.nn.Module,
+            style_encoder: torch.nn.Module,
+            vocoder: torch.nn.Module,
+            ):
+        super(VoiceConversionWrapper, self).__init__()
+        self.sr = sr
+        self.hop_size = hop_size
+        self.mel_fn = mel_fn
+        self.cfm = cfm
+        self.cfm_length_regulator = cfm_length_regulator
+        self.content_extractor_narrow = content_extractor_narrow
+        self.content_extractor_wide = content_extractor_wide
+        self.vocoder = vocoder
+        self.ar_length_regulator = ar_length_regulator
+        self.ar = ar
+        self.style_encoder = style_encoder
+        # Set streaming parameters
+        self.overlap_frame_len = 16
+        self.bitrate = "320k"
+        self.compiled_decode_fn = None
+        self.dit_compiled = False
+        self.dit_max_context_len = 30  # in seconds
+        self.compile_len = 87 * self.dit_max_context_len
+
+    def compile_ar(self):
+        """
+        Compile the AR model for inference.
+        """
+        self.compiled_decode_fn = torch.compile(
+            self.ar.model.forward_generate,
+            fullgraph=True,
+            backend="inductor" if torch.cuda.is_available() else "aot_eager",
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
+        )
+
+    def compile_cfm(self):
+        self.cfm.estimator.transformer = torch.compile(
+            self.cfm.estimator.transformer,
+            fullgraph=True,
+            backend="inductor" if torch.cuda.is_available() else "aot_eager",
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
+        )
+        self.dit_compiled = True
+
+    @staticmethod
+    def strip_prefix(state_dict: dict, prefix: str = "module.") -> dict:
+        """
+        Strip the prefix from the state_dict keys.
+        """
+        new_state_dict = {}
+        for k, v in state_dict.items():
+            if k.startswith(prefix):
+                new_key = k[len(prefix):]
+            else:
+                new_key = k
+            new_state_dict[new_key] = v
+        return new_state_dict
+
+    @staticmethod
+    def duration_reduction_func(token_seq, n_gram=1):
+        """
+        Args:
+            token_seq: (T,)
+        Returns:
+            reduced_token_seq: (T')
+            reduced_token_seq_len: T'
+        """
+        n_gram_seq = token_seq.unfold(0, n_gram, 1)
+        mask = torch.all(n_gram_seq[1:] != n_gram_seq[:-1], dim=1)
+        reduced_token_seq = torch.cat(
+            (n_gram_seq[0, :n_gram], n_gram_seq[1:, -1][mask])
+        )
+        return reduced_token_seq, len(reduced_token_seq)
+        
+    @staticmethod
+    def crossfade(chunk1, chunk2, overlap):
+        """Apply crossfade between two audio chunks."""
+        fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
+        fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
+        if len(chunk2) < overlap:
+            chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
+        else:
+            chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
+        return chunk2
+
+    def _stream_wave_chunks(self, vc_wave, processed_frames, vc_mel, overlap_wave_len, 
+                           generated_wave_chunks, previous_chunk, is_last_chunk, stream_output):
+        """
+        Helper method to handle streaming wave chunks.
+        
+        Args:
+            vc_wave: The current wave chunk
+            processed_frames: Number of frames processed so far
+            vc_mel: The mel spectrogram
+            overlap_wave_len: Length of overlap between chunks
+            generated_wave_chunks: List of generated wave chunks
+            previous_chunk: Previous wave chunk for crossfading
+            is_last_chunk: Whether this is the last chunk
+            stream_output: Whether to stream the output
+            
+        Returns:
+            Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio)
+            where should_break indicates if processing should stop
+            mp3_bytes is the MP3 bytes if streaming, None otherwise
+            full_audio is the full audio if this is the last chunk, None otherwise
+        """
+        mp3_bytes = None
+        full_audio = None
+        
+        if processed_frames == 0:
+            if is_last_chunk:
+                output_wave = vc_wave[0].cpu().numpy()
+                generated_wave_chunks.append(output_wave)
+
+                if stream_output:
+                    output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                    mp3_bytes = AudioSegment(
+                        output_wave_int16.tobytes(), frame_rate=self.sr,
+                        sample_width=output_wave_int16.dtype.itemsize, channels=1
+                    ).export(format="mp3", bitrate=self.bitrate).read()
+                    full_audio = (self.sr, np.concatenate(generated_wave_chunks))
+                else:
+                    return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+                return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+            output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
+            generated_wave_chunks.append(output_wave)
+            previous_chunk = vc_wave[0, -overlap_wave_len:]
+            processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=self.sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+
+        elif is_last_chunk:
+            output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
+            generated_wave_chunks.append(output_wave)
+            processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=self.sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+                full_audio = (self.sr, np.concatenate(generated_wave_chunks))
+            else:
+                return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+            return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+        else:
+            output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
+            generated_wave_chunks.append(output_wave)
+            previous_chunk = vc_wave[0, -overlap_wave_len:]
+            processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=self.sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+                
+        return processed_frames, previous_chunk, False, mp3_bytes, full_audio
+
+    def load_checkpoints(
+            self,
+            cfm_checkpoint_path = None,
+            ar_checkpoint_path = None,
+    ):
+        if cfm_checkpoint_path is None:
+            cfm_checkpoint_path = load_custom_model_from_hf(
+                repo_id=DEFAULT_REPO_ID,
+                model_filename=DEFAULT_CFM_CHECKPOINT,
+            )
+        if ar_checkpoint_path is None:
+            ar_checkpoint_path = load_custom_model_from_hf(
+                repo_id=DEFAULT_REPO_ID,
+                model_filename=DEFAULT_AR_CHECKPOINT,
+            )
+        # cfm
+        cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
+        cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
+        cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
+        self.cfm.load_state_dict(cfm_state_dict, strict=False)
+        self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
+
+        # ar
+        ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
+        ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
+        ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
+        self.ar.load_state_dict(ar_state_dict, strict=False)
+        self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
+
+        # content extractor
+        content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
+            repo_id=DEFAULT_CE_REPO_ID,
+            model_filename=DEFAULT_CE_NARROW_CHECKPOINT,
+        )
+        content_extractor_narrow_checkpoint = torch.load(content_extractor_narrow_checkpoint_path, map_location="cpu")
+        self.content_extractor_narrow.load_state_dict(
+            content_extractor_narrow_checkpoint, strict=False
+        )
+
+        content_extractor_wide_checkpoint_path = load_custom_model_from_hf(
+            repo_id=DEFAULT_CE_REPO_ID,
+            model_filename=DEFAULT_CE_WIDE_CHECKPOINT,
+        )
+        content_extractor_wide_checkpoint = torch.load(content_extractor_wide_checkpoint_path, map_location="cpu")
+        self.content_extractor_wide.load_state_dict(
+            content_extractor_wide_checkpoint, strict=False
+        )
+
+        # style encoder
+        style_encoder_checkpoint_path = load_custom_model_from_hf(DEFAULT_SE_REPO_ID, DEFAULT_SE_CHECKPOINT, config_filename=None)
+        style_encoder_checkpoint = torch.load(style_encoder_checkpoint_path, map_location="cpu")
+        self.style_encoder.load_state_dict(style_encoder_checkpoint, strict=False)
+
+    def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
+        self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
+
+    def compute_style(self, waves_16k: torch.Tensor):
+        feat = torchaudio.compliance.kaldi.fbank(waves_16k,
+                                                  num_mel_bins=80,
+                                                  dither=0,
+                                                  sample_frequency=16000)
+        feat = feat - feat.mean(dim=0, keepdim=True)
+        style = self.style_encoder(feat.unsqueeze(0))
+        return style
+
+    @torch.no_grad()
+    @torch.inference_mode()
+    def convert_timbre(
+            self,
+            source_audio_path: str,
+            target_audio_path: str,
+            diffusion_steps: int = 30,
+            length_adjust: float = 1.0,
+            inference_cfg_rate: float = 0.5,
+            use_sway_sampling: bool = False,
+            use_amo_sampling: bool = False,
+            device: torch.device = torch.device("cpu"),
+            dtype: torch.dtype = torch.float32,
+    ):
+        source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+        target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+        source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device)
+        target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device)
+
+        # get 16khz audio
+        source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+        target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+        source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+        target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+        # compute mel spectrogram
+        source_mel = self.mel_fn(source_wave_tensor)
+        target_mel = self.mel_fn(target_wave_tensor)
+        source_mel_len = source_mel.size(2)
+        target_mel_len = target_mel.size(2)
+
+        with torch.autocast(device_type=device.type, dtype=dtype):
+            # compute content features
+            _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size])
+            _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size])
+
+            # compute style features
+            target_style = self.compute_style(target_wave_16k_tensor)
+
+            # Length regulation
+            cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
+            prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device))
+
+            cat_condition = torch.cat([prompt_condition, cond], dim=1)
+            # generate mel spectrogram
+            vc_mel = self.cfm.inference(
+                cat_condition,
+                torch.LongTensor([cat_condition.size(1)]).to(device),
+                target_mel, target_style, diffusion_steps,
+                inference_cfg_rate=inference_cfg_rate,
+                sway_sampling=use_sway_sampling,
+                amo_sampling=use_amo_sampling,
+            )
+        vc_mel = vc_mel[:, :, target_mel_len:]
+        vc_wave = self.vocoder(vc_mel.float()).squeeze()[None]
+        return vc_wave.cpu().numpy()
+
+    @torch.no_grad()
+    @torch.inference_mode()
+    def convert_voice(
+            self,
+            source_audio_path: str,
+            target_audio_path: str,
+            diffusion_steps: int = 30,
+            length_adjust: float = 1.0,
+            inference_cfg_rate: float = 0.5,
+            top_p: float = 0.7,
+            temperature: float = 0.7,
+            repetition_penalty: float = 1.5,
+            use_sway_sampling: bool = False,
+            use_amo_sampling: bool = False,
+            device: torch.device = torch.device("cpu"),
+            dtype: torch.dtype = torch.float32,
+    ):
+        source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+        target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+        source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device)
+        target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device)
+
+        # get 16khz audio
+        source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+        target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+        source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+        target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+        # compute mel spectrogram
+        source_mel = self.mel_fn(source_wave_tensor)
+        target_mel = self.mel_fn(target_wave_tensor)
+        source_mel_len = source_mel.size(2)
+        target_mel_len = target_mel.size(2)
+
+        with torch.autocast(device_type=device.type, dtype=dtype):
+            # compute content features
+            _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size])
+            _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size])
+
+            _, source_narrow_indices, _ = self.content_extractor_narrow(source_wave_16k_tensor,
+                                                                         [source_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model)
+            _, target_narrow_indices, _ = self.content_extractor_narrow(target_wave_16k_tensor,
+                                                                         [target_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model)
+
+            src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1)
+            tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1)
+
+            ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, src_narrow_reduced], dim=0)[None])[0]
+
+            ar_out = self.ar.generate(ar_cond, target_content_indices, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty)
+            ar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size(-1) * ar_out.size(-1) * length_adjust)]).to(device)
+            # compute style features
+            target_style = self.compute_style(target_wave_16k_tensor)
+
+            # Length regulation
+            cond, _ = self.cfm_length_regulator(ar_out, ylens=torch.LongTensor([ar_out_mel_len]).to(device))
+            prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device))
+
+            cat_condition = torch.cat([prompt_condition, cond], dim=1)
+            # generate mel spectrogram
+            vc_mel = self.cfm.inference(
+                cat_condition,
+                torch.LongTensor([cat_condition.size(1)]).to(device),
+                target_mel, target_style, diffusion_steps,
+                inference_cfg_rate=inference_cfg_rate,
+                sway_sampling=use_sway_sampling,
+                amo_sampling=use_amo_sampling,
+            )
+        vc_mel = vc_mel[:, :, target_mel_len:]
+        vc_wave = self.vocoder(vc_mel.float()).squeeze()[None]
+        return vc_wave.cpu().numpy()
+
+    def _process_content_features(self, audio_16k_tensor, is_narrow=False):
+        """Process audio through Whisper model to extract features."""
+        content_extractor_fn = self.content_extractor_narrow if is_narrow else self.content_extractor_wide
+        if audio_16k_tensor.size(-1) <= 16000 * 30:
+            # Compute content features
+            _, content_indices, _ = content_extractor_fn(audio_16k_tensor, [audio_16k_tensor.size(-1)], ssl_model=self.content_extractor_wide.ssl_model)
+        else:
+            # Process long audio in chunks
+            overlapping_time = 5  # 5 seconds
+            features_list = []
+            buffer = None
+            traversed_time = 0
+            while traversed_time < audio_16k_tensor.size(-1):
+                if buffer is None:  # first chunk
+                    chunk = audio_16k_tensor[:, traversed_time:traversed_time + 16000 * 30]
+                else:
+                    chunk = torch.cat([
+                        buffer,
+                        audio_16k_tensor[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]
+                    ], dim=-1)
+                _, chunk_content_indices, _ = content_extractor_fn(chunk, [chunk.size(-1)], ssl_model=self.content_extractor_wide.ssl_model)
+                if traversed_time == 0:
+                    features_list.append(chunk_content_indices)
+                else:
+                    features_list.append(chunk_content_indices[:, 50 * overlapping_time:])
+                buffer = chunk[:, -16000 * overlapping_time:]
+                traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
+            content_indices = torch.cat(features_list, dim=1)
+
+        return content_indices
+
+    @spaces.GPU
+    @torch.no_grad()
+    @torch.inference_mode()
+    def convert_voice_with_streaming(
+            self,
+            source_audio_path: str,
+            target_audio_path: str,
+            diffusion_steps: int = 30,
+            length_adjust: float = 1.0,
+            intelligebility_cfg_rate: float = 0.7,
+            similarity_cfg_rate: float = 0.7,
+            top_p: float = 0.7,
+            temperature: float = 0.7,
+            repetition_penalty: float = 1.5,
+            convert_style: bool = False,
+            anonymization_only: bool = False,
+            device: torch.device = torch.device("cuda"),
+            dtype: torch.dtype = torch.float16,
+            stream_output: bool = True,
+    ):
+        """
+        Convert voice with streaming support for long audio files.
+        
+        Args:
+            source_audio_path: Path to source audio file
+            target_audio_path: Path to target audio file
+            diffusion_steps: Number of diffusion steps (default: 30)
+            length_adjust: Length adjustment factor (default: 1.0)
+            intelligebility_cfg_rate: CFG rate for intelligibility (default: 0.7)
+            similarity_cfg_rate: CFG rate for similarity (default: 0.7)
+            top_p: Top-p sampling parameter (default: 0.7)
+            temperature: Temperature for sampling (default: 0.7)
+            repetition_penalty: Repetition penalty (default: 1.5)
+            device: Device to use (default: cpu)
+            dtype: Data type to use (default: float32)
+            stream_output: Whether to stream the output (default: True)
+            
+        Returns:
+            If stream_output is True, yields (mp3_bytes, full_audio) tuples
+            If stream_output is False, returns the full audio as a numpy array
+        """
+        # Load audio
+        source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+        target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+        
+        # Limit target audio to 25 seconds
+        target_wave = target_wave[:self.sr * (self.dit_max_context_len - 5)]
+        
+        source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).float().to(device)
+        target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).float().to(device)
+
+        # Resample to 16kHz for feature extraction
+        source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+        target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+        source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+        target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+        # Compute mel spectrograms
+        source_mel = self.mel_fn(source_wave_tensor)
+        target_mel = self.mel_fn(target_wave_tensor)
+        source_mel_len = source_mel.size(2)
+        target_mel_len = target_mel.size(2)
+        
+        # Set up chunk processing parameters
+        max_context_window = self.sr // self.hop_size * self.dit_max_context_len
+        overlap_wave_len = self.overlap_frame_len * self.hop_size
+        
+        with torch.autocast(device_type=device.type, dtype=dtype):
+            # Compute content features
+            source_content_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=False)
+            target_content_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=False)
+            # Compute style features
+            target_style = self.compute_style(target_wave_16k_tensor)
+            prompt_condition, _, = self.cfm_length_regulator(target_content_indices,
+                                                             ylens=torch.LongTensor([target_mel_len]).to(device))
+
+        # prepare for streaming
+        generated_wave_chunks = []
+        processed_frames = 0
+        previous_chunk = None
+        if convert_style:
+            with torch.autocast(device_type=device.type, dtype=dtype):
+                source_narrow_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=True)
+                target_narrow_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=True)
+            src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1)
+            tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1)
+            # Process src_narrow_reduced in chunks of max 1000 tokens
+            max_chunk_size = 1000
+
+            # Process src_narrow_reduced in chunks
+            for i in range(0, len(src_narrow_reduced), max_chunk_size):
+                is_last_chunk = i + max_chunk_size >= len(src_narrow_reduced)
+                with torch.autocast(device_type=device.type, dtype=dtype):
+                    chunk = src_narrow_reduced[i:i + max_chunk_size]
+                    if anonymization_only:
+                        chunk_ar_cond = self.ar_length_regulator(chunk[None])[0]
+                        chunk_ar_out = self.ar.generate(chunk_ar_cond, torch.zeros([1, 0]).long().to(device),
+                                                        compiled_decode_fn=self.compiled_decode_fn,
+                                                      top_p=top_p, temperature=temperature,
+                                                      repetition_penalty=repetition_penalty)
+                    else:
+                        # For each chunk, we need to include tgt_narrow_reduced as context
+                        chunk_ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, chunk], dim=0)[None])[0]
+                        chunk_ar_out = self.ar.generate(chunk_ar_cond, target_content_indices, compiled_decode_fn=self.compiled_decode_fn,
+                                                      top_p=top_p, temperature=temperature,
+                                                      repetition_penalty=repetition_penalty)
+                    chunkar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size(
+                        -1) * chunk_ar_out.size(-1) * length_adjust)]).to(device)
+                    # Length regulation
+                    chunk_cond, _ = self.cfm_length_regulator(chunk_ar_out, ylens=torch.LongTensor([chunkar_out_mel_len]).to(device))
+                    cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+                    original_len = cat_condition.size(1)
+                    # pad cat_condition to compile_len
+                    if self.dit_compiled:
+                        cat_condition = torch.nn.functional.pad(cat_condition,
+                                                                (0, 0, 0, self.compile_len - cat_condition.size(1),),
+                                                                value=0)
+                    # Voice Conversion
+                    vc_mel = self.cfm.inference(
+                        cat_condition,
+                        torch.LongTensor([original_len]).to(device),
+                        target_mel, target_style, diffusion_steps,
+                        inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate],
+                        random_voice=anonymization_only,
+                    )
+                    vc_mel = vc_mel[:, :, target_mel_len:original_len]
+                vc_wave = self.vocoder(vc_mel).squeeze()[None]
+                processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+                    vc_wave, processed_frames, vc_mel, overlap_wave_len,
+                    generated_wave_chunks, previous_chunk, is_last_chunk, stream_output
+                )
+
+                if stream_output and mp3_bytes is not None:
+                    yield mp3_bytes, full_audio
+
+                if should_break:
+                    if not stream_output:
+                        return full_audio
+                    break
+        else:
+            cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
+
+            # Process in chunks for streaming
+            max_source_window = max_context_window - target_mel.size(2)
+
+            # Generate chunk by chunk and stream the output
+            while processed_frames < cond.size(1):
+                chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
+                is_last_chunk = processed_frames + max_source_window >= cond.size(1)
+                cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+                original_len = cat_condition.size(1)
+                # pad cat_condition to compile_len
+                if self.dit_compiled:
+                    cat_condition = torch.nn.functional.pad(cat_condition,
+                                                            (0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
+                with torch.autocast(device_type=device.type, dtype=dtype):
+                    # Voice Conversion
+                    vc_mel = self.cfm.inference(
+                        cat_condition,
+                        torch.LongTensor([original_len]).to(device),
+                        target_mel, target_style, diffusion_steps,
+                        inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate],
+                        random_voice=anonymization_only,
+                    )
+                vc_mel = vc_mel[:, :, target_mel_len:original_len]
+                vc_wave = self.vocoder(vc_mel).squeeze()[None]
+
+                processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+                    vc_wave, processed_frames, vc_mel, overlap_wave_len,
+                    generated_wave_chunks, previous_chunk, is_last_chunk, stream_output
+                )
+                
+                if stream_output and mp3_bytes is not None:
+                    yield mp3_bytes, full_audio
+                    
+                if should_break:
+                    if not stream_output:
+                        return full_audio
+                    break
+
+
diff --git a/requirements.txt b/requirements.txt
index e608c6a787058579c732680c3337566a0e36b259..4fa463bd62265380f5549538f043fa358200ef01 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,24 @@
---extra-index-url https://download.pytorch.org/whl/cu113
-torch
-torchvision
-torchaudio
-scipy==1.13.1
-onnxruntime-gpu==1.19.0
-librosa==0.10.2
-huggingface-hub
-munch
-einops
-descript-audio-codec
-git+https://github.com/openai/whisper.git 
-pydub
-transformers
\ No newline at end of file
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.4.0
+torchvision==0.19.0
+torchaudio==2.4.0
+scipy==1.13.1
+librosa==0.10.2
+huggingface-hub==0.23.4
+munch==4.0.0
+einops==0.8.0
+descript-audio-codec==1.0.0
+gradio==5.23.0
+pydub==0.25.1
+resemblyzer
+jiwer==3.0.3
+transformers==4.46.3
+FreeSimpleGUI==5.1.1
+soundfile==0.12.1
+sounddevice==0.5.0
+modelscope==1.18.1
+funasr==1.1.5
+numpy==1.26.4
+hydra-core==1.3.2
+pyyaml
+python-dotenv
diff --git a/seed_vc_wrapper.py b/seed_vc_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..14caf0cfd3a18861e1a2b2c5fc3774aaa244e353
--- /dev/null
+++ b/seed_vc_wrapper.py
@@ -0,0 +1,463 @@
+import spaces
+import torch
+import torchaudio
+import librosa
+import numpy as np
+from pydub import AudioSegment
+import yaml
+from modules.commons import build_model, load_checkpoint, recursive_munch
+from hf_utils import load_custom_model_from_hf
+from modules.campplus.DTDNN import CAMPPlus
+from modules.bigvgan import bigvgan
+from modules.audio import mel_spectrogram
+from modules.rmvpe import RMVPE
+from transformers import AutoFeatureExtractor, WhisperModel
+
+class SeedVCWrapper:
+    def __init__(self, device=None):
+        """
+        Initialize the Seed-VC wrapper with all necessary models and configurations.
+        
+        Args:
+            device: torch device to use. If None, will be automatically determined.
+        """
+        # Set device
+        if device is None:
+            if torch.cuda.is_available():
+                self.device = torch.device("cuda")
+            elif torch.backends.mps.is_available():
+                self.device = torch.device("mps")
+            else:
+                self.device = torch.device("cpu")
+        else:
+            self.device = device
+            
+        # Load base model and configuration
+        self._load_base_model()
+        
+        # Load F0 conditioned model
+        self._load_f0_model()
+        
+        # Load additional modules
+        self._load_additional_modules()
+        
+        # Set streaming parameters
+        self.overlap_frame_len = 16
+        self.bitrate = "320k"
+        
+    def _load_base_model(self):
+        """Load the base DiT model for voice conversion."""
+        dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
+            "Plachta/Seed-VC",
+            "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
+            "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"
+        )
+        config = yaml.safe_load(open(dit_config_path, 'r'))
+        model_params = recursive_munch(config['model_params'])
+        self.model = build_model(model_params, stage='DiT')
+        self.hop_length = config['preprocess_params']['spect_params']['hop_length']
+        self.sr = config['preprocess_params']['sr']
+        
+        # Load checkpoints
+        self.model, _, _, _ = load_checkpoint(
+            self.model, None, dit_checkpoint_path,
+            load_only_params=True, ignore_modules=[], is_distributed=False
+        )
+        for key in self.model:
+            self.model[key].eval()
+            self.model[key].to(self.device)
+        self.model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
+        
+        # Set up mel spectrogram function
+        mel_fn_args = {
+            "n_fft": config['preprocess_params']['spect_params']['n_fft'],
+            "win_size": config['preprocess_params']['spect_params']['win_length'],
+            "hop_size": config['preprocess_params']['spect_params']['hop_length'],
+            "num_mels": config['preprocess_params']['spect_params']['n_mels'],
+            "sampling_rate": self.sr,
+            "fmin": 0,
+            "fmax": None,
+            "center": False
+        }
+        self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
+        
+        # Load whisper model
+        whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small"
+        self.whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(self.device)
+        del self.whisper_model.decoder
+        self.whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
+        
+    def _load_f0_model(self):
+        """Load the F0 conditioned model for voice conversion."""
+        dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
+            "Plachta/Seed-VC",
+            "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
+            "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml"
+        )
+        config = yaml.safe_load(open(dit_config_path, 'r'))
+        model_params = recursive_munch(config['model_params'])
+        self.model_f0 = build_model(model_params, stage='DiT')
+        self.hop_length_f0 = config['preprocess_params']['spect_params']['hop_length']
+        self.sr_f0 = config['preprocess_params']['sr']
+        
+        # Load checkpoints
+        self.model_f0, _, _, _ = load_checkpoint(
+            self.model_f0, None, dit_checkpoint_path,
+            load_only_params=True, ignore_modules=[], is_distributed=False
+        )
+        for key in self.model_f0:
+            self.model_f0[key].eval()
+            self.model_f0[key].to(self.device)
+        self.model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
+        
+        # Set up mel spectrogram function for F0 model
+        mel_fn_args_f0 = {
+            "n_fft": config['preprocess_params']['spect_params']['n_fft'],
+            "win_size": config['preprocess_params']['spect_params']['win_length'],
+            "hop_size": config['preprocess_params']['spect_params']['hop_length'],
+            "num_mels": config['preprocess_params']['spect_params']['n_mels'],
+            "sampling_rate": self.sr_f0,
+            "fmin": 0,
+            "fmax": None,
+            "center": False
+        }
+        self.to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
+        
+    def _load_additional_modules(self):
+        """Load additional modules like CAMPPlus, BigVGAN, and RMVPE."""
+        # Load CAMPPlus
+        campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
+        self.campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
+        self.campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
+        self.campplus_model.eval()
+        self.campplus_model.to(self.device)
+        
+        # Load BigVGAN models
+        self.bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
+        self.bigvgan_model.remove_weight_norm()
+        self.bigvgan_model = self.bigvgan_model.eval().to(self.device)
+        
+        self.bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
+        self.bigvgan_44k_model.remove_weight_norm()
+        self.bigvgan_44k_model = self.bigvgan_44k_model.eval().to(self.device)
+        
+        # Load RMVPE for F0 extraction
+        model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
+        self.rmvpe = RMVPE(model_path, is_half=False, device=self.device)
+        
+    @staticmethod
+    def adjust_f0_semitones(f0_sequence, n_semitones):
+        """Adjust F0 values by a number of semitones."""
+        factor = 2 ** (n_semitones / 12)
+        return f0_sequence * factor
+    
+    @staticmethod
+    def crossfade(chunk1, chunk2, overlap):
+        """Apply crossfade between two audio chunks."""
+        fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
+        fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
+        if len(chunk2) < overlap:
+            chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
+        else:
+            chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
+        return chunk2
+    
+    def _stream_wave_chunks(self, vc_wave, processed_frames, vc_target, overlap_wave_len, 
+                           generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr):
+        """
+        Helper method to handle streaming wave chunks.
+        
+        Args:
+            vc_wave: The current wave chunk
+            processed_frames: Number of frames processed so far
+            vc_target: The target mel spectrogram
+            overlap_wave_len: Length of overlap between chunks
+            generated_wave_chunks: List of generated wave chunks
+            previous_chunk: Previous wave chunk for crossfading
+            is_last_chunk: Whether this is the last chunk
+            stream_output: Whether to stream the output
+            sr: Sample rate
+            
+        Returns:
+            Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio)
+            where should_break indicates if processing should stop
+            mp3_bytes is the MP3 bytes if streaming, None otherwise
+            full_audio is the full audio if this is the last chunk, None otherwise
+        """
+        mp3_bytes = None
+        full_audio = None
+        
+        if processed_frames == 0:
+            if is_last_chunk:
+                output_wave = vc_wave[0].cpu().numpy()
+                generated_wave_chunks.append(output_wave)
+                
+                if stream_output:
+                    output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                    mp3_bytes = AudioSegment(
+                        output_wave_int16.tobytes(), frame_rate=sr,
+                        sample_width=output_wave_int16.dtype.itemsize, channels=1
+                    ).export(format="mp3", bitrate=self.bitrate).read()
+                    full_audio = (sr, np.concatenate(generated_wave_chunks))
+                else:
+                    return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+                
+                return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+            
+            output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
+            generated_wave_chunks.append(output_wave)
+            previous_chunk = vc_wave[0, -overlap_wave_len:]
+            processed_frames += vc_target.size(2) - self.overlap_frame_len
+            
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+            
+        elif is_last_chunk:
+            output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
+            generated_wave_chunks.append(output_wave)
+            processed_frames += vc_target.size(2) - self.overlap_frame_len
+            
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+                full_audio = (sr, np.concatenate(generated_wave_chunks))
+            else:
+                return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+            
+            return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+            
+        else:
+            output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
+            generated_wave_chunks.append(output_wave)
+            previous_chunk = vc_wave[0, -overlap_wave_len:]
+            processed_frames += vc_target.size(2) - self.overlap_frame_len
+            
+            if stream_output:
+                output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+                mp3_bytes = AudioSegment(
+                    output_wave_int16.tobytes(), frame_rate=sr,
+                    sample_width=output_wave_int16.dtype.itemsize, channels=1
+                ).export(format="mp3", bitrate=self.bitrate).read()
+                
+        return processed_frames, previous_chunk, False, mp3_bytes, full_audio
+
+    def _process_whisper_features(self, audio_16k, is_source=True):
+        """Process audio through Whisper model to extract features."""
+        if audio_16k.size(-1) <= 16000 * 30:
+            # If audio is short enough, process in one go
+            inputs = self.whisper_feature_extractor(
+                [audio_16k.squeeze(0).cpu().numpy()],
+                return_tensors="pt",
+                return_attention_mask=True,
+                sampling_rate=16000
+            )
+            input_features = self.whisper_model._mask_input_features(
+                inputs.input_features, attention_mask=inputs.attention_mask
+            ).to(self.device)
+            outputs = self.whisper_model.encoder(
+                input_features.to(self.whisper_model.encoder.dtype),
+                head_mask=None,
+                output_attentions=False,
+                output_hidden_states=False,
+                return_dict=True,
+            )
+            features = outputs.last_hidden_state.to(torch.float32)
+            features = features[:, :audio_16k.size(-1) // 320 + 1]
+        else:
+            # Process long audio in chunks
+            overlapping_time = 5  # 5 seconds
+            features_list = []
+            buffer = None
+            traversed_time = 0
+            while traversed_time < audio_16k.size(-1):
+                if buffer is None:  # first chunk
+                    chunk = audio_16k[:, traversed_time:traversed_time + 16000 * 30]
+                else:
+                    chunk = torch.cat([
+                        buffer, 
+                        audio_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]
+                    ], dim=-1)
+                inputs = self.whisper_feature_extractor(
+                    [chunk.squeeze(0).cpu().numpy()],
+                    return_tensors="pt",
+                    return_attention_mask=True,
+                    sampling_rate=16000
+                )
+                input_features = self.whisper_model._mask_input_features(
+                    inputs.input_features, attention_mask=inputs.attention_mask
+                ).to(self.device)
+                outputs = self.whisper_model.encoder(
+                    input_features.to(self.whisper_model.encoder.dtype),
+                    head_mask=None,
+                    output_attentions=False,
+                    output_hidden_states=False,
+                    return_dict=True,
+                )
+                chunk_features = outputs.last_hidden_state.to(torch.float32)
+                chunk_features = chunk_features[:, :chunk.size(-1) // 320 + 1]
+                if traversed_time == 0:
+                    features_list.append(chunk_features)
+                else:
+                    features_list.append(chunk_features[:, 50 * overlapping_time:])
+                buffer = chunk[:, -16000 * overlapping_time:]
+                traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
+            features = torch.cat(features_list, dim=1)
+        
+        return features
+
+    @spaces.GPU
+    @torch.no_grad()
+    @torch.inference_mode()
+    def convert_voice(self, source, target, diffusion_steps=10, length_adjust=1.0,
+                     inference_cfg_rate=0.7, f0_condition=False, auto_f0_adjust=True, 
+                     pitch_shift=0, stream_output=True):
+        """
+        Convert both timbre and voice from source to target.
+        
+        Args:
+            source: Path to source audio file
+            target: Path to target audio file
+            diffusion_steps: Number of diffusion steps (default: 10)
+            length_adjust: Length adjustment factor (default: 1.0)
+            inference_cfg_rate: Inference CFG rate (default: 0.7)
+            f0_condition: Whether to use F0 conditioning (default: False)
+            auto_f0_adjust: Whether to automatically adjust F0 (default: True)
+            pitch_shift: Pitch shift in semitones (default: 0)
+            stream_output: Whether to stream the output (default: True)
+            
+        Returns:
+            If stream_output is True, yields (mp3_bytes, full_audio) tuples
+            If stream_output is False, returns the full audio as a numpy array
+        """
+        # Select appropriate models based on F0 condition
+        inference_module = self.model if not f0_condition else self.model_f0
+        mel_fn = self.to_mel if not f0_condition else self.to_mel_f0
+        bigvgan_fn = self.bigvgan_model if not f0_condition else self.bigvgan_44k_model
+        sr = 22050 if not f0_condition else 44100
+        hop_length = 256 if not f0_condition else 512
+        max_context_window = sr // hop_length * 30
+        overlap_wave_len = self.overlap_frame_len * hop_length
+        
+        # Load audio
+        source_audio = librosa.load(source, sr=sr)[0]
+        ref_audio = librosa.load(target, sr=sr)[0]
+        
+        # Process audio
+        source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
+        ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(self.device)
+        
+        # Resample to 16kHz for feature extraction
+        ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
+        converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
+        
+        # Extract Whisper features
+        S_alt = self._process_whisper_features(converted_waves_16k, is_source=True)
+        S_ori = self._process_whisper_features(ref_waves_16k, is_source=False)
+        
+        # Compute mel spectrograms
+        mel = mel_fn(source_audio.to(self.device).float())
+        mel2 = mel_fn(ref_audio.to(self.device).float())
+        
+        # Set target lengths
+        target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
+        target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
+        
+        # Compute style features
+        feat2 = torchaudio.compliance.kaldi.fbank(
+            ref_waves_16k,
+            num_mel_bins=80,
+            dither=0,
+            sample_frequency=16000
+        )
+        feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
+        style2 = self.campplus_model(feat2.unsqueeze(0))
+        
+        # Process F0 if needed
+        if f0_condition:
+            F0_ori = self.rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.03)
+            F0_alt = self.rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
+            
+            if self.device == "mps":
+                F0_ori = torch.from_numpy(F0_ori).float().to(self.device)[None]
+                F0_alt = torch.from_numpy(F0_alt).float().to(self.device)[None]
+            else:
+                F0_ori = torch.from_numpy(F0_ori).to(self.device)[None]
+                F0_alt = torch.from_numpy(F0_alt).to(self.device)[None]
+            
+            voiced_F0_ori = F0_ori[F0_ori > 1]
+            voiced_F0_alt = F0_alt[F0_alt > 1]
+            
+            log_f0_alt = torch.log(F0_alt + 1e-5)
+            voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
+            voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
+            median_log_f0_ori = torch.median(voiced_log_f0_ori)
+            median_log_f0_alt = torch.median(voiced_log_f0_alt)
+            
+            # Shift alt log f0 level to ori log f0 level
+            shifted_log_f0_alt = log_f0_alt.clone()
+            if auto_f0_adjust:
+                shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
+            shifted_f0_alt = torch.exp(shifted_log_f0_alt)
+            if pitch_shift != 0:
+                shifted_f0_alt[F0_alt > 1] = self.adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
+        else:
+            F0_ori = None
+            F0_alt = None
+            shifted_f0_alt = None
+        
+        # Length regulation
+        cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
+            S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt
+        )
+        prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
+            S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori
+        )
+        
+        # Process in chunks for streaming
+        max_source_window = max_context_window - mel2.size(2)
+        processed_frames = 0
+        generated_wave_chunks = []
+        previous_chunk = None
+        
+        # Generate chunk by chunk and stream the output
+        while processed_frames < cond.size(1):
+            chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
+            is_last_chunk = processed_frames + max_source_window >= cond.size(1)
+            cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+            
+            with torch.autocast(device_type=self.device.type, dtype=torch.float16):
+                # Voice Conversion
+                vc_target = inference_module.cfm.inference(
+                    cat_condition,
+                    torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
+                    mel2, style2, None, diffusion_steps,
+                    inference_cfg_rate=inference_cfg_rate
+                )
+                vc_target = vc_target[:, :, mel2.size(-1):]
+            
+            vc_wave = bigvgan_fn(vc_target.float())[0]
+            
+            processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+                vc_wave, processed_frames, vc_target, overlap_wave_len, 
+                generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr
+            )
+            
+            if stream_output and mp3_bytes is not None:
+                yield mp3_bytes, full_audio
+                
+            if should_break:
+                if not stream_output:
+                    return full_audio
+                break
+        
+        if not stream_output:
+            return np.concatenate(generated_wave_chunks)
+        
+        return None, None 
\ No newline at end of file