diff --git a/audiotools/__init__.py b/audiotools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..573ffd06100ad72614df9363b12cda6672f1b70e
--- /dev/null
+++ b/audiotools/__init__.py
@@ -0,0 +1,10 @@
+__version__ = "0.7.3"
+from .core import AudioSignal
+from .core import STFTParams
+from .core import Meter
+from .core import util
+from . import metrics
+from . import data
+from . import ml
+from .data import datasets
+from .data import transforms
diff --git a/audiotools/core/__init__.py b/audiotools/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8660c4e67f43d0ded584a38939425e2c28d95cd3
--- /dev/null
+++ b/audiotools/core/__init__.py
@@ -0,0 +1,4 @@
+from . import util
+from .audio_signal import AudioSignal
+from .audio_signal import STFTParams
+from .loudness import Meter
diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb6d751cb968a003656e3e7874c487b83d94c82e
--- /dev/null
+++ b/audiotools/core/audio_signal.py
@@ -0,0 +1,1682 @@
+import copy
+import functools
+import hashlib
+import math
+import pathlib
+import tempfile
+import typing
+import warnings
+from collections import namedtuple
+from pathlib import Path
+
+import julius
+import numpy as np
+import soundfile
+import torch
+
+from . import util
+from .display import DisplayMixin
+from .dsp import DSPMixin
+from .effects import EffectMixin
+from .effects import ImpulseResponseMixin
+from .ffmpeg import FFMPEGMixin
+from .loudness import LoudnessMixin
+from .playback import PlayMixin
+from .whisper import WhisperMixin
+
+
+STFTParams = namedtuple(
+    "STFTParams",
+    ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
+)
+"""
+STFTParams object is a container that holds STFT parameters - window_length,
+hop_length, and window_type. Not all parameters need to be specified. Ones that
+are not specified will be inferred by the AudioSignal parameters.
+
+Parameters
+----------
+window_length : int, optional
+    Window length of STFT, by default ``0.032 * self.sample_rate``.
+hop_length : int, optional
+    Hop length of STFT, by default ``window_length // 4``.
+window_type : str, optional
+    Type of window to use, by default ``sqrt\_hann``.
+match_stride : bool, optional
+    Whether to match the stride of convolutional layers, by default False
+padding_type : str, optional
+    Type of padding to use, by default 'reflect'
+"""
+STFTParams.__new__.__defaults__ = (None, None, None, None, None)
+
+
+class AudioSignal(
+    EffectMixin,
+    LoudnessMixin,
+    PlayMixin,
+    ImpulseResponseMixin,
+    DSPMixin,
+    DisplayMixin,
+    FFMPEGMixin,
+    WhisperMixin,
+):
+    """This is the core object of this library. Audio is always
+    loaded into an AudioSignal, which then enables all the features
+    of this library, including audio augmentations, I/O, playback,
+    and more.
+
+    The structure of this object is that the base functionality
+    is defined in ``core/audio_signal.py``, while extensions to
+    that functionality are defined in the other ``core/*.py``
+    files. For example, all the display-based functionality
+    (e.g. plot spectrograms, waveforms, write to tensorboard)
+    are in ``core/display.py``.
+
+    Parameters
+    ----------
+    audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
+        Object to create AudioSignal from. Can be a tensor, numpy array,
+        or a path to a file. The file is always reshaped to
+    sample_rate : int, optional
+        Sample rate of the audio. If different from underlying file, resampling is
+        performed. If passing in an array or tensor, this must be defined,
+        by default None
+    stft_params : STFTParams, optional
+        Parameters of STFT to use. , by default None
+    offset : float, optional
+        Offset in seconds to read from file, by default 0
+    duration : float, optional
+        Duration in seconds to read from file, by default None
+    device : str, optional
+        Device to load audio onto, by default None
+
+    Examples
+    --------
+    Loading an AudioSignal from an array, at a sample rate of
+    44100.
+
+    >>> signal = AudioSignal(torch.randn(5*44100), 44100)
+
+    Note, the signal is reshaped to have a batch size, and one
+    audio channel:
+
+    >>> print(signal.shape)
+    (1, 1, 44100)
+
+    You can treat AudioSignals like tensors, and many of the same
+    functions you might use on tensors are defined for AudioSignals
+    as well:
+
+    >>> signal.to("cuda")
+    >>> signal.cuda()
+    >>> signal.clone()
+    >>> signal.detach()
+
+    Indexing AudioSignals returns an AudioSignal:
+
+    >>> signal[..., 3*44100:4*44100]
+
+    The above signal is 1 second long, and is also an AudioSignal.
+    """
+
+    def __init__(
+        self,
+        audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
+        sample_rate: int = None,
+        stft_params: STFTParams = None,
+        offset: float = 0,
+        duration: float = None,
+        device: str = None,
+    ):
+        audio_path = None
+        audio_array = None
+
+        if isinstance(audio_path_or_array, str):
+            audio_path = audio_path_or_array
+        elif isinstance(audio_path_or_array, pathlib.Path):
+            audio_path = audio_path_or_array
+        elif isinstance(audio_path_or_array, np.ndarray):
+            audio_array = audio_path_or_array
+        elif torch.is_tensor(audio_path_or_array):
+            audio_array = audio_path_or_array
+        else:
+            raise ValueError(
+                "audio_path_or_array must be either a Path, "
+                "string, numpy array, or torch Tensor!"
+            )
+
+        self.path_to_file = None
+
+        self.audio_data = None
+        self.sources = None  # List of AudioSignal objects.
+        self.stft_data = None
+        if audio_path is not None:
+            self.load_from_file(
+                audio_path, offset=offset, duration=duration, device=device
+            )
+        elif audio_array is not None:
+            assert sample_rate is not None, "Must set sample rate!"
+            self.load_from_array(audio_array, sample_rate, device=device)
+
+        self.window = None
+        self.stft_params = stft_params
+
+        self.metadata = {
+            "offset": offset,
+            "duration": duration,
+        }
+
+    @property
+    def path_to_input_file(
+        self,
+    ):
+        """
+        Path to input file, if it exists.
+        Alias to ``path_to_file`` for backwards compatibility
+        """
+        return self.path_to_file
+
+    @classmethod
+    def excerpt(
+        cls,
+        audio_path: typing.Union[str, Path],
+        offset: float = None,
+        duration: float = None,
+        state: typing.Union[np.random.RandomState, int] = None,
+        **kwargs,
+    ):
+        """Randomly draw an excerpt of ``duration`` seconds from an
+        audio file specified at ``audio_path``, between ``offset`` seconds
+        and end of file. ``state`` can be used to seed the random draw.
+
+        Parameters
+        ----------
+        audio_path : typing.Union[str, Path]
+            Path to audio file to grab excerpt from.
+        offset : float, optional
+            Lower bound for the start time, in seconds drawn from
+            the file, by default None.
+        duration : float, optional
+            Duration of excerpt, in seconds, by default None
+        state : typing.Union[np.random.RandomState, int], optional
+            RandomState or seed of random state, by default None
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal containing excerpt.
+
+        Examples
+        --------
+        >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
+        """
+        info = util.info(audio_path)
+        total_duration = info.duration
+
+        state = util.random_state(state)
+        lower_bound = 0 if offset is None else offset
+        upper_bound = max(total_duration - duration, 0)
+        offset = state.uniform(lower_bound, upper_bound)
+
+        signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
+        signal.metadata["offset"] = offset
+        signal.metadata["duration"] = duration
+
+        return signal
+
+    @classmethod
+    def salient_excerpt(
+        cls,
+        audio_path: typing.Union[str, Path],
+        loudness_cutoff: float = None,
+        num_tries: int = 8,
+        state: typing.Union[np.random.RandomState, int] = None,
+        **kwargs,
+    ):
+        """Similar to AudioSignal.excerpt, except it extracts excerpts only
+        if they are above a specified loudness threshold, which is computed via
+        a fast LUFS routine.
+
+        Parameters
+        ----------
+        audio_path : typing.Union[str, Path]
+            Path to audio file to grab excerpt from.
+        loudness_cutoff : float, optional
+            Loudness threshold in dB. Typical values are ``-40, -60``,
+            etc, by default None
+        num_tries : int, optional
+            Number of tries to grab an excerpt above the threshold
+            before giving up, by default 8.
+        state : typing.Union[np.random.RandomState, int], optional
+            RandomState or seed of random state, by default None
+        kwargs : dict
+            Keyword arguments to AudioSignal.excerpt
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal containing excerpt.
+
+
+        .. warning::
+            if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
+            result in an infinite loop if ``audio_path`` does not have
+            any loud enough excerpts.
+
+        Examples
+        --------
+        >>> signal = AudioSignal.salient_excerpt(
+                "path/to/audio",
+                loudness_cutoff=-40,
+                duration=5
+            )
+        """
+        state = util.random_state(state)
+        if loudness_cutoff is None:
+            excerpt = cls.excerpt(audio_path, state=state, **kwargs)
+        else:
+            loudness = -np.inf
+            num_try = 0
+            while loudness <= loudness_cutoff:
+                excerpt = cls.excerpt(audio_path, state=state, **kwargs)
+                loudness = excerpt.loudness()
+                num_try += 1
+                if num_tries is not None and num_try >= num_tries:
+                    break
+        return excerpt
+
+    @classmethod
+    def zeros(
+        cls,
+        duration: float,
+        sample_rate: int,
+        num_channels: int = 1,
+        batch_size: int = 1,
+        **kwargs,
+    ):
+        """Helper function create an AudioSignal of all zeros.
+
+        Parameters
+        ----------
+        duration : float
+            Duration of AudioSignal
+        sample_rate : int
+            Sample rate of AudioSignal
+        num_channels : int, optional
+            Number of channels, by default 1
+        batch_size : int, optional
+            Batch size, by default 1
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal containing all zeros.
+
+        Examples
+        --------
+        Generate 5 seconds of all zeros at a sample rate of 44100.
+
+        >>> signal = AudioSignal.zeros(5.0, 44100)
+        """
+        n_samples = int(duration * sample_rate)
+        return cls(
+            torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
+        )
+
+    @classmethod
+    def wave(
+        cls,
+        frequency: float,
+        duration: float,
+        sample_rate: int,
+        num_channels: int = 1,
+        shape: str = "sine",
+        **kwargs,
+    ):
+        """
+        Generate a waveform of a given frequency and shape.
+
+        Parameters
+        ----------
+        frequency : float
+            Frequency of the waveform
+        duration : float
+            Duration of the waveform
+        sample_rate : int
+            Sample rate of the waveform
+        num_channels : int, optional
+            Number of channels, by default 1
+        shape : str, optional
+            Shape of the waveform, by default "saw"
+            One of "sawtooth", "square", "sine", "triangle"
+        kwargs : dict
+            Keyword arguments to AudioSignal
+        """
+        n_samples = int(duration * sample_rate)
+        t = torch.linspace(0, duration, n_samples)
+        if shape == "sawtooth":
+            from scipy.signal import sawtooth
+
+            wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
+        elif shape == "square":
+            from scipy.signal import square
+
+            wave_data = square(2 * np.pi * frequency * t)
+        elif shape == "sine":
+            wave_data = np.sin(2 * np.pi * frequency * t)
+        elif shape == "triangle":
+            from scipy.signal import sawtooth
+
+            # frequency is doubled by the abs call, so omit the 2 in 2pi
+            wave_data = sawtooth(np.pi * frequency * t, 0.5)
+            wave_data = -np.abs(wave_data) * 2 + 1
+        else:
+            raise ValueError(f"Invalid shape {shape}")
+
+        wave_data = torch.tensor(wave_data, dtype=torch.float32)
+        wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
+        return cls(wave_data, sample_rate, **kwargs)
+
+    @classmethod
+    def batch(
+        cls,
+        audio_signals: list,
+        pad_signals: bool = False,
+        truncate_signals: bool = False,
+        resample: bool = False,
+        dim: int = 0,
+    ):
+        """Creates a batched AudioSignal from a list of AudioSignals.
+
+        Parameters
+        ----------
+        audio_signals : list[AudioSignal]
+            List of AudioSignal objects
+        pad_signals : bool, optional
+            Whether to pad signals to length of the maximum length
+            AudioSignal in the list, by default False
+        truncate_signals : bool, optional
+            Whether to truncate signals to length of shortest length
+            AudioSignal in the list, by default False
+        resample : bool, optional
+            Whether to resample AudioSignal to the sample rate of
+            the first AudioSignal in the list, by default False
+        dim : int, optional
+            Dimension along which to batch the signals.
+
+        Returns
+        -------
+        AudioSignal
+            Batched AudioSignal.
+
+        Raises
+        ------
+        RuntimeError
+            If not all AudioSignals are the same sample rate, and
+            ``resample=False``, an error is raised.
+        RuntimeError
+            If not all AudioSignals are the same the length, and
+            both ``pad_signals=False`` and ``truncate_signals=False``,
+            an error is raised.
+
+        Examples
+        --------
+        Batching a bunch of random signals:
+
+        >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
+        >>> signal = AudioSignal.batch(signal_list)
+        >>> print(signal.shape)
+        (10, 1, 44100)
+
+        """
+        signal_lengths = [x.signal_length for x in audio_signals]
+        sample_rates = [x.sample_rate for x in audio_signals]
+
+        if len(set(sample_rates)) != 1:
+            if resample:
+                for x in audio_signals:
+                    x.resample(sample_rates[0])
+            else:
+                raise RuntimeError(
+                    f"Not all signals had the same sample rate! Got {sample_rates}. "
+                    f"All signals must have the same sample rate, or resample must be True. "
+                )
+
+        if len(set(signal_lengths)) != 1:
+            if pad_signals:
+                max_length = max(signal_lengths)
+                for x in audio_signals:
+                    pad_len = max_length - x.signal_length
+                    x.zero_pad(0, pad_len)
+            elif truncate_signals:
+                min_length = min(signal_lengths)
+                for x in audio_signals:
+                    x.truncate_samples(min_length)
+            else:
+                raise RuntimeError(
+                    f"Not all signals had the same length! Got {signal_lengths}. "
+                    f"All signals must be the same length, or pad_signals/truncate_signals "
+                    f"must be True. "
+                )
+        # Concatenate along the specified dimension (default 0)
+        audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
+        audio_paths = [x.path_to_file for x in audio_signals]
+
+        batched_signal = cls(
+            audio_data,
+            sample_rate=audio_signals[0].sample_rate,
+        )
+        batched_signal.path_to_file = audio_paths
+        return batched_signal
+
+    # I/O
+    def load_from_file(
+        self,
+        audio_path: typing.Union[str, Path],
+        offset: float,
+        duration: float,
+        device: str = "cpu",
+    ):
+        """Loads data from file. Used internally when AudioSignal
+        is instantiated with a path to a file.
+
+        Parameters
+        ----------
+        audio_path : typing.Union[str, Path]
+            Path to file
+        offset : float
+            Offset in seconds
+        duration : float
+            Duration in seconds
+        device : str, optional
+            Device to put AudioSignal on, by default "cpu"
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal loaded from file
+        """
+        import librosa
+
+        data, sample_rate = librosa.load(
+            audio_path,
+            offset=offset,
+            duration=duration,
+            sr=None,
+            mono=False,
+        )
+        data = util.ensure_tensor(data)
+        if data.shape[-1] == 0:
+            raise RuntimeError(
+                f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
+            )
+
+        if data.ndim < 2:
+            data = data.unsqueeze(0)
+        if data.ndim < 3:
+            data = data.unsqueeze(0)
+        self.audio_data = data
+
+        self.original_signal_length = self.signal_length
+
+        self.sample_rate = sample_rate
+        self.path_to_file = audio_path
+        return self.to(device)
+
+    def load_from_array(
+        self,
+        audio_array: typing.Union[torch.Tensor, np.ndarray],
+        sample_rate: int,
+        device: str = "cpu",
+    ):
+        """Loads data from array, reshaping it to be exactly 3
+        dimensions. Used internally when AudioSignal is called
+        with a tensor or an array.
+
+        Parameters
+        ----------
+        audio_array : typing.Union[torch.Tensor, np.ndarray]
+            Array/tensor of audio of samples.
+        sample_rate : int
+            Sample rate of audio
+        device : str, optional
+            Device to move audio onto, by default "cpu"
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal loaded from array
+        """
+        audio_data = util.ensure_tensor(audio_array)
+
+        if audio_data.dtype == torch.double:
+            audio_data = audio_data.float()
+
+        if audio_data.ndim < 2:
+            audio_data = audio_data.unsqueeze(0)
+        if audio_data.ndim < 3:
+            audio_data = audio_data.unsqueeze(0)
+        self.audio_data = audio_data
+
+        self.original_signal_length = self.signal_length
+
+        self.sample_rate = sample_rate
+        return self.to(device)
+
+    def write(self, audio_path: typing.Union[str, Path]):
+        """Writes audio to a file. Only writes the audio
+        that is in the very first item of the batch. To write other items
+        in the batch, index the signal along the batch dimension
+        before writing. After writing, the signal's ``path_to_file``
+        attribute is updated to the new path.
+
+        Parameters
+        ----------
+        audio_path : typing.Union[str, Path]
+            Path to write audio to.
+
+        Returns
+        -------
+        AudioSignal
+            Returns original AudioSignal, so you can use this in a fluent
+            interface.
+
+        Examples
+        --------
+        Creating and writing a signal to disk:
+
+        >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
+        >>> signal.write("/tmp/out.wav")
+
+        Writing a different element of the batch:
+
+        >>> signal[5].write("/tmp/out.wav")
+
+        Using this in a fluent interface:
+
+        >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
+
+        """
+        if self.audio_data[0].abs().max() > 1:
+            warnings.warn("Audio amplitude > 1 clipped when saving")
+        soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
+
+        self.path_to_file = audio_path
+        return self
+
+    def deepcopy(self):
+        """Copies the signal and all of its attributes.
+
+        Returns
+        -------
+        AudioSignal
+            Deep copy of the audio signal.
+        """
+        return copy.deepcopy(self)
+
+    def copy(self):
+        """Shallow copy of signal.
+
+        Returns
+        -------
+        AudioSignal
+            Shallow copy of the audio signal.
+        """
+        return copy.copy(self)
+
+    def clone(self):
+        """Clones all tensors contained in the AudioSignal,
+        and returns a copy of the signal with everything
+        cloned. Useful when using AudioSignal within autograd
+        computation graphs.
+
+        Relevant attributes are the stft data, the audio data,
+        and the loudness of the file.
+
+        Returns
+        -------
+        AudioSignal
+            Clone of AudioSignal.
+        """
+        clone = type(self)(
+            self.audio_data.clone(),
+            self.sample_rate,
+            stft_params=self.stft_params,
+        )
+        if self.stft_data is not None:
+            clone.stft_data = self.stft_data.clone()
+        if self._loudness is not None:
+            clone._loudness = self._loudness.clone()
+        clone.path_to_file = copy.deepcopy(self.path_to_file)
+        clone.metadata = copy.deepcopy(self.metadata)
+        return clone
+
+    def detach(self):
+        """Detaches tensors contained in AudioSignal.
+
+        Relevant attributes are the stft data, the audio data,
+        and the loudness of the file.
+
+        Returns
+        -------
+        AudioSignal
+            Same signal, but with all tensors detached.
+        """
+        if self._loudness is not None:
+            self._loudness = self._loudness.detach()
+        if self.stft_data is not None:
+            self.stft_data = self.stft_data.detach()
+
+        self.audio_data = self.audio_data.detach()
+        return self
+
+    def hash(self):
+        """Writes the audio data to a temporary file, and then
+        hashes it using hashlib. Useful for creating a file
+        name based on the audio content.
+
+        Returns
+        -------
+        str
+            Hash of audio data.
+
+        Examples
+        --------
+        Creating a signal, and writing it to a unique file name:
+
+        >>> signal = AudioSignal(torch.randn(44100), 44100)
+        >>> hash = signal.hash()
+        >>> signal.write(f"{hash}.wav")
+
+        """
+        with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+            self.write(f.name)
+            h = hashlib.sha256()
+            b = bytearray(128 * 1024)
+            mv = memoryview(b)
+            with open(f.name, "rb", buffering=0) as f:
+                for n in iter(lambda: f.readinto(mv), 0):
+                    h.update(mv[:n])
+            file_hash = h.hexdigest()
+        return file_hash
+
+    # Signal operations
+    def to_mono(self):
+        """Converts audio data to mono audio, by taking the mean
+        along the channels dimension.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with mean of channels.
+        """
+        self.audio_data = self.audio_data.mean(1, keepdim=True)
+        return self
+
+    def resample(self, sample_rate: int):
+        """Resamples the audio, using sinc interpolation. This works on both
+        cpu and gpu, and is much faster on gpu.
+
+        Parameters
+        ----------
+        sample_rate : int
+            Sample rate to resample to.
+
+        Returns
+        -------
+        AudioSignal
+            Resampled AudioSignal
+        """
+        if sample_rate == self.sample_rate:
+            return self
+        self.audio_data = julius.resample_frac(
+            self.audio_data, self.sample_rate, sample_rate
+        )
+        self.sample_rate = sample_rate
+        return self
+
+    # Tensor operations
+    def to(self, device: str):
+        """Moves all tensors contained in signal to the specified device.
+
+        Parameters
+        ----------
+        device : str
+            Device to move AudioSignal onto. Typical values are
+            "cuda", "cpu", or "cuda:n" to specify the nth gpu.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with all tensors moved to specified device.
+        """
+        if self._loudness is not None:
+            self._loudness = self._loudness.to(device)
+        if self.stft_data is not None:
+            self.stft_data = self.stft_data.to(device)
+        if self.audio_data is not None:
+            self.audio_data = self.audio_data.to(device)
+        return self
+
+    def float(self):
+        """Calls ``.float()`` on ``self.audio_data``.
+
+        Returns
+        -------
+        AudioSignal
+        """
+        self.audio_data = self.audio_data.float()
+        return self
+
+    def cpu(self):
+        """Moves AudioSignal to cpu.
+
+        Returns
+        -------
+        AudioSignal
+        """
+        return self.to("cpu")
+
+    def cuda(self):  # pragma: no cover
+        """Moves AudioSignal to cuda.
+
+        Returns
+        -------
+        AudioSignal
+        """
+        return self.to("cuda")
+
+    def numpy(self):
+        """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
+
+        Returns
+        -------
+        np.ndarray
+            Audio data as a numpy array.
+        """
+        return self.audio_data.detach().cpu().numpy()
+
+    def zero_pad(self, before: int, after: int):
+        """Zero pads the audio_data tensor before and after.
+
+        Parameters
+        ----------
+        before : int
+            How many zeros to prepend to audio.
+        after : int
+            How many zeros to append to audio.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with padding applied.
+        """
+        self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
+        return self
+
+    def zero_pad_to(self, length: int, mode: str = "after"):
+        """Pad with zeros to a specified length, either before or after
+        the audio data.
+
+        Parameters
+        ----------
+        length : int
+            Length to pad to
+        mode : str, optional
+            Whether to prepend or append zeros to signal, by default "after"
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with padding applied.
+        """
+        if mode == "before":
+            self.zero_pad(max(length - self.signal_length, 0), 0)
+        elif mode == "after":
+            self.zero_pad(0, max(length - self.signal_length, 0))
+        return self
+
+    def trim(self, before: int, after: int):
+        """Trims the audio_data tensor before and after.
+
+        Parameters
+        ----------
+        before : int
+            How many samples to trim from beginning.
+        after : int
+            How many samples to trim from end.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with trimming applied.
+        """
+        if after == 0:
+            self.audio_data = self.audio_data[..., before:]
+        else:
+            self.audio_data = self.audio_data[..., before:-after]
+        return self
+
+    def truncate_samples(self, length_in_samples: int):
+        """Truncate signal to specified length.
+
+        Parameters
+        ----------
+        length_in_samples : int
+            Truncate to this many samples.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with truncation applied.
+        """
+        self.audio_data = self.audio_data[..., :length_in_samples]
+        return self
+
+    @property
+    def device(self):
+        """Get device that AudioSignal is on.
+
+        Returns
+        -------
+        torch.device
+            Device that AudioSignal is on.
+        """
+        if self.audio_data is not None:
+            device = self.audio_data.device
+        elif self.stft_data is not None:
+            device = self.stft_data.device
+        return device
+
+    # Properties
+    @property
+    def audio_data(self):
+        """Returns the audio data tensor in the object.
+
+        Audio data is always of the shape
+        (batch_size, num_channels, num_samples). If value has less
+        than 3 dims (e.g. is (num_channels, num_samples)), then it will
+        be reshaped to (1, num_channels, num_samples) - a batch size of 1.
+
+        Parameters
+        ----------
+        data : typing.Union[torch.Tensor, np.ndarray]
+            Audio data to set.
+
+        Returns
+        -------
+        torch.Tensor
+            Audio samples.
+        """
+        return self._audio_data
+
+    @audio_data.setter
+    def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
+        if data is not None:
+            assert torch.is_tensor(data), "audio_data should be torch.Tensor"
+            assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
+        self._audio_data = data
+        # Old loudness value not guaranteed to be right, reset it.
+        self._loudness = None
+        return
+
+    # alias for audio_data
+    samples = audio_data
+
+    @property
+    def stft_data(self):
+        """Returns the STFT data inside the signal. Shape is
+        (batch, channels, frequencies, time).
+
+        Returns
+        -------
+        torch.Tensor
+            Complex spectrogram data.
+        """
+        return self._stft_data
+
+    @stft_data.setter
+    def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
+        if data is not None:
+            assert torch.is_tensor(data) and torch.is_complex(data)
+            if self.stft_data is not None and self.stft_data.shape != data.shape:
+                warnings.warn("stft_data changed shape")
+        self._stft_data = data
+        return
+
+    @property
+    def batch_size(self):
+        """Batch size of audio signal.
+
+        Returns
+        -------
+        int
+            Batch size of signal.
+        """
+        return self.audio_data.shape[0]
+
+    @property
+    def signal_length(self):
+        """Length of audio signal.
+
+        Returns
+        -------
+        int
+            Length of signal in samples.
+        """
+        return self.audio_data.shape[-1]
+
+    # alias for signal_length
+    length = signal_length
+
+    @property
+    def shape(self):
+        """Shape of audio data.
+
+        Returns
+        -------
+        tuple
+            Shape of audio data.
+        """
+        return self.audio_data.shape
+
+    @property
+    def signal_duration(self):
+        """Length of audio signal in seconds.
+
+        Returns
+        -------
+        float
+            Length of signal in seconds.
+        """
+        return self.signal_length / self.sample_rate
+
+    # alias for signal_duration
+    duration = signal_duration
+
+    @property
+    def num_channels(self):
+        """Number of audio channels.
+
+        Returns
+        -------
+        int
+            Number of audio channels.
+        """
+        return self.audio_data.shape[1]
+
+    # STFT
+    @staticmethod
+    @functools.lru_cache(None)
+    def get_window(window_type: str, window_length: int, device: str):
+        """Wrapper around scipy.signal.get_window so one can also get the
+        popular sqrt-hann window. This function caches for efficiency
+        using functools.lru\_cache.
+
+        Parameters
+        ----------
+        window_type : str
+            Type of window to get
+        window_length : int
+            Length of the window
+        device : str
+            Device to put window onto.
+
+        Returns
+        -------
+        torch.Tensor
+            Window returned by scipy.signal.get_window, as a tensor.
+        """
+        from scipy import signal
+
+        if window_type == "average":
+            window = np.ones(window_length) / window_length
+        elif window_type == "sqrt_hann":
+            window = np.sqrt(signal.get_window("hann", window_length))
+        else:
+            window = signal.get_window(window_type, window_length)
+        window = torch.from_numpy(window).to(device).float()
+        return window
+
+    @property
+    def stft_params(self):
+        """Returns STFTParams object, which can be re-used to other
+        AudioSignals.
+
+        This property can be set as well. If values are not defined in STFTParams,
+        they are inferred automatically from the signal properties. The default is to use
+        32ms windows, with 8ms hop length, and the square root of the hann window.
+
+        Returns
+        -------
+        STFTParams
+            STFT parameters for the AudioSignal.
+
+        Examples
+        --------
+        >>> stft_params = STFTParams(128, 32)
+        >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
+        >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
+        >>> signal1.stft_params = STFTParams() # Defaults
+        """
+        return self._stft_params
+
+    @stft_params.setter
+    def stft_params(self, value: STFTParams):
+        default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
+        default_hop_len = default_win_len // 4
+        default_win_type = "hann"
+        default_match_stride = False
+        default_padding_type = "reflect"
+
+        default_stft_params = STFTParams(
+            window_length=default_win_len,
+            hop_length=default_hop_len,
+            window_type=default_win_type,
+            match_stride=default_match_stride,
+            padding_type=default_padding_type,
+        )._asdict()
+
+        value = value._asdict() if value else default_stft_params
+
+        for key in default_stft_params:
+            if value[key] is None:
+                value[key] = default_stft_params[key]
+
+        self._stft_params = STFTParams(**value)
+        self.stft_data = None
+
+    def compute_stft_padding(
+        self, window_length: int, hop_length: int, match_stride: bool
+    ):
+        """Compute how the STFT should be padded, based on match\_stride.
+
+        Parameters
+        ----------
+        window_length : int
+            Window length of STFT.
+        hop_length : int
+            Hop length of STFT.
+        match_stride : bool
+            Whether or not to match stride, making the STFT have the same alignment as
+            convolutional layers.
+
+        Returns
+        -------
+        tuple
+            Amount to pad on either side of audio.
+        """
+        length = self.signal_length
+
+        if match_stride:
+            assert (
+                hop_length == window_length // 4
+            ), "For match_stride, hop must equal n_fft // 4"
+            right_pad = math.ceil(length / hop_length) * hop_length - length
+            pad = (window_length - hop_length) // 2
+        else:
+            right_pad = 0
+            pad = 0
+
+        return right_pad, pad
+
+    def stft(
+        self,
+        window_length: int = None,
+        hop_length: int = None,
+        window_type: str = None,
+        match_stride: bool = None,
+        padding_type: str = None,
+    ):
+        """Computes the short-time Fourier transform of the audio data,
+        with specified STFT parameters.
+
+        Parameters
+        ----------
+        window_length : int, optional
+            Window length of STFT, by default ``0.032 * self.sample_rate``.
+        hop_length : int, optional
+            Hop length of STFT, by default ``window_length // 4``.
+        window_type : str, optional
+            Type of window to use, by default ``sqrt\_hann``.
+        match_stride : bool, optional
+            Whether to match the stride of convolutional layers, by default False
+        padding_type : str, optional
+            Type of padding to use, by default 'reflect'
+
+        Returns
+        -------
+        torch.Tensor
+            STFT of audio data.
+
+        Examples
+        --------
+        Compute the STFT of an AudioSignal:
+
+        >>> signal = AudioSignal(torch.randn(44100), 44100)
+        >>> signal.stft()
+
+        Vary the window and hop length:
+
+        >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
+        >>> for stft_param in stft_params:
+        >>>     signal.stft_params = stft_params
+        >>>     signal.stft()
+
+        """
+        window_length = (
+            self.stft_params.window_length
+            if window_length is None
+            else int(window_length)
+        )
+        hop_length = (
+            self.stft_params.hop_length if hop_length is None else int(hop_length)
+        )
+        window_type = (
+            self.stft_params.window_type if window_type is None else window_type
+        )
+        match_stride = (
+            self.stft_params.match_stride if match_stride is None else match_stride
+        )
+        padding_type = (
+            self.stft_params.padding_type if padding_type is None else padding_type
+        )
+
+        window = self.get_window(window_type, window_length, self.audio_data.device)
+        window = window.to(self.audio_data.device)
+
+        audio_data = self.audio_data
+        right_pad, pad = self.compute_stft_padding(
+            window_length, hop_length, match_stride
+        )
+        audio_data = torch.nn.functional.pad(
+            audio_data, (pad, pad + right_pad), padding_type
+        )
+        stft_data = torch.stft(
+            audio_data.reshape(-1, audio_data.shape[-1]),
+            n_fft=window_length,
+            hop_length=hop_length,
+            window=window,
+            return_complex=True,
+            center=True,
+        )
+        _, nf, nt = stft_data.shape
+        stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
+
+        if match_stride:
+            # Drop first two and last two frames, which are added
+            # because of padding. Now num_frames * hop_length = num_samples.
+            stft_data = stft_data[..., 2:-2]
+        self.stft_data = stft_data
+
+        return stft_data
+
+    def istft(
+        self,
+        window_length: int = None,
+        hop_length: int = None,
+        window_type: str = None,
+        match_stride: bool = None,
+        length: int = None,
+    ):
+        """Computes inverse STFT and sets it to audio\_data.
+
+        Parameters
+        ----------
+        window_length : int, optional
+            Window length of STFT, by default ``0.032 * self.sample_rate``.
+        hop_length : int, optional
+            Hop length of STFT, by default ``window_length // 4``.
+        window_type : str, optional
+            Type of window to use, by default ``sqrt\_hann``.
+        match_stride : bool, optional
+            Whether to match the stride of convolutional layers, by default False
+        length : int, optional
+            Original length of signal, by default None
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with istft applied.
+
+        Raises
+        ------
+        RuntimeError
+            Raises an error if stft was not called prior to istft on the signal,
+            or if stft_data is not set.
+        """
+        if self.stft_data is None:
+            raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
+
+        window_length = (
+            self.stft_params.window_length
+            if window_length is None
+            else int(window_length)
+        )
+        hop_length = (
+            self.stft_params.hop_length if hop_length is None else int(hop_length)
+        )
+        window_type = (
+            self.stft_params.window_type if window_type is None else window_type
+        )
+        match_stride = (
+            self.stft_params.match_stride if match_stride is None else match_stride
+        )
+
+        window = self.get_window(window_type, window_length, self.stft_data.device)
+
+        nb, nch, nf, nt = self.stft_data.shape
+        stft_data = self.stft_data.reshape(nb * nch, nf, nt)
+        right_pad, pad = self.compute_stft_padding(
+            window_length, hop_length, match_stride
+        )
+
+        if length is None:
+            length = self.original_signal_length
+            length = length + 2 * pad + right_pad
+
+        if match_stride:
+            # Zero-pad the STFT on either side, putting back the frames that were
+            # dropped in stft().
+            stft_data = torch.nn.functional.pad(stft_data, (2, 2))
+
+        audio_data = torch.istft(
+            stft_data,
+            n_fft=window_length,
+            hop_length=hop_length,
+            window=window,
+            length=length,
+            center=True,
+        )
+        audio_data = audio_data.reshape(nb, nch, -1)
+        if match_stride:
+            audio_data = audio_data[..., pad : -(pad + right_pad)]
+        self.audio_data = audio_data
+
+        return self
+
+    @staticmethod
+    @functools.lru_cache(None)
+    def get_mel_filters(
+        sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
+    ):
+        """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
+
+        Parameters
+        ----------
+        sr : int
+            Sample rate of audio
+        n_fft : int
+            Number of FFT bins
+        n_mels : int
+            Number of mels
+        fmin : float, optional
+            Lowest frequency, in Hz, by default 0.0
+        fmax : float, optional
+            Highest frequency, by default None
+
+        Returns
+        -------
+        np.ndarray [shape=(n_mels, 1 + n_fft/2)]
+            Mel transform matrix
+        """
+        from librosa.filters import mel as librosa_mel_fn
+
+        return librosa_mel_fn(
+            sr=sr,
+            n_fft=n_fft,
+            n_mels=n_mels,
+            fmin=fmin,
+            fmax=fmax,
+        )
+
+    def mel_spectrogram(
+        self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
+    ):
+        """Computes a Mel spectrogram.
+
+        Parameters
+        ----------
+        n_mels : int, optional
+            Number of mels, by default 80
+        mel_fmin : float, optional
+            Lowest frequency, in Hz, by default 0.0
+        mel_fmax : float, optional
+            Highest frequency, by default None
+        kwargs : dict, optional
+            Keyword arguments to self.stft().
+
+        Returns
+        -------
+        torch.Tensor [shape=(batch, channels, mels, time)]
+            Mel spectrogram.
+        """
+        stft = self.stft(**kwargs)
+        magnitude = torch.abs(stft)
+
+        nf = magnitude.shape[2]
+        mel_basis = self.get_mel_filters(
+            sr=self.sample_rate,
+            n_fft=2 * (nf - 1),
+            n_mels=n_mels,
+            fmin=mel_fmin,
+            fmax=mel_fmax,
+        )
+        mel_basis = torch.from_numpy(mel_basis).to(self.device)
+
+        mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
+        mel_spectrogram = mel_spectrogram.transpose(-1, 2)
+        return mel_spectrogram
+
+    @staticmethod
+    @functools.lru_cache(None)
+    def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
+        """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
+        it can be normalized depending on norm. For more information about dct:
+        http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
+
+        Parameters
+        ----------
+        n_mfcc : int
+            Number of mfccs
+        n_mels : int
+            Number of mels
+        norm   : str
+            Use "ortho" to get a orthogonal matrix or None, by default "ortho"
+        device : str, optional
+            Device to load the transformation matrix on, by default None
+
+        Returns
+        -------
+        torch.Tensor [shape=(n_mels, n_mfcc)] T
+            The dct transformation matrix.
+        """
+        from torchaudio.functional import create_dct
+
+        return create_dct(n_mfcc, n_mels, norm).to(device)
+
+    def mfcc(
+        self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
+    ):
+        """Computes mel-frequency cepstral coefficients (MFCCs).
+
+        Parameters
+        ----------
+        n_mfcc : int, optional
+            Number of mels, by default 40
+        n_mels : int, optional
+            Number of mels, by default 80
+        log_offset: float, optional
+            Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
+        kwargs : dict, optional
+            Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
+
+        Returns
+        -------
+        torch.Tensor [shape=(batch, channels, mfccs, time)]
+            MFCCs.
+        """
+
+        mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
+        mel_spectrogram = torch.log(mel_spectrogram + log_offset)
+        dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
+
+        mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
+        mfcc = mfcc.transpose(-1, -2)
+        return mfcc
+
+    @property
+    def magnitude(self):
+        """Computes and returns the absolute value of the STFT, which
+        is the magnitude. This value can also be set to some tensor.
+        When set, ``self.stft_data`` is manipulated so that its magnitude
+        matches what this is set to, and modulated by the phase.
+
+        Returns
+        -------
+        torch.Tensor
+            Magnitude of STFT.
+
+        Examples
+        --------
+        >>> signal = AudioSignal(torch.randn(44100), 44100)
+        >>> magnitude = signal.magnitude # Computes stft if not computed
+        >>> magnitude[magnitude < magnitude.mean()] = 0
+        >>> signal.magnitude = magnitude
+        >>> signal.istft()
+        """
+        if self.stft_data is None:
+            self.stft()
+        return torch.abs(self.stft_data)
+
+    @magnitude.setter
+    def magnitude(self, value):
+        self.stft_data = value * torch.exp(1j * self.phase)
+        return
+
+    def log_magnitude(
+        self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
+    ):
+        """Computes the log-magnitude of the spectrogram.
+
+        Parameters
+        ----------
+        ref_value : float, optional
+            The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
+            Zeros in the output correspond to positions where ``S == ref``,
+            by default 1.0
+        amin : float, optional
+            Minimum threshold for ``S`` and ``ref``, by default 1e-5
+        top_db : float, optional
+            Threshold the output at ``top_db`` below the peak:
+            ``max(10 * log10(S/ref)) - top_db``, by default -80.0
+
+        Returns
+        -------
+        torch.Tensor
+            Log-magnitude spectrogram
+        """
+        magnitude = self.magnitude
+
+        amin = amin**2
+        log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
+        log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
+
+        if top_db is not None:
+            log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
+        return log_spec
+
+    @property
+    def phase(self):
+        """Computes and returns the phase of the STFT.
+        This value can also be set to some tensor.
+        When set, ``self.stft_data`` is manipulated so that its phase
+        matches what this is set to, we original magnitudeith th.
+
+        Returns
+        -------
+        torch.Tensor
+            Phase of STFT.
+
+        Examples
+        --------
+        >>> signal = AudioSignal(torch.randn(44100), 44100)
+        >>> phase = signal.phase # Computes stft if not computed
+        >>> phase[phase < phase.mean()] = 0
+        >>> signal.phase = phase
+        >>> signal.istft()
+        """
+        if self.stft_data is None:
+            self.stft()
+        return torch.angle(self.stft_data)
+
+    @phase.setter
+    def phase(self, value):
+        self.stft_data = self.magnitude * torch.exp(1j * value)
+        return
+
+    # Operator overloading
+    def __add__(self, other):
+        new_signal = self.clone()
+        new_signal.audio_data += util._get_value(other)
+        return new_signal
+
+    def __iadd__(self, other):
+        self.audio_data += util._get_value(other)
+        return self
+
+    def __radd__(self, other):
+        return self + other
+
+    def __sub__(self, other):
+        new_signal = self.clone()
+        new_signal.audio_data -= util._get_value(other)
+        return new_signal
+
+    def __isub__(self, other):
+        self.audio_data -= util._get_value(other)
+        return self
+
+    def __mul__(self, other):
+        new_signal = self.clone()
+        new_signal.audio_data *= util._get_value(other)
+        return new_signal
+
+    def __imul__(self, other):
+        self.audio_data *= util._get_value(other)
+        return self
+
+    def __rmul__(self, other):
+        return self * other
+
+    # Representation
+    def _info(self):
+        dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
+        info = {
+            "duration": f"{dur} seconds",
+            "batch_size": self.batch_size,
+            "path": self.path_to_file if self.path_to_file else "path unknown",
+            "sample_rate": self.sample_rate,
+            "num_channels": self.num_channels if self.num_channels else "[unknown]",
+            "audio_data.shape": self.audio_data.shape,
+            "stft_params": self.stft_params,
+            "device": self.device,
+        }
+
+        return info
+
+    def markdown(self):
+        """Produces a markdown representation of AudioSignal, in a markdown table.
+
+        Returns
+        -------
+        str
+            Markdown representation of AudioSignal.
+
+        Examples
+        --------
+        >>> signal = AudioSignal(torch.randn(44100), 44100)
+        >>> print(signal.markdown())
+        | Key | Value
+        |---|---
+        | duration | 1.000 seconds |
+        | batch_size | 1 |
+        | path | path unknown |
+        | sample_rate | 44100 |
+        | num_channels | 1 |
+        | audio_data.shape | torch.Size([1, 1, 44100]) |
+        | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
+        | device | cpu |
+        """
+        info = self._info()
+
+        FORMAT = "| Key | Value \n" "|---|--- \n"
+        for k, v in info.items():
+            row = f"| {k} | {v} |\n"
+            FORMAT += row
+        return FORMAT
+
+    def __str__(self):
+        info = self._info()
+
+        desc = ""
+        for k, v in info.items():
+            desc += f"{k}: {v}\n"
+        return desc
+
+    def __rich__(self):
+        from rich.table import Table
+
+        info = self._info()
+
+        table = Table(title=f"{self.__class__.__name__}")
+        table.add_column("Key", style="green")
+        table.add_column("Value", style="cyan")
+
+        for k, v in info.items():
+            table.add_row(k, str(v))
+        return table
+
+    # Comparison
+    def __eq__(self, other):
+        for k, v in list(self.__dict__.items()):
+            if torch.is_tensor(v):
+                if not torch.allclose(v, other.__dict__[k], atol=1e-6):
+                    max_error = (v - other.__dict__[k]).abs().max()
+                    print(f"Max abs error for {k}: {max_error}")
+                    return False
+        return True
+
+    # Indexing
+    def __getitem__(self, key):
+        if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
+            assert self.batch_size == 1
+            audio_data = self.audio_data
+            _loudness = self._loudness
+            stft_data = self.stft_data
+
+        elif isinstance(key, (bool, int, list, slice, tuple)) or (
+            torch.is_tensor(key) and key.ndim <= 1
+        ):
+            # Indexing only on the batch dimension.
+            # Then let's copy over relevant stuff.
+            # Future work: make this work for time-indexing
+            # as well, using the hop length.
+            audio_data = self.audio_data[key]
+            _loudness = self._loudness[key] if self._loudness is not None else None
+            stft_data = self.stft_data[key] if self.stft_data is not None else None
+
+        sources = None
+
+        copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
+        copy._loudness = _loudness
+        copy._stft_data = stft_data
+        copy.sources = sources
+
+        return copy
+
+    def __setitem__(self, key, value):
+        if not isinstance(value, type(self)):
+            self.audio_data[key] = value
+            return
+
+        if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
+            assert self.batch_size == 1
+            self.audio_data = value.audio_data
+            self._loudness = value._loudness
+            self.stft_data = value.stft_data
+            return
+
+        elif isinstance(key, (bool, int, list, slice, tuple)) or (
+            torch.is_tensor(key) and key.ndim <= 1
+        ):
+            if self.audio_data is not None and value.audio_data is not None:
+                self.audio_data[key] = value.audio_data
+            if self._loudness is not None and value._loudness is not None:
+                self._loudness[key] = value._loudness
+            if self.stft_data is not None and value.stft_data is not None:
+                self.stft_data[key] = value.stft_data
+            return
+
+    def __ne__(self, other):
+        return not self == other
diff --git a/audiotools/core/display.py b/audiotools/core/display.py
new file mode 100644
index 0000000000000000000000000000000000000000..66cbcf34cb2cf9fdf8d67ec4418a887eba73f184
--- /dev/null
+++ b/audiotools/core/display.py
@@ -0,0 +1,194 @@
+import inspect
+import typing
+from functools import wraps
+
+from . import util
+
+
+def format_figure(func):
+    """Decorator for formatting figures produced by the code below.
+    See :py:func:`audiotools.core.util.format_figure` for more.
+
+    Parameters
+    ----------
+    func : Callable
+        Plotting function that is decorated by this function.
+
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        f_keys = inspect.signature(util.format_figure).parameters.keys()
+        f_kwargs = {}
+        for k, v in list(kwargs.items()):
+            if k in f_keys:
+                kwargs.pop(k)
+                f_kwargs[k] = v
+        func(*args, **kwargs)
+        util.format_figure(**f_kwargs)
+
+    return wrapper
+
+
+class DisplayMixin:
+    @format_figure
+    def specshow(
+        self,
+        preemphasis: bool = False,
+        x_axis: str = "time",
+        y_axis: str = "linear",
+        n_mels: int = 128,
+        **kwargs,
+    ):
+        """Displays a spectrogram, using ``librosa.display.specshow``.
+
+        Parameters
+        ----------
+        preemphasis : bool, optional
+            Whether or not to apply preemphasis, which makes high
+            frequency detail easier to see, by default False
+        x_axis : str, optional
+            How to label the x axis, by default "time"
+        y_axis : str, optional
+            How to label the y axis, by default "linear"
+        n_mels : int, optional
+            If displaying a mel spectrogram with ``y_axis = "mel"``,
+            this controls the number of mels, by default 128.
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
+        """
+        import librosa
+        import librosa.display
+
+        # Always re-compute the STFT data before showing it, in case
+        # it changed.
+        signal = self.clone()
+        signal.stft_data = None
+
+        if preemphasis:
+            signal.preemphasis()
+
+        ref = signal.magnitude.max()
+        log_mag = signal.log_magnitude(ref_value=ref)
+
+        if y_axis == "mel":
+            log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
+            log_mag -= log_mag.max()
+
+        librosa.display.specshow(
+            log_mag.numpy()[0].mean(axis=0),
+            x_axis=x_axis,
+            y_axis=y_axis,
+            sr=signal.sample_rate,
+            **kwargs,
+        )
+
+    @format_figure
+    def waveplot(self, x_axis: str = "time", **kwargs):
+        """Displays a waveform plot, using ``librosa.display.waveshow``.
+
+        Parameters
+        ----------
+        x_axis : str, optional
+            How to label the x axis, by default "time"
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
+        """
+        import librosa
+        import librosa.display
+
+        audio_data = self.audio_data[0].mean(dim=0)
+        audio_data = audio_data.cpu().numpy()
+
+        plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
+        wave_plot_fn = getattr(librosa.display, plot_fn)
+        wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
+
+    @format_figure
+    def wavespec(self, x_axis: str = "time", **kwargs):
+        """Displays a waveform plot, using ``librosa.display.waveshow``.
+
+        Parameters
+        ----------
+        x_axis : str, optional
+            How to label the x axis, by default "time"
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
+        """
+        import matplotlib.pyplot as plt
+        from matplotlib.gridspec import GridSpec
+
+        gs = GridSpec(6, 1)
+        plt.subplot(gs[0, :])
+        self.waveplot(x_axis=x_axis)
+        plt.subplot(gs[1:, :])
+        self.specshow(x_axis=x_axis, **kwargs)
+
+    def write_audio_to_tb(
+        self,
+        tag: str,
+        writer,
+        step: int = None,
+        plot_fn: typing.Union[typing.Callable, str] = "specshow",
+        **kwargs,
+    ):
+        """Writes a signal and its spectrogram to Tensorboard. Will show up
+        under the Audio and Images tab in Tensorboard.
+
+        Parameters
+        ----------
+        tag : str
+            Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
+            written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
+        writer : SummaryWriter
+            A SummaryWriter object from PyTorch library.
+        step : int, optional
+            The step to write the signal to, by default None
+        plot_fn : typing.Union[typing.Callable, str], optional
+            How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
+            whatever ``plot_fn`` is set to.
+        """
+        import matplotlib.pyplot as plt
+
+        audio_data = self.audio_data[0, 0].detach().cpu()
+        sample_rate = self.sample_rate
+        writer.add_audio(tag, audio_data, step, sample_rate)
+
+        if plot_fn is not None:
+            if isinstance(plot_fn, str):
+                plot_fn = getattr(self, plot_fn)
+            fig = plt.figure()
+            plt.clf()
+            plot_fn(**kwargs)
+            writer.add_figure(tag.replace("wav", "png"), fig, step)
+
+    def save_image(
+        self,
+        image_path: str,
+        plot_fn: typing.Union[typing.Callable, str] = "specshow",
+        **kwargs,
+    ):
+        """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
+        a specified file.
+
+        Parameters
+        ----------
+        image_path : str
+            Where to save the file to.
+        plot_fn : typing.Union[typing.Callable, str], optional
+            How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
+            whatever ``plot_fn`` is set to.
+        """
+        import matplotlib.pyplot as plt
+
+        if isinstance(plot_fn, str):
+            plot_fn = getattr(self, plot_fn)
+
+        plt.clf()
+        plot_fn(**kwargs)
+        plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
+        plt.close()
diff --git a/audiotools/core/dsp.py b/audiotools/core/dsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9be51a119537b77e497ddc2dac126d569533d7c
--- /dev/null
+++ b/audiotools/core/dsp.py
@@ -0,0 +1,390 @@
+import typing
+
+import julius
+import numpy as np
+import torch
+
+from . import util
+
+
+class DSPMixin:
+    _original_batch_size = None
+    _original_num_channels = None
+    _padded_signal_length = None
+
+    def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
+        self._original_batch_size = self.batch_size
+        self._original_num_channels = self.num_channels
+
+        window_length = int(window_duration * self.sample_rate)
+        hop_length = int(hop_duration * self.sample_rate)
+
+        if window_length % hop_length != 0:
+            factor = window_length // hop_length
+            window_length = factor * hop_length
+
+        self.zero_pad(hop_length, hop_length)
+        self._padded_signal_length = self.signal_length
+
+        return window_length, hop_length
+
+    def windows(
+        self, window_duration: float, hop_duration: float, preprocess: bool = True
+    ):
+        """Generator which yields windows of specified duration from signal with a specified
+        hop length.
+
+        Parameters
+        ----------
+        window_duration : float
+            Duration of every window in seconds.
+        hop_duration : float
+            Hop between windows in seconds.
+        preprocess : bool, optional
+            Whether to preprocess the signal, so that the first sample is in
+            the middle of the first window, by default True
+
+        Yields
+        ------
+        AudioSignal
+            Each window is returned as an AudioSignal.
+        """
+        if preprocess:
+            window_length, hop_length = self._preprocess_signal_for_windowing(
+                window_duration, hop_duration
+            )
+
+        self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
+
+        for b in range(self.batch_size):
+            i = 0
+            start_idx = i * hop_length
+            while True:
+                start_idx = i * hop_length
+                i += 1
+                end_idx = start_idx + window_length
+                if end_idx > self.signal_length:
+                    break
+                yield self[b, ..., start_idx:end_idx]
+
+    def collect_windows(
+        self, window_duration: float, hop_duration: float, preprocess: bool = True
+    ):
+        """Reshapes signal into windows of specified duration from signal with a specified
+        hop length. Window are placed along the batch dimension. Use with
+        :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
+        original signal.
+
+        Parameters
+        ----------
+        window_duration : float
+            Duration of every window in seconds.
+        hop_duration : float
+            Hop between windows in seconds.
+        preprocess : bool, optional
+            Whether to preprocess the signal, so that the first sample is in
+            the middle of the first window, by default True
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
+        """
+        if preprocess:
+            window_length, hop_length = self._preprocess_signal_for_windowing(
+                window_duration, hop_duration
+            )
+
+        # self.audio_data: (nb, nch, nt).
+        unfolded = torch.nn.functional.unfold(
+            self.audio_data.reshape(-1, 1, 1, self.signal_length),
+            kernel_size=(1, window_length),
+            stride=(1, hop_length),
+        )
+        # unfolded: (nb * nch, window_length, num_windows).
+        # -> (nb * nch * num_windows, 1, window_length)
+        unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
+        self.audio_data = unfolded
+        return self
+
+    def overlap_and_add(self, hop_duration: float):
+        """Function which takes a list of windows and overlap adds them into a
+        signal the same length as ``audio_signal``.
+
+        Parameters
+        ----------
+        hop_duration : float
+            How much to shift for each window
+            (overlap is window_duration - hop_duration) in seconds.
+
+        Returns
+        -------
+        AudioSignal
+            overlap-and-added signal.
+        """
+        hop_length = int(hop_duration * self.sample_rate)
+        window_length = self.signal_length
+
+        nb, nch = self._original_batch_size, self._original_num_channels
+
+        unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
+        folded = torch.nn.functional.fold(
+            unfolded,
+            output_size=(1, self._padded_signal_length),
+            kernel_size=(1, window_length),
+            stride=(1, hop_length),
+        )
+
+        norm = torch.ones_like(unfolded, device=unfolded.device)
+        norm = torch.nn.functional.fold(
+            norm,
+            output_size=(1, self._padded_signal_length),
+            kernel_size=(1, window_length),
+            stride=(1, hop_length),
+        )
+
+        folded = folded / norm
+
+        folded = folded.reshape(nb, nch, -1)
+        self.audio_data = folded
+        self.trim(hop_length, hop_length)
+        return self
+
+    def low_pass(
+        self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
+    ):
+        """Low-passes the signal in-place. Each item in the batch
+        can have a different low-pass cutoff, if the input
+        to this signal is an array or tensor. If a float, all
+        items are given the same low-pass filter.
+
+        Parameters
+        ----------
+        cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
+            Cutoff in Hz of low-pass filter.
+        zeros : int, optional
+            Number of taps to use in low-pass filter, by default 51
+
+        Returns
+        -------
+        AudioSignal
+            Low-passed AudioSignal.
+        """
+        cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
+        cutoffs = cutoffs / self.sample_rate
+        filtered = torch.empty_like(self.audio_data)
+
+        for i, cutoff in enumerate(cutoffs):
+            lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
+            filtered[i] = lp_filter(self.audio_data[i])
+
+        self.audio_data = filtered
+        self.stft_data = None
+        return self
+
+    def high_pass(
+        self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
+    ):
+        """High-passes the signal in-place. Each item in the batch
+        can have a different high-pass cutoff, if the input
+        to this signal is an array or tensor. If a float, all
+        items are given the same high-pass filter.
+
+        Parameters
+        ----------
+        cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
+            Cutoff in Hz of high-pass filter.
+        zeros : int, optional
+            Number of taps to use in high-pass filter, by default 51
+
+        Returns
+        -------
+        AudioSignal
+            High-passed AudioSignal.
+        """
+        cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
+        cutoffs = cutoffs / self.sample_rate
+        filtered = torch.empty_like(self.audio_data)
+
+        for i, cutoff in enumerate(cutoffs):
+            hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
+            filtered[i] = hp_filter(self.audio_data[i])
+
+        self.audio_data = filtered
+        self.stft_data = None
+        return self
+
+    def mask_frequencies(
+        self,
+        fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
+        fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
+        val: float = 0.0,
+    ):
+        """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
+        with the value specified by ``val``. Useful for implementing SpecAug.
+        The min and max can be different for every item in the batch.
+
+        Parameters
+        ----------
+        fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
+            Lower end of band to mask out.
+        fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
+            Upper end of band to mask out.
+        val : float, optional
+            Value to fill in, by default 0.0
+
+        Returns
+        -------
+        AudioSignal
+            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+            masked audio data.
+        """
+        # SpecAug
+        mag, phase = self.magnitude, self.phase
+        fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
+        fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
+        assert torch.all(fmin_hz < fmax_hz)
+
+        # build mask
+        nbins = mag.shape[-2]
+        bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
+        bins_hz = bins_hz[None, None, :, None].repeat(
+            self.batch_size, 1, 1, mag.shape[-1]
+        )
+        mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
+        mask = mask.to(self.device)
+
+        mag = mag.masked_fill(mask, val)
+        phase = phase.masked_fill(mask, val)
+        self.stft_data = mag * torch.exp(1j * phase)
+        return self
+
+    def mask_timesteps(
+        self,
+        tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
+        tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
+        val: float = 0.0,
+    ):
+        """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
+        with the value specified by ``val``. Useful for implementing SpecAug.
+        The min and max can be different for every item in the batch.
+
+        Parameters
+        ----------
+        tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
+            Lower end of timesteps to mask out.
+        tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
+            Upper end of timesteps to mask out.
+        val : float, optional
+            Value to fill in, by default 0.0
+
+        Returns
+        -------
+        AudioSignal
+            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+            masked audio data.
+        """
+        # SpecAug
+        mag, phase = self.magnitude, self.phase
+        tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
+        tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
+
+        assert torch.all(tmin_s < tmax_s)
+
+        # build mask
+        nt = mag.shape[-1]
+        bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
+        bins_t = bins_t[None, None, None, :].repeat(
+            self.batch_size, 1, mag.shape[-2], 1
+        )
+        mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
+
+        mag = mag.masked_fill(mask, val)
+        phase = phase.masked_fill(mask, val)
+        self.stft_data = mag * torch.exp(1j * phase)
+        return self
+
+    def mask_low_magnitudes(
+        self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
+    ):
+        """Mask away magnitudes below a specified threshold, which
+        can be different for every item in the batch.
+
+        Parameters
+        ----------
+        db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
+            Decibel value for which things below it will be masked away.
+        val : float, optional
+            Value to fill in for masked portions, by default 0.0
+
+        Returns
+        -------
+        AudioSignal
+            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+            masked audio data.
+        """
+        mag = self.magnitude
+        log_mag = self.log_magnitude()
+
+        db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
+        mask = log_mag < db_cutoff
+        mag = mag.masked_fill(mask, val)
+
+        self.magnitude = mag
+        return self
+
+    def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
+        """Shifts the phase by a constant value.
+
+        Parameters
+        ----------
+        shift : typing.Union[torch.Tensor, np.ndarray, float]
+            What to shift the phase by.
+
+        Returns
+        -------
+        AudioSignal
+            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+            masked audio data.
+        """
+        shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
+        self.phase = self.phase + shift
+        return self
+
+    def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
+        """Corrupts the phase randomly by some scaled value.
+
+        Parameters
+        ----------
+        scale : typing.Union[torch.Tensor, np.ndarray, float]
+            Standard deviation of noise to add to the phase.
+
+        Returns
+        -------
+        AudioSignal
+            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+            masked audio data.
+        """
+        scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
+        self.phase = self.phase + scale * torch.randn_like(self.phase)
+        return self
+
+    def preemphasis(self, coef: float = 0.85):
+        """Applies pre-emphasis to audio signal.
+
+        Parameters
+        ----------
+        coef : float, optional
+            How much pre-emphasis to apply, lower values do less. 0 does nothing.
+            by default 0.85
+
+        Returns
+        -------
+        AudioSignal
+            Pre-emphasized signal.
+        """
+        kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
+        x = self.audio_data.reshape(-1, 1, self.signal_length)
+        x = torch.nn.functional.conv1d(x, kernel, padding=1)
+        self.audio_data = x.reshape(*self.audio_data.shape)
+        return self
diff --git a/audiotools/core/effects.py b/audiotools/core/effects.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb534cbcb2d457575de685fc9248d1716879145b
--- /dev/null
+++ b/audiotools/core/effects.py
@@ -0,0 +1,647 @@
+import typing
+
+import julius
+import numpy as np
+import torch
+import torchaudio
+
+from . import util
+
+
+class EffectMixin:
+    GAIN_FACTOR = np.log(10) / 20
+    """Gain factor for converting between amplitude and decibels."""
+    CODEC_PRESETS = {
+        "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
+        "GSM-FR": {"format": "gsm"},
+        "MP3": {"format": "mp3", "compression": -9},
+        "Vorbis": {"format": "vorbis", "compression": -1},
+        "Ogg": {
+            "format": "ogg",
+            "compression": -1,
+        },
+        "Amr-nb": {"format": "amr-nb"},
+    }
+    """Presets for applying codecs via torchaudio."""
+
+    def mix(
+        self,
+        other,
+        snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
+        other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
+    ):
+        """Mixes noise with signal at specified
+        signal-to-noise ratio. Optionally, the
+        other signal can be equalized in-place.
+
+
+        Parameters
+        ----------
+        other : AudioSignal
+            AudioSignal object to mix with.
+        snr : typing.Union[torch.Tensor, np.ndarray, float], optional
+            Signal to noise ratio, by default 10
+        other_eq : typing.Union[torch.Tensor, np.ndarray], optional
+            EQ curve to apply to other signal, if any, by default None
+
+        Returns
+        -------
+        AudioSignal
+            In-place modification of AudioSignal.
+        """
+        snr = util.ensure_tensor(snr).to(self.device)
+
+        pad_len = max(0, self.signal_length - other.signal_length)
+        other.zero_pad(0, pad_len)
+        other.truncate_samples(self.signal_length)
+        if other_eq is not None:
+            other = other.equalizer(other_eq)
+
+        tgt_loudness = self.loudness() - snr
+        other = other.normalize(tgt_loudness)
+
+        self.audio_data = self.audio_data + other.audio_data
+        return self
+
+    def convolve(self, other, start_at_max: bool = True):
+        """Convolves self with other.
+        This function uses FFTs to do the convolution.
+
+        Parameters
+        ----------
+        other : AudioSignal
+            Signal to convolve with.
+        start_at_max : bool, optional
+            Whether to start at the max value of other signal, to
+            avoid inducing delays, by default True
+
+        Returns
+        -------
+        AudioSignal
+            Convolved signal, in-place.
+        """
+        from . import AudioSignal
+
+        pad_len = self.signal_length - other.signal_length
+
+        if pad_len > 0:
+            other.zero_pad(0, pad_len)
+        else:
+            other.truncate_samples(self.signal_length)
+
+        if start_at_max:
+            # Use roll to rotate over the max for every item
+            # so that the impulse responses don't induce any
+            # delay.
+            idx = other.audio_data.abs().argmax(axis=-1)
+            irs = torch.zeros_like(other.audio_data)
+            for i in range(other.batch_size):
+                irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
+            other = AudioSignal(irs, other.sample_rate)
+
+        delta = torch.zeros_like(other.audio_data)
+        delta[..., 0] = 1
+
+        length = self.signal_length
+        delta_fft = torch.fft.rfft(delta, length)
+        other_fft = torch.fft.rfft(other.audio_data, length)
+        self_fft = torch.fft.rfft(self.audio_data, length)
+
+        convolved_fft = other_fft * self_fft
+        convolved_audio = torch.fft.irfft(convolved_fft, length)
+
+        delta_convolved_fft = other_fft * delta_fft
+        delta_audio = torch.fft.irfft(delta_convolved_fft, length)
+
+        # Use the delta to rescale the audio exactly as needed.
+        delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
+        scale = 1 / delta_max.clamp(1e-5)
+        convolved_audio = convolved_audio * scale
+
+        self.audio_data = convolved_audio
+
+        return self
+
+    def apply_ir(
+        self,
+        ir,
+        drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
+        ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
+        use_original_phase: bool = False,
+    ):
+        """Applies an impulse response to the signal. If ` is`ir_eq``
+        is specified, the impulse response is equalized before
+        it is applied, using the given curve.
+
+        Parameters
+        ----------
+        ir : AudioSignal
+            Impulse response to convolve with.
+        drr : typing.Union[torch.Tensor, np.ndarray, float], optional
+            Direct-to-reverberant ratio that impulse response will be
+            altered to, if specified, by default None
+        ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
+            Equalization that will be applied to impulse response
+            if specified, by default None
+        use_original_phase : bool, optional
+            Whether to use the original phase, instead of the convolved
+            phase, by default False
+
+        Returns
+        -------
+        AudioSignal
+            Signal with impulse response applied to it
+        """
+        if ir_eq is not None:
+            ir = ir.equalizer(ir_eq)
+        if drr is not None:
+            ir = ir.alter_drr(drr)
+
+        # Save the peak before
+        max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
+
+        # Augment the impulse response to simulate microphone effects
+        # and with varying direct-to-reverberant ratio.
+        phase = self.phase
+        self.convolve(ir)
+
+        # Use the input phase
+        if use_original_phase:
+            self.stft()
+            self.stft_data = self.magnitude * torch.exp(1j * phase)
+            self.istft()
+
+        # Rescale to the input's amplitude
+        max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
+        scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
+        self = self * scale_factor
+
+        return self
+
+    def ensure_max_of_audio(self, max: float = 1.0):
+        """Ensures that ``abs(audio_data) <= max``.
+
+        Parameters
+        ----------
+        max : float, optional
+            Max absolute value of signal, by default 1.0
+
+        Returns
+        -------
+        AudioSignal
+            Signal with values scaled between -max and max.
+        """
+        peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
+        peak_gain = torch.ones_like(peak)
+        peak_gain[peak > max] = max / peak[peak > max]
+        self.audio_data = self.audio_data * peak_gain
+        return self
+
+    def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
+        """Normalizes the signal's volume to the specified db, in LUFS.
+        This is GPU-compatible, making for very fast loudness normalization.
+
+        Parameters
+        ----------
+        db : typing.Union[torch.Tensor, np.ndarray, float], optional
+            Loudness to normalize to, by default -24.0
+
+        Returns
+        -------
+        AudioSignal
+            Normalized audio signal.
+        """
+        db = util.ensure_tensor(db).to(self.device)
+        ref_db = self.loudness()
+        gain = db - ref_db
+        gain = torch.exp(gain * self.GAIN_FACTOR)
+
+        self.audio_data = self.audio_data * gain[:, None, None]
+        return self
+
+    def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
+        """Change volume of signal by some amount, in dB.
+
+        Parameters
+        ----------
+        db : typing.Union[torch.Tensor, np.ndarray, float]
+            Amount to change volume by.
+
+        Returns
+        -------
+        AudioSignal
+            Signal at new volume.
+        """
+        db = util.ensure_tensor(db, ndim=1).to(self.device)
+        gain = torch.exp(db * self.GAIN_FACTOR)
+        self.audio_data = self.audio_data * gain[:, None, None]
+        return self
+
+    def _to_2d(self):
+        waveform = self.audio_data.reshape(-1, self.signal_length)
+        return waveform
+
+    def _to_3d(self, waveform):
+        return waveform.reshape(self.batch_size, self.num_channels, -1)
+
+    def pitch_shift(self, n_semitones: int, quick: bool = True):
+        """Pitch shift the signal. All items in the batch
+        get the same pitch shift.
+
+        Parameters
+        ----------
+        n_semitones : int
+            How many semitones to shift the signal by.
+        quick : bool, optional
+            Using quick pitch shifting, by default True
+
+        Returns
+        -------
+        AudioSignal
+            Pitch shifted audio signal.
+        """
+        device = self.device
+        effects = [
+            ["pitch", str(n_semitones * 100)],
+            ["rate", str(self.sample_rate)],
+        ]
+        if quick:
+            effects[0].insert(1, "-q")
+
+        waveform = self._to_2d().cpu()
+        waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
+            waveform, self.sample_rate, effects, channels_first=True
+        )
+        self.sample_rate = sample_rate
+        self.audio_data = self._to_3d(waveform)
+        return self.to(device)
+
+    def time_stretch(self, factor: float, quick: bool = True):
+        """Time stretch the audio signal.
+
+        Parameters
+        ----------
+        factor : float
+            Factor by which to stretch the AudioSignal. Typically
+            between 0.8 and 1.2.
+        quick : bool, optional
+            Whether to use quick time stretching, by default True
+
+        Returns
+        -------
+        AudioSignal
+            Time-stretched AudioSignal.
+        """
+        device = self.device
+        effects = [
+            ["tempo", str(factor)],
+            ["rate", str(self.sample_rate)],
+        ]
+        if quick:
+            effects[0].insert(1, "-q")
+
+        waveform = self._to_2d().cpu()
+        waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
+            waveform, self.sample_rate, effects, channels_first=True
+        )
+        self.sample_rate = sample_rate
+        self.audio_data = self._to_3d(waveform)
+        return self.to(device)
+
+    def apply_codec(
+        self,
+        preset: str = None,
+        format: str = "wav",
+        encoding: str = None,
+        bits_per_sample: int = None,
+        compression: int = None,
+    ):  # pragma: no cover
+        """Applies an audio codec to the signal.
+
+        Parameters
+        ----------
+        preset : str, optional
+            One of the keys in ``self.CODEC_PRESETS``, by default None
+        format : str, optional
+            Format for audio codec, by default "wav"
+        encoding : str, optional
+            Encoding to use, by default None
+        bits_per_sample : int, optional
+            How many bits per sample, by default None
+        compression : int, optional
+            Compression amount of codec, by default None
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with codec applied.
+
+        Raises
+        ------
+        ValueError
+            If preset is not in ``self.CODEC_PRESETS``, an error
+            is thrown.
+        """
+        torchaudio_version_070 = "0.7" in torchaudio.__version__
+        if torchaudio_version_070:
+            return self
+
+        kwargs = {
+            "format": format,
+            "encoding": encoding,
+            "bits_per_sample": bits_per_sample,
+            "compression": compression,
+        }
+
+        if preset is not None:
+            if preset in self.CODEC_PRESETS:
+                kwargs = self.CODEC_PRESETS[preset]
+            else:
+                raise ValueError(
+                    f"Unknown preset: {preset}. "
+                    f"Known presets: {list(self.CODEC_PRESETS.keys())}"
+                )
+
+        waveform = self._to_2d()
+        if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
+            # Apply it in a for loop
+            augmented = torch.cat(
+                [
+                    torchaudio.functional.apply_codec(
+                        waveform[i][None, :], self.sample_rate, **kwargs
+                    )
+                    for i in range(waveform.shape[0])
+                ],
+                dim=0,
+            )
+        else:
+            augmented = torchaudio.functional.apply_codec(
+                waveform, self.sample_rate, **kwargs
+            )
+        augmented = self._to_3d(augmented)
+
+        self.audio_data = augmented
+        return self
+
+    def mel_filterbank(self, n_bands: int):
+        """Breaks signal into mel bands.
+
+        Parameters
+        ----------
+        n_bands : int
+            Number of mel bands to use.
+
+        Returns
+        -------
+        torch.Tensor
+            Mel-filtered bands, with last axis being the band index.
+        """
+        filterbank = (
+            julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
+        )
+        filtered = filterbank(self.audio_data)
+        return filtered.permute(1, 2, 3, 0)
+
+    def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
+        """Applies a mel-spaced equalizer to the audio signal.
+
+        Parameters
+        ----------
+        db : typing.Union[torch.Tensor, np.ndarray]
+            EQ curve to apply.
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal with equalization applied.
+        """
+        db = util.ensure_tensor(db)
+        n_bands = db.shape[-1]
+        fbank = self.mel_filterbank(n_bands)
+
+        # If there's a batch dimension, make sure it's the same.
+        if db.ndim == 2:
+            if db.shape[0] != 1:
+                assert db.shape[0] == fbank.shape[0]
+        else:
+            db = db.unsqueeze(0)
+
+        weights = (10**db).to(self.device).float()
+        fbank = fbank * weights[:, None, None, :]
+        eq_audio_data = fbank.sum(-1)
+        self.audio_data = eq_audio_data
+        return self
+
+    def clip_distortion(
+        self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
+    ):
+        """Clips the signal at a given percentile. The higher it is,
+        the lower the threshold for clipping.
+
+        Parameters
+        ----------
+        clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
+            Values are between 0.0 to 1.0. Typical values are 0.1 or below.
+
+        Returns
+        -------
+        AudioSignal
+            Audio signal with clipped audio data.
+        """
+        clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
+        min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
+        max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
+
+        nc = self.audio_data.shape[1]
+        min_thresh = min_thresh[:, :nc, :]
+        max_thresh = max_thresh[:, :nc, :]
+
+        self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
+
+        return self
+
+    def quantization(
+        self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
+    ):
+        """Applies quantization to the input waveform.
+
+        Parameters
+        ----------
+        quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
+            Number of evenly spaced quantization channels to quantize
+            to.
+
+        Returns
+        -------
+        AudioSignal
+            Quantized AudioSignal.
+        """
+        quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
+
+        x = self.audio_data
+        x = (x + 1) / 2
+        x = x * quantization_channels
+        x = x.floor()
+        x = x / quantization_channels
+        x = 2 * x - 1
+
+        residual = (self.audio_data - x).detach()
+        self.audio_data = self.audio_data - residual
+        return self
+
+    def mulaw_quantization(
+        self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
+    ):
+        """Applies mu-law quantization to the input waveform.
+
+        Parameters
+        ----------
+        quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
+            Number of mu-law spaced quantization channels to quantize
+            to.
+
+        Returns
+        -------
+        AudioSignal
+            Quantized AudioSignal.
+        """
+        mu = quantization_channels - 1.0
+        mu = util.ensure_tensor(mu, ndim=3)
+
+        x = self.audio_data
+
+        # quantize
+        x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
+        x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
+
+        # unquantize
+        x = (x / mu) * 2 - 1.0
+        x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
+
+        residual = (self.audio_data - x).detach()
+        self.audio_data = self.audio_data - residual
+        return self
+
+    def __matmul__(self, other):
+        return self.convolve(other)
+
+
+class ImpulseResponseMixin:
+    """These functions are generally only used with AudioSignals that are derived
+    from impulse responses, not other sources like music or speech. These methods
+    are used to replicate the data augmentation described in [1].
+
+    1.  Bryan, Nicholas J. "Impulse response data augmentation and deep
+        neural networks for blind room acoustic parameter estimation."
+        ICASSP 2020-2020 IEEE International Conference on Acoustics,
+        Speech and Signal Processing (ICASSP). IEEE, 2020.
+    """
+
+    def decompose_ir(self):
+        """Decomposes an impulse response into early and late
+        field responses.
+        """
+        # Equations 1 and 2
+        # -----------------
+        # Breaking up into early
+        # response + late field response.
+
+        td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
+        t0 = int(self.sample_rate * 0.0025)
+
+        idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
+        idx = idx.expand(self.batch_size, -1, -1)
+        early_idx = (idx >= td - t0) * (idx <= td + t0)
+
+        early_response = torch.zeros_like(self.audio_data, device=self.device)
+        early_response[early_idx] = self.audio_data[early_idx]
+
+        late_idx = ~early_idx
+        late_field = torch.zeros_like(self.audio_data, device=self.device)
+        late_field[late_idx] = self.audio_data[late_idx]
+
+        # Equation 4
+        # ----------
+        # Decompose early response into windowed
+        # direct path and windowed residual.
+
+        window = torch.zeros_like(self.audio_data, device=self.device)
+        for idx in range(self.batch_size):
+            window_idx = early_idx[idx, 0].nonzero()
+            window[idx, ..., window_idx] = self.get_window(
+                "hann", window_idx.shape[-1], self.device
+            )
+        return early_response, late_field, window
+
+    def measure_drr(self):
+        """Measures the direct-to-reverberant ratio of the impulse
+        response.
+
+        Returns
+        -------
+        float
+            Direct-to-reverberant ratio
+        """
+        early_response, late_field, _ = self.decompose_ir()
+        num = (early_response**2).sum(dim=-1)
+        den = (late_field**2).sum(dim=-1)
+        drr = 10 * torch.log10(num / den)
+        return drr
+
+    @staticmethod
+    def solve_alpha(early_response, late_field, wd, target_drr):
+        """Used to solve for the alpha value, which is used
+        to alter the drr.
+        """
+        # Equation 5
+        # ----------
+        # Apply the good ol' quadratic formula.
+
+        wd_sq = wd**2
+        wd_sq_1 = (1 - wd) ** 2
+        e_sq = early_response**2
+        l_sq = late_field**2
+        a = (wd_sq * e_sq).sum(dim=-1)
+        b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
+        c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
+            dim=-1
+        )
+
+        expr = ((b**2) - 4 * a * c).sqrt()
+        alpha = torch.maximum(
+            (-b - expr) / (2 * a),
+            (-b + expr) / (2 * a),
+        )
+        return alpha
+
+    def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
+        """Alters the direct-to-reverberant ratio of the impulse response.
+
+        Parameters
+        ----------
+        drr : typing.Union[torch.Tensor, np.ndarray, float]
+            Direct-to-reverberant ratio that impulse response will be
+            altered to, if specified, by default None
+
+        Returns
+        -------
+        AudioSignal
+            Altered impulse response.
+        """
+        drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
+
+        early_response, late_field, window = self.decompose_ir()
+        alpha = self.solve_alpha(early_response, late_field, window, drr)
+        min_alpha = (
+            late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
+        )
+        alpha = torch.maximum(alpha, min_alpha)[..., None]
+
+        aug_ir_data = (
+            alpha * window * early_response
+            + ((1 - window) * early_response)
+            + late_field
+        )
+        self.audio_data = aug_ir_data
+        self.ensure_max_of_audio()
+        return self
diff --git a/audiotools/core/ffmpeg.py b/audiotools/core/ffmpeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..baf27ccca25ffbf9e915aa870ca8797c37187cdd
--- /dev/null
+++ b/audiotools/core/ffmpeg.py
@@ -0,0 +1,204 @@
+import json
+import shlex
+import subprocess
+import tempfile
+from pathlib import Path
+from typing import Tuple
+
+import ffmpy
+import numpy as np
+import torch
+
+
+def r128stats(filepath: str, quiet: bool):
+    """Takes a path to an audio file, returns a dict with the loudness
+    stats computed by the ffmpeg ebur128 filter.
+
+    Parameters
+    ----------
+    filepath : str
+        Path to compute loudness stats on.
+    quiet : bool
+        Whether to show FFMPEG output during computation.
+
+    Returns
+    -------
+    dict
+        Dictionary containing loudness stats.
+    """
+    ffargs = [
+        "ffmpeg",
+        "-nostats",
+        "-i",
+        filepath,
+        "-filter_complex",
+        "ebur128",
+        "-f",
+        "null",
+        "-",
+    ]
+    if quiet:
+        ffargs += ["-hide_banner"]
+    proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
+    stats = proc.communicate()[1]
+    summary_index = stats.rfind("Summary:")
+
+    summary_list = stats[summary_index:].split()
+    i_lufs = float(summary_list[summary_list.index("I:") + 1])
+    i_thresh = float(summary_list[summary_list.index("I:") + 4])
+    lra = float(summary_list[summary_list.index("LRA:") + 1])
+    lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
+    lra_low = float(summary_list[summary_list.index("low:") + 1])
+    lra_high = float(summary_list[summary_list.index("high:") + 1])
+    stats_dict = {
+        "I": i_lufs,
+        "I Threshold": i_thresh,
+        "LRA": lra,
+        "LRA Threshold": lra_thresh,
+        "LRA Low": lra_low,
+        "LRA High": lra_high,
+    }
+
+    return stats_dict
+
+
+def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
+    """Given a path to a file, returns the start time offset and codec of
+    the first audio stream.
+    """
+    ff = ffmpy.FFprobe(
+        inputs={path: None},
+        global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
+    )
+    streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
+    seconds_offset = 0.0
+    codec = None
+
+    # Get the offset and codec of the first audio stream we find
+    # and return its start time, if it has one.
+    for stream in streams:
+        if stream["codec_type"] == "audio":
+            seconds_offset = stream.get("start_time", 0.0)
+            codec = stream.get("codec_name")
+            break
+    return float(seconds_offset), codec
+
+
+class FFMPEGMixin:
+    _loudness = None
+
+    def ffmpeg_loudness(self, quiet: bool = True):
+        """Computes loudness of audio file using FFMPEG.
+
+        Parameters
+        ----------
+        quiet : bool, optional
+            Whether to show FFMPEG output during computation,
+            by default True
+
+        Returns
+        -------
+        torch.Tensor
+            Loudness of every item in the batch, computed via
+            FFMPEG.
+        """
+        loudness = []
+
+        with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+            for i in range(self.batch_size):
+                self[i].write(f.name)
+                loudness_stats = r128stats(f.name, quiet=quiet)
+                loudness.append(loudness_stats["I"])
+
+        self._loudness = torch.from_numpy(np.array(loudness)).float()
+        return self.loudness()
+
+    def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
+        """Resamples AudioSignal using FFMPEG. More memory-efficient
+        than using julius.resample for long audio files.
+
+        Parameters
+        ----------
+        sample_rate : int
+            Sample rate to resample to.
+        quiet : bool, optional
+            Whether to show FFMPEG output during computation,
+            by default True
+
+        Returns
+        -------
+        AudioSignal
+            Resampled AudioSignal.
+        """
+        from audiotools import AudioSignal
+
+        if sample_rate == self.sample_rate:
+            return self
+
+        with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+            self.write(f.name)
+            f_out = f.name.replace("wav", "rs.wav")
+            command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
+            if quiet:
+                command += " -hide_banner -loglevel error"
+            subprocess.check_call(shlex.split(command))
+            resampled = AudioSignal(f_out)
+            Path.unlink(Path(f_out))
+        return resampled
+
+    @classmethod
+    def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
+        """Loads AudioSignal object after decoding it to a wav file using FFMPEG.
+        Useful for loading audio that isn't covered by librosa's loading mechanism. Also
+        useful for loading mp3 files, without any offset.
+
+        Parameters
+        ----------
+        audio_path : str
+            Path to load AudioSignal from.
+        quiet : bool, optional
+            Whether to show FFMPEG output during computation,
+            by default True
+
+        Returns
+        -------
+        AudioSignal
+            AudioSignal loaded from file with FFMPEG.
+        """
+        audio_path = str(audio_path)
+        with tempfile.TemporaryDirectory() as d:
+            wav_file = str(Path(d) / "extracted.wav")
+            padded_wav = str(Path(d) / "padded.wav")
+
+            global_options = "-y"
+            if quiet:
+                global_options += " -loglevel error"
+
+            ff = ffmpy.FFmpeg(
+                inputs={audio_path: None},
+                outputs={wav_file: None},
+                global_options=global_options,
+            )
+            ff.run()
+
+            # We pad the file using the start time offset in case it's an audio
+            # stream starting at some offset in a video container.
+            pad, codec = ffprobe_offset_and_codec(audio_path)
+
+            # For mp3s, don't pad files with discrepancies less than 0.027s -
+            # it's likely due to codec latency. The amount of latency introduced
+            # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
+            # here slightly above that.
+            # Source: https://lame.sourceforge.io/tech-FAQ.txt.
+            if codec == "mp3" and pad < 0.027:
+                pad = 0.0
+            ff = ffmpy.FFmpeg(
+                inputs={wav_file: None},
+                outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
+                global_options=global_options,
+            )
+            ff.run()
+
+            signal = cls(padded_wav, **kwargs)
+
+        return signal
diff --git a/audiotools/core/loudness.py b/audiotools/core/loudness.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3ee2675d7cb71f4c00106b0c1e901b8e51b842
--- /dev/null
+++ b/audiotools/core/loudness.py
@@ -0,0 +1,320 @@
+import copy
+
+import julius
+import numpy as np
+import scipy
+import torch
+import torch.nn.functional as F
+import torchaudio
+
+
+class Meter(torch.nn.Module):
+    """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
+
+    Parameters
+    ----------
+    rate : int
+        Sample rate of audio.
+    filter_class : str, optional
+        Class of weighting filter used.
+        K-weighting' (default), 'Fenton/Lee 1'
+        'Fenton/Lee 2', 'Dash et al.'
+        by default "K-weighting"
+    block_size : float, optional
+        Gating block size in seconds, by default 0.400
+    zeros : int, optional
+         Number of zeros to use in FIR approximation of
+         IIR filters, by default 512
+    use_fir : bool, optional
+        Whether to use FIR approximation or exact IIR formulation.
+        If computing on GPU, ``use_fir=True`` will be used, as its
+        much faster, by default False
+    """
+
+    def __init__(
+        self,
+        rate: int,
+        filter_class: str = "K-weighting",
+        block_size: float = 0.400,
+        zeros: int = 512,
+        use_fir: bool = False,
+    ):
+        super().__init__()
+
+        self.rate = rate
+        self.filter_class = filter_class
+        self.block_size = block_size
+        self.use_fir = use_fir
+
+        G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
+        self.register_buffer("G", G)
+
+        # Compute impulse responses so that filtering is fast via
+        # a convolution at runtime, on GPU, unlike lfilter.
+        impulse = np.zeros((zeros,))
+        impulse[..., 0] = 1.0
+
+        firs = np.zeros((len(self._filters), 1, zeros))
+        passband_gain = torch.zeros(len(self._filters))
+
+        for i, (_, filter_stage) in enumerate(self._filters.items()):
+            firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
+            passband_gain[i] = filter_stage.passband_gain
+
+        firs = torch.from_numpy(firs[..., ::-1].copy()).float()
+
+        self.register_buffer("firs", firs)
+        self.register_buffer("passband_gain", passband_gain)
+
+    def apply_filter_gpu(self, data: torch.Tensor):
+        """Performs FIR approximation of loudness computation.
+
+        Parameters
+        ----------
+        data : torch.Tensor
+            Audio data of shape (nb, nch, nt).
+
+        Returns
+        -------
+        torch.Tensor
+            Filtered audio data.
+        """
+        # Data is of shape (nb, nch, nt)
+        # Reshape to (nb*nch, 1, nt)
+        nb, nt, nch = data.shape
+        data = data.permute(0, 2, 1)
+        data = data.reshape(nb * nch, 1, nt)
+
+        # Apply padding
+        pad_length = self.firs.shape[-1]
+
+        # Apply filtering in sequence
+        for i in range(self.firs.shape[0]):
+            data = F.pad(data, (pad_length, pad_length))
+            data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
+            data = self.passband_gain[i] * data
+            data = data[..., 1 : nt + 1]
+
+        data = data.permute(0, 2, 1)
+        data = data[:, :nt, :]
+        return data
+
+    def apply_filter_cpu(self, data: torch.Tensor):
+        """Performs IIR formulation of loudness computation.
+
+        Parameters
+        ----------
+        data : torch.Tensor
+            Audio data of shape (nb, nch, nt).
+
+        Returns
+        -------
+        torch.Tensor
+            Filtered audio data.
+        """
+        for _, filter_stage in self._filters.items():
+            passband_gain = filter_stage.passband_gain
+
+            a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
+            b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
+
+            _data = data.permute(0, 2, 1)
+            filtered = torchaudio.functional.lfilter(
+                _data, a_coeffs, b_coeffs, clamp=False
+            )
+            data = passband_gain * filtered.permute(0, 2, 1)
+        return data
+
+    def apply_filter(self, data: torch.Tensor):
+        """Applies filter on either CPU or GPU, depending
+        on if the audio is on GPU or is on CPU, or if
+        ``self.use_fir`` is True.
+
+        Parameters
+        ----------
+        data : torch.Tensor
+            Audio data of shape (nb, nch, nt).
+
+        Returns
+        -------
+        torch.Tensor
+            Filtered audio data.
+        """
+        if data.is_cuda or self.use_fir:
+            data = self.apply_filter_gpu(data)
+        else:
+            data = self.apply_filter_cpu(data)
+        return data
+
+    def forward(self, data: torch.Tensor):
+        """Computes integrated loudness of data.
+
+        Parameters
+        ----------
+        data : torch.Tensor
+            Audio data of shape (nb, nch, nt).
+
+        Returns
+        -------
+        torch.Tensor
+            Filtered audio data.
+        """
+        return self.integrated_loudness(data)
+
+    def _unfold(self, input_data):
+        T_g = self.block_size
+        overlap = 0.75  # overlap of 75% of the block duration
+        step = 1.0 - overlap  # step size by percentage
+
+        kernel_size = int(T_g * self.rate)
+        stride = int(T_g * self.rate * step)
+        unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
+        unfolded = unfolded.transpose(-1, -2)
+
+        return unfolded
+
+    def integrated_loudness(self, data: torch.Tensor):
+        """Computes integrated loudness of data.
+
+        Parameters
+        ----------
+        data : torch.Tensor
+            Audio data of shape (nb, nch, nt).
+
+        Returns
+        -------
+        torch.Tensor
+            Filtered audio data.
+        """
+        if not torch.is_tensor(data):
+            data = torch.from_numpy(data).float()
+        else:
+            data = data.float()
+
+        input_data = copy.copy(data)
+        # Data always has a batch and channel dimension.
+        # Is of shape (nb, nt, nch)
+        if input_data.ndim < 2:
+            input_data = input_data.unsqueeze(-1)
+        if input_data.ndim < 3:
+            input_data = input_data.unsqueeze(0)
+
+        nb, nt, nch = input_data.shape
+
+        # Apply frequency weighting filters - account
+        # for the acoustic respose of the head and auditory system
+        input_data = self.apply_filter(input_data)
+
+        G = self.G  # channel gains
+        T_g = self.block_size  # 400 ms gating block standard
+        Gamma_a = -70.0  # -70 LKFS = absolute loudness threshold
+
+        unfolded = self._unfold(input_data)
+
+        z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
+        l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
+        l = l.expand_as(z)
+
+        # find gating block indices above absolute threshold
+        z_avg_gated = z
+        z_avg_gated[l <= Gamma_a] = 0
+        masked = l > Gamma_a
+        z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
+
+        # calculate the relative threshold value (see eq. 6)
+        Gamma_r = (
+            -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
+        )
+        Gamma_r = Gamma_r[:, None, None]
+        Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
+
+        # find gating block indices above relative and absolute thresholds  (end of eq. 7)
+        z_avg_gated = z
+        z_avg_gated[l <= Gamma_a] = 0
+        z_avg_gated[l <= Gamma_r] = 0
+        masked = (l > Gamma_a) * (l > Gamma_r)
+        z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
+
+        # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
+        # z_avg_gated = torch.nan_to_num(z_avg_gated)
+        z_avg_gated = torch.where(
+            z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
+        )
+        z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
+        z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
+
+        LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
+        return LUFS.float()
+
+    @property
+    def filter_class(self):
+        return self._filter_class
+
+    @filter_class.setter
+    def filter_class(self, value):
+        from pyloudnorm import Meter
+
+        meter = Meter(self.rate)
+        meter.filter_class = value
+        self._filter_class = value
+        self._filters = meter._filters
+
+
+class LoudnessMixin:
+    _loudness = None
+    MIN_LOUDNESS = -70
+    """Minimum loudness possible."""
+
+    def loudness(
+        self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
+    ):
+        """Calculates loudness using an implementation of ITU-R BS.1770-4.
+        Allows control over gating block size and frequency weighting filters for
+        additional control. Measure the integrated gated loudness of a signal.
+
+        API is derived from PyLoudnorm, but this implementation is ported to PyTorch
+        and is tensorized across batches. When on GPU, an FIR approximation of the IIR
+        filters is used to compute loudness for speed.
+
+        Uses the weighting filters and block size defined by the meter
+        the integrated loudness is measured based upon the gating algorithm
+        defined in the ITU-R BS.1770-4 specification.
+
+        Parameters
+        ----------
+        filter_class : str, optional
+            Class of weighting filter used.
+            K-weighting' (default), 'Fenton/Lee 1'
+            'Fenton/Lee 2', 'Dash et al.'
+            by default "K-weighting"
+        block_size : float, optional
+            Gating block size in seconds, by default 0.400
+        kwargs : dict, optional
+            Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
+
+        Returns
+        -------
+        torch.Tensor
+            Loudness of audio data.
+        """
+        if self._loudness is not None:
+            return self._loudness.to(self.device)
+        original_length = self.signal_length
+        if self.signal_duration < 0.5:
+            pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
+            self.zero_pad(0, pad_len)
+
+        # create BS.1770 meter
+        meter = Meter(
+            self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
+        )
+        meter = meter.to(self.device)
+        # measure loudness
+        loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
+        self.truncate_samples(original_length)
+        min_loudness = (
+            torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
+        )
+        self._loudness = torch.maximum(loudness, min_loudness)
+
+        return self._loudness.to(self.device)
diff --git a/audiotools/core/playback.py b/audiotools/core/playback.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0f21aaa392494f35305c0084c05b87667ea14d
--- /dev/null
+++ b/audiotools/core/playback.py
@@ -0,0 +1,252 @@
+"""
+These are utilities that allow one to embed an AudioSignal
+as a playable object in a Jupyter notebook, or to play audio from
+the terminal, etc.
+"""  # fmt: skip
+import base64
+import io
+import random
+import string
+import subprocess
+from tempfile import NamedTemporaryFile
+
+import importlib_resources as pkg_resources
+
+from . import templates
+from .util import _close_temp_files
+from .util import format_figure
+
+headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
+widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
+
+DEFAULT_EXTENSION = ".wav"
+
+
+def _check_imports():  # pragma: no cover
+    try:
+        import ffmpy
+    except:
+        ffmpy = False
+
+    try:
+        import IPython
+    except:
+        raise ImportError("IPython must be installed in order to use this function!")
+    return ffmpy, IPython
+
+
+class PlayMixin:
+    def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
+        """Embeds audio as a playable audio embed in a notebook, or HTML
+        document, etc.
+
+        Parameters
+        ----------
+        ext : str, optional
+            Extension to use when saving the audio, by default ".wav"
+        display : bool, optional
+            This controls whether or not to display the audio when called. This
+            is used when the embed is the last line in a Jupyter cell, to prevent
+            the audio from being embedded twice, by default True
+        return_html : bool, optional
+            Whether to return the data wrapped in an HTML audio element, by default False
+
+        Returns
+        -------
+        str
+            Either the element for display, or the HTML string of it.
+        """
+        if ext is None:
+            ext = DEFAULT_EXTENSION
+        ext = f".{ext}" if not ext.startswith(".") else ext
+        ffmpy, IPython = _check_imports()
+        sr = self.sample_rate
+        tmpfiles = []
+
+        with _close_temp_files(tmpfiles):
+            tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
+            tmpfiles.append(tmp_wav)
+            self.write(tmp_wav.name)
+            if ext != ".wav" and ffmpy:
+                tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
+                tmpfiles.append(tmp_wav)
+                ff = ffmpy.FFmpeg(
+                    inputs={tmp_wav.name: None},
+                    outputs={
+                        tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
+                    },
+                )
+                ff.run()
+            else:
+                tmp_converted = tmp_wav
+
+            audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
+            if display:
+                IPython.display.display(audio_element)
+
+        if return_html:
+            audio_element = (
+                f"<audio "
+                f"  controls "
+                f"  src='{audio_element.src_attr()}'> "
+                f"</audio> "
+            )
+        return audio_element
+
+    def widget(
+        self,
+        title: str = None,
+        ext: str = ".wav",
+        add_headers: bool = True,
+        player_width: str = "100%",
+        margin: str = "10px",
+        plot_fn: str = "specshow",
+        return_html: bool = False,
+        **kwargs,
+    ):
+        """Creates a playable widget with spectrogram. Inspired (heavily) by
+        https://sjvasquez.github.io/blog/melnet/.
+
+        Parameters
+        ----------
+        title : str, optional
+            Title of plot, placed in upper right of top-most axis.
+        ext : str, optional
+            Extension for embedding, by default ".mp3"
+        add_headers : bool, optional
+            Whether or not to add headers (use for first embed, False for later embeds), by default True
+        player_width : str, optional
+            Width of the player, as a string in a CSS rule, by default "100%"
+        margin : str, optional
+            Margin on all sides of player, by default "10px"
+        plot_fn : function, optional
+            Plotting function to use (by default self.specshow).
+        return_html : bool, optional
+            Whether to return the data wrapped in an HTML audio element, by default False
+        kwargs : dict, optional
+            Keyword arguments to plot_fn (by default self.specshow).
+
+        Returns
+        -------
+        HTML
+            HTML object.
+        """
+        import matplotlib.pyplot as plt
+
+        def _save_fig_to_tag():
+            buffer = io.BytesIO()
+
+            plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
+            plt.close()
+
+            buffer.seek(0)
+            data_uri = base64.b64encode(buffer.read()).decode("ascii")
+            tag = "data:image/png;base64,{0}".format(data_uri)
+
+            return tag
+
+        _, IPython = _check_imports()
+
+        header_html = ""
+
+        if add_headers:
+            header_html = headers.replace("PLAYER_WIDTH", str(player_width))
+            header_html = header_html.replace("MARGIN", str(margin))
+            IPython.display.display(IPython.display.HTML(header_html))
+
+        widget_html = widget
+        if isinstance(plot_fn, str):
+            plot_fn = getattr(self, plot_fn)
+            kwargs["title"] = title
+        plot_fn(**kwargs)
+
+        fig = plt.gcf()
+        pixels = fig.get_size_inches() * fig.dpi
+
+        tag = _save_fig_to_tag()
+
+        # Make the source image for the levels
+        self.specshow()
+        format_figure((12, 1.5))
+        levels_tag = _save_fig_to_tag()
+
+        player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
+
+        audio_elem = self.embed(ext=ext, display=False)
+        widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
+        widget_html = widget_html.replace("IMAGE_SRC", tag)
+        widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
+        widget_html = widget_html.replace("PLAYER_ID", player_id)
+
+        # Calculate width/height of figure based on figure size.
+        widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
+        widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
+
+        IPython.display.display(IPython.display.HTML(widget_html))
+
+        if return_html:
+            html = header_html if add_headers else ""
+            html += widget_html
+            return html
+
+    def play(self):
+        """
+        Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
+        Otherwise, will fail. The audio signal is written to a temporary file
+        and then played with ffplay.
+        """
+        tmpfiles = []
+        with _close_temp_files(tmpfiles):
+            tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
+            tmpfiles.append(tmp_wav)
+            self.write(tmp_wav.name)
+            print(self)
+            subprocess.call(
+                [
+                    "ffplay",
+                    "-nodisp",
+                    "-autoexit",
+                    "-hide_banner",
+                    "-loglevel",
+                    "error",
+                    tmp_wav.name,
+                ]
+            )
+        return self
+
+
+if __name__ == "__main__":  # pragma: no cover
+    from audiotools import AudioSignal
+
+    signal = AudioSignal(
+        "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
+    )
+
+    wave_html = signal.widget(
+        "Waveform",
+        plot_fn="waveplot",
+        return_html=True,
+    )
+
+    spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
+
+    combined_html = signal.widget(
+        "Waveform + spectrogram",
+        plot_fn="wavespec",
+        return_html=True,
+        add_headers=False,
+    )
+
+    signal.low_pass(8000)
+    lowpass_html = signal.widget(
+        "Lowpassed audio",
+        plot_fn="wavespec",
+        return_html=True,
+        add_headers=False,
+    )
+
+    with open("/tmp/index.html", "w") as f:
+        f.write(wave_html)
+        f.write(spec_html)
+        f.write(combined_html)
+        f.write(lowpass_html)
diff --git a/audiotools/core/templates/__init__.py b/audiotools/core/templates/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audiotools/core/templates/headers.html b/audiotools/core/templates/headers.html
new file mode 100644
index 0000000000000000000000000000000000000000..9eaef4a94d575f7826608ad63dcc77fab13b7b19
--- /dev/null
+++ b/audiotools/core/templates/headers.html
@@ -0,0 +1,322 @@
+<style>
+    .player {
+        width: 100%;
+        /*border: 1px solid black;*/
+        margin: 10px;
+    }
+
+    .underlay img {
+        width: 100%;
+        height: 100%;
+    }
+
+    .spectrogram {
+        height: 0;
+        width: 100%;
+        position: relative;
+    }
+
+    .audio-controls {
+        width: 100%;
+        height: 54px;
+        display: flex;
+        /*border-top: 1px solid black;*/
+        /*background-color: rgb(241, 243, 244);*/
+        background-color: rgb(248, 248, 248);
+        background-color: rgb(253, 253, 254);
+        border: 1px solid rgb(205, 208, 211);
+        margin-top: 20px;
+        /*border: 1px solid black;*/
+        border-radius: 30px;
+
+    }
+
+    .play-img {
+        margin: auto;
+        height: 45%;
+        width: 45%;
+        display: block;
+    }
+
+    .download-img {
+        margin: auto;
+        height: 100%;
+        width: 100%;
+        display: block;
+    }
+
+    .pause-img {
+        margin: auto;
+        height: 45%;
+        width: 45%;
+        display: none
+    }
+
+    .playpause {
+        margin:11px 11px 11px 11px;
+        width: 32px;
+        min-width: 32px;
+        height: 32px;
+        /*background-color: rgb(241, 243, 244);*/
+        background-color: rgba(0, 0, 0, 0.0);
+        /*border-right: 1px solid black;*/
+        /*border: 1px solid red;*/
+        border-radius: 16px;
+        color: black;
+        transition: 0.25s;
+        box-sizing: border-box !important;
+    }
+
+    .download {
+        margin:11px 11px 11px 11px;
+        width: 32px;
+        min-width: 32px;
+        height: 32px;
+        /*background-color: rgb(241, 243, 244);*/
+        background-color: rgba(0, 0, 0, 0.0);
+        /*border-right: 1px solid black;*/
+        /*border: 1px solid red;*/
+        border-radius: 16px;
+        color: black;
+        transition: 0.25s;
+        box-sizing: border-box !important;
+    }
+
+    /*.playpause:disabled {
+        background-color: red;
+    }*/
+
+    .playpause:hover {
+        background-color: rgba(10, 20, 30, 0.03);
+    }
+
+    .playpause:focus {
+        outline:none;
+    }
+
+    .response {
+        padding:0px 20px 0px 0px;
+        width: calc(100% - 132px);
+        height: 100%;
+
+        /*border: 1px solid red;*/
+        /*border-bottom: 1px solid rgb(89, 89, 89);*/
+    }
+
+    .response-canvas {
+        height: 100%;
+        width: 100%;
+    }
+
+
+    .underlay {
+        height: 100%;
+        width: 100%;
+        position: absolute;
+        top: 0;
+        left: 0;
+    }
+
+    .overlay{
+        width: 0%;
+        height:100%;
+        top: 0;
+        right: 0px;
+
+        background:rgba(255, 255, 255, 0.15);
+        overflow:hidden;
+        position: absolute;
+        z-index: 10;
+        border-left: solid 1px rgba(0, 0, 0, 0.664);
+
+        position: absolute;
+        pointer-events: none;
+    }
+</style>
+
+<script>
+    !function(t){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=t();else if("function"==typeof define&&define.amd)define([],t);else{("undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:this).pako=t()}}(function(){return function(){return function t(e,a,i){function n(s,o){if(!a[s]){if(!e[s]){var l="function"==typeof require&&require;if(!o&&l)return l(s,!0);if(r)return r(s,!0);var h=new Error("Cannot find module '"+s+"'");throw h.code="MODULE_NOT_FOUND",h}var d=a[s]={exports:{}};e[s][0].call(d.exports,function(t){return n(e[s][1][t]||t)},d,d.exports,t,e,a,i)}return a[s].exports}for(var r="function"==typeof require&&require,s=0;s<i.length;s++)n(i[s]);return n}}()({1:[function(t,e,a){"use strict";var i=t("./zlib/deflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/messages"),o=t("./zlib/zstream"),l=Object.prototype.toString,h=0,d=-1,f=0,_=8;function u(t){if(!(this instanceof u))return new u(t);this.options=n.assign({level:d,method:_,chunkSize:16384,windowBits:15,memLevel:8,strategy:f,to:""},t||{});var e=this.options;e.raw&&e.windowBits>0?e.windowBits=-e.windowBits:e.gzip&&e.windowBits>0&&e.windowBits<16&&(e.windowBits+=16),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new o,this.strm.avail_out=0;var a=i.deflateInit2(this.strm,e.level,e.method,e.windowBits,e.memLevel,e.strategy);if(a!==h)throw new Error(s[a]);if(e.header&&i.deflateSetHeader(this.strm,e.header),e.dictionary){var c;if(c="string"==typeof e.dictionary?r.string2buf(e.dictionary):"[object ArrayBuffer]"===l.call(e.dictionary)?new Uint8Array(e.dictionary):e.dictionary,(a=i.deflateSetDictionary(this.strm,c))!==h)throw new Error(s[a]);this._dict_set=!0}}function c(t,e){var a=new u(e);if(a.push(t,!0),a.err)throw a.msg||s[a.err];return a.result}u.prototype.push=function(t,e){var a,s,o=this.strm,d=this.options.chunkSize;if(this.ended)return!1;s=e===~~e?e:!0===e?4:0,"string"==typeof t?o.input=r.string2buf(t):"[object ArrayBuffer]"===l.call(t)?o.input=new Uint8Array(t):o.input=t,o.next_in=0,o.avail_in=o.input.length;do{if(0===o.avail_out&&(o.output=new n.Buf8(d),o.next_out=0,o.avail_out=d),1!==(a=i.deflate(o,s))&&a!==h)return this.onEnd(a),this.ended=!0,!1;0!==o.avail_out&&(0!==o.avail_in||4!==s&&2!==s)||("string"===this.options.to?this.onData(r.buf2binstring(n.shrinkBuf(o.output,o.next_out))):this.onData(n.shrinkBuf(o.output,o.next_out)))}while((o.avail_in>0||0===o.avail_out)&&1!==a);return 4===s?(a=i.deflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===h):2!==s||(this.onEnd(h),o.avail_out=0,!0)},u.prototype.onData=function(t){this.chunks.push(t)},u.prototype.onEnd=function(t){t===h&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Deflate=u,a.deflate=c,a.deflateRaw=function(t,e){return(e=e||{}).raw=!0,c(t,e)},a.gzip=function(t,e){return(e=e||{}).gzip=!0,c(t,e)}},{"./utils/common":3,"./utils/strings":4,"./zlib/deflate":8,"./zlib/messages":13,"./zlib/zstream":15}],2:[function(t,e,a){"use strict";var i=t("./zlib/inflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/constants"),o=t("./zlib/messages"),l=t("./zlib/zstream"),h=t("./zlib/gzheader"),d=Object.prototype.toString;function f(t){if(!(this instanceof f))return new f(t);this.options=n.assign({chunkSize:16384,windowBits:0,to:""},t||{});var e=this.options;e.raw&&e.windowBits>=0&&e.windowBits<16&&(e.windowBits=-e.windowBits,0===e.windowBits&&(e.windowBits=-15)),!(e.windowBits>=0&&e.windowBits<16)||t&&t.windowBits||(e.windowBits+=32),e.windowBits>15&&e.windowBits<48&&0==(15&e.windowBits)&&(e.windowBits|=15),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new l,this.strm.avail_out=0;var a=i.inflateInit2(this.strm,e.windowBits);if(a!==s.Z_OK)throw new Error(o[a]);if(this.header=new h,i.inflateGetHeader(this.strm,this.header),e.dictionary&&("string"==typeof e.dictionary?e.dictionary=r.string2buf(e.dictionary):"[object ArrayBuffer]"===d.call(e.dictionary)&&(e.dictionary=new Uint8Array(e.dictionary)),e.raw&&(a=i.inflateSetDictionary(this.strm,e.dictionary))!==s.Z_OK))throw new Error(o[a])}function _(t,e){var a=new f(e);if(a.push(t,!0),a.err)throw a.msg||o[a.err];return a.result}f.prototype.push=function(t,e){var a,o,l,h,f,_=this.strm,u=this.options.chunkSize,c=this.options.dictionary,b=!1;if(this.ended)return!1;o=e===~~e?e:!0===e?s.Z_FINISH:s.Z_NO_FLUSH,"string"==typeof t?_.input=r.binstring2buf(t):"[object ArrayBuffer]"===d.call(t)?_.input=new Uint8Array(t):_.input=t,_.next_in=0,_.avail_in=_.input.length;do{if(0===_.avail_out&&(_.output=new n.Buf8(u),_.next_out=0,_.avail_out=u),(a=i.inflate(_,s.Z_NO_FLUSH))===s.Z_NEED_DICT&&c&&(a=i.inflateSetDictionary(this.strm,c)),a===s.Z_BUF_ERROR&&!0===b&&(a=s.Z_OK,b=!1),a!==s.Z_STREAM_END&&a!==s.Z_OK)return this.onEnd(a),this.ended=!0,!1;_.next_out&&(0!==_.avail_out&&a!==s.Z_STREAM_END&&(0!==_.avail_in||o!==s.Z_FINISH&&o!==s.Z_SYNC_FLUSH)||("string"===this.options.to?(l=r.utf8border(_.output,_.next_out),h=_.next_out-l,f=r.buf2string(_.output,l),_.next_out=h,_.avail_out=u-h,h&&n.arraySet(_.output,_.output,l,h,0),this.onData(f)):this.onData(n.shrinkBuf(_.output,_.next_out)))),0===_.avail_in&&0===_.avail_out&&(b=!0)}while((_.avail_in>0||0===_.avail_out)&&a!==s.Z_STREAM_END);return a===s.Z_STREAM_END&&(o=s.Z_FINISH),o===s.Z_FINISH?(a=i.inflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===s.Z_OK):o!==s.Z_SYNC_FLUSH||(this.onEnd(s.Z_OK),_.avail_out=0,!0)},f.prototype.onData=function(t){this.chunks.push(t)},f.prototype.onEnd=function(t){t===s.Z_OK&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Inflate=f,a.inflate=_,a.inflateRaw=function(t,e){return(e=e||{}).raw=!0,_(t,e)},a.ungzip=_},{"./utils/common":3,"./utils/strings":4,"./zlib/constants":6,"./zlib/gzheader":9,"./zlib/inflate":11,"./zlib/messages":13,"./zlib/zstream":15}],3:[function(t,e,a){"use strict";var i="undefined"!=typeof Uint8Array&&"undefined"!=typeof Uint16Array&&"undefined"!=typeof Int32Array;function n(t,e){return Object.prototype.hasOwnProperty.call(t,e)}a.assign=function(t){for(var e=Array.prototype.slice.call(arguments,1);e.length;){var a=e.shift();if(a){if("object"!=typeof a)throw new TypeError(a+"must be non-object");for(var i in a)n(a,i)&&(t[i]=a[i])}}return t},a.shrinkBuf=function(t,e){return t.length===e?t:t.subarray?t.subarray(0,e):(t.length=e,t)};var r={arraySet:function(t,e,a,i,n){if(e.subarray&&t.subarray)t.set(e.subarray(a,a+i),n);else for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){var e,a,i,n,r,s;for(i=0,e=0,a=t.length;e<a;e++)i+=t[e].length;for(s=new Uint8Array(i),n=0,e=0,a=t.length;e<a;e++)r=t[e],s.set(r,n),n+=r.length;return s}},s={arraySet:function(t,e,a,i,n){for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){return[].concat.apply([],t)}};a.setTyped=function(t){t?(a.Buf8=Uint8Array,a.Buf16=Uint16Array,a.Buf32=Int32Array,a.assign(a,r)):(a.Buf8=Array,a.Buf16=Array,a.Buf32=Array,a.assign(a,s))},a.setTyped(i)},{}],4:[function(t,e,a){"use strict";var i=t("./common"),n=!0,r=!0;try{String.fromCharCode.apply(null,[0])}catch(t){n=!1}try{String.fromCharCode.apply(null,new Uint8Array(1))}catch(t){r=!1}for(var s=new i.Buf8(256),o=0;o<256;o++)s[o]=o>=252?6:o>=248?5:o>=240?4:o>=224?3:o>=192?2:1;function l(t,e){if(e<65534&&(t.subarray&&r||!t.subarray&&n))return String.fromCharCode.apply(null,i.shrinkBuf(t,e));for(var a="",s=0;s<e;s++)a+=String.fromCharCode(t[s]);return a}s[254]=s[254]=1,a.string2buf=function(t){var e,a,n,r,s,o=t.length,l=0;for(r=0;r<o;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),l+=a<128?1:a<2048?2:a<65536?3:4;for(e=new i.Buf8(l),s=0,r=0;s<l;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),a<128?e[s++]=a:a<2048?(e[s++]=192|a>>>6,e[s++]=128|63&a):a<65536?(e[s++]=224|a>>>12,e[s++]=128|a>>>6&63,e[s++]=128|63&a):(e[s++]=240|a>>>18,e[s++]=128|a>>>12&63,e[s++]=128|a>>>6&63,e[s++]=128|63&a);return e},a.buf2binstring=function(t){return l(t,t.length)},a.binstring2buf=function(t){for(var e=new i.Buf8(t.length),a=0,n=e.length;a<n;a++)e[a]=t.charCodeAt(a);return e},a.buf2string=function(t,e){var a,i,n,r,o=e||t.length,h=new Array(2*o);for(i=0,a=0;a<o;)if((n=t[a++])<128)h[i++]=n;else if((r=s[n])>4)h[i++]=65533,a+=r-1;else{for(n&=2===r?31:3===r?15:7;r>1&&a<o;)n=n<<6|63&t[a++],r--;r>1?h[i++]=65533:n<65536?h[i++]=n:(n-=65536,h[i++]=55296|n>>10&1023,h[i++]=56320|1023&n)}return l(h,i)},a.utf8border=function(t,e){var a;for((e=e||t.length)>t.length&&(e=t.length),a=e-1;a>=0&&128==(192&t[a]);)a--;return a<0?e:0===a?e:a+s[t[a]]>e?a:e}},{"./common":3}],5:[function(t,e,a){"use strict";e.exports=function(t,e,a,i){for(var n=65535&t|0,r=t>>>16&65535|0,s=0;0!==a;){a-=s=a>2e3?2e3:a;do{r=r+(n=n+e[i++]|0)|0}while(--s);n%=65521,r%=65521}return n|r<<16|0}},{}],6:[function(t,e,a){"use strict";e.exports={Z_NO_FLUSH:0,Z_PARTIAL_FLUSH:1,Z_SYNC_FLUSH:2,Z_FULL_FLUSH:3,Z_FINISH:4,Z_BLOCK:5,Z_TREES:6,Z_OK:0,Z_STREAM_END:1,Z_NEED_DICT:2,Z_ERRNO:-1,Z_STREAM_ERROR:-2,Z_DATA_ERROR:-3,Z_BUF_ERROR:-5,Z_NO_COMPRESSION:0,Z_BEST_SPEED:1,Z_BEST_COMPRESSION:9,Z_DEFAULT_COMPRESSION:-1,Z_FILTERED:1,Z_HUFFMAN_ONLY:2,Z_RLE:3,Z_FIXED:4,Z_DEFAULT_STRATEGY:0,Z_BINARY:0,Z_TEXT:1,Z_UNKNOWN:2,Z_DEFLATED:8}},{}],7:[function(t,e,a){"use strict";var i=function(){for(var t,e=[],a=0;a<256;a++){t=a;for(var i=0;i<8;i++)t=1&t?3988292384^t>>>1:t>>>1;e[a]=t}return e}();e.exports=function(t,e,a,n){var r=i,s=n+a;t^=-1;for(var o=n;o<s;o++)t=t>>>8^r[255&(t^e[o])];return-1^t}},{}],8:[function(t,e,a){"use strict";var i,n=t("../utils/common"),r=t("./trees"),s=t("./adler32"),o=t("./crc32"),l=t("./messages"),h=0,d=1,f=3,_=4,u=5,c=0,b=1,g=-2,m=-3,w=-5,p=-1,v=1,k=2,y=3,x=4,z=0,B=2,S=8,E=9,A=15,Z=8,R=286,C=30,N=19,O=2*R+1,D=15,I=3,U=258,T=U+I+1,F=32,L=42,H=69,j=73,K=91,M=103,P=113,Y=666,q=1,G=2,X=3,W=4,J=3;function Q(t,e){return t.msg=l[e],e}function V(t){return(t<<1)-(t>4?9:0)}function $(t){for(var e=t.length;--e>=0;)t[e]=0}function tt(t){var e=t.state,a=e.pending;a>t.avail_out&&(a=t.avail_out),0!==a&&(n.arraySet(t.output,e.pending_buf,e.pending_out,a,t.next_out),t.next_out+=a,e.pending_out+=a,t.total_out+=a,t.avail_out-=a,e.pending-=a,0===e.pending&&(e.pending_out=0))}function et(t,e){r._tr_flush_block(t,t.block_start>=0?t.block_start:-1,t.strstart-t.block_start,e),t.block_start=t.strstart,tt(t.strm)}function at(t,e){t.pending_buf[t.pending++]=e}function it(t,e){t.pending_buf[t.pending++]=e>>>8&255,t.pending_buf[t.pending++]=255&e}function nt(t,e){var a,i,n=t.max_chain_length,r=t.strstart,s=t.prev_length,o=t.nice_match,l=t.strstart>t.w_size-T?t.strstart-(t.w_size-T):0,h=t.window,d=t.w_mask,f=t.prev,_=t.strstart+U,u=h[r+s-1],c=h[r+s];t.prev_length>=t.good_match&&(n>>=2),o>t.lookahead&&(o=t.lookahead);do{if(h[(a=e)+s]===c&&h[a+s-1]===u&&h[a]===h[r]&&h[++a]===h[r+1]){r+=2,a++;do{}while(h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&r<_);if(i=U-(_-r),r=_-U,i>s){if(t.match_start=e,s=i,i>=o)break;u=h[r+s-1],c=h[r+s]}}}while((e=f[e&d])>l&&0!=--n);return s<=t.lookahead?s:t.lookahead}function rt(t){var e,a,i,r,l,h,d,f,_,u,c=t.w_size;do{if(r=t.window_size-t.lookahead-t.strstart,t.strstart>=c+(c-T)){n.arraySet(t.window,t.window,c,c,0),t.match_start-=c,t.strstart-=c,t.block_start-=c,e=a=t.hash_size;do{i=t.head[--e],t.head[e]=i>=c?i-c:0}while(--a);e=a=c;do{i=t.prev[--e],t.prev[e]=i>=c?i-c:0}while(--a);r+=c}if(0===t.strm.avail_in)break;if(h=t.strm,d=t.window,f=t.strstart+t.lookahead,_=r,u=void 0,(u=h.avail_in)>_&&(u=_),a=0===u?0:(h.avail_in-=u,n.arraySet(d,h.input,h.next_in,u,f),1===h.state.wrap?h.adler=s(h.adler,d,u,f):2===h.state.wrap&&(h.adler=o(h.adler,d,u,f)),h.next_in+=u,h.total_in+=u,u),t.lookahead+=a,t.lookahead+t.insert>=I)for(l=t.strstart-t.insert,t.ins_h=t.window[l],t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+1])&t.hash_mask;t.insert&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+I-1])&t.hash_mask,t.prev[l&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=l,l++,t.insert--,!(t.lookahead+t.insert<I)););}while(t.lookahead<T&&0!==t.strm.avail_in)}function st(t,e){for(var a,i;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),0!==a&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a)),t.match_length>=I)if(i=r._tr_tally(t,t.strstart-t.match_start,t.match_length-I),t.lookahead-=t.match_length,t.match_length<=t.max_lazy_match&&t.lookahead>=I){t.match_length--;do{t.strstart++,t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart}while(0!=--t.match_length);t.strstart++}else t.strstart+=t.match_length,t.match_length=0,t.ins_h=t.window[t.strstart],t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+1])&t.hash_mask;else i=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++;if(i&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function ot(t,e){for(var a,i,n;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),t.prev_length=t.match_length,t.prev_match=t.match_start,t.match_length=I-1,0!==a&&t.prev_length<t.max_lazy_match&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a),t.match_length<=5&&(t.strategy===v||t.match_length===I&&t.strstart-t.match_start>4096)&&(t.match_length=I-1)),t.prev_length>=I&&t.match_length<=t.prev_length){n=t.strstart+t.lookahead-I,i=r._tr_tally(t,t.strstart-1-t.prev_match,t.prev_length-I),t.lookahead-=t.prev_length-1,t.prev_length-=2;do{++t.strstart<=n&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart)}while(0!=--t.prev_length);if(t.match_available=0,t.match_length=I-1,t.strstart++,i&&(et(t,!1),0===t.strm.avail_out))return q}else if(t.match_available){if((i=r._tr_tally(t,0,t.window[t.strstart-1]))&&et(t,!1),t.strstart++,t.lookahead--,0===t.strm.avail_out)return q}else t.match_available=1,t.strstart++,t.lookahead--}return t.match_available&&(i=r._tr_tally(t,0,t.window[t.strstart-1]),t.match_available=0),t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function lt(t,e,a,i,n){this.good_length=t,this.max_lazy=e,this.nice_length=a,this.max_chain=i,this.func=n}function ht(){this.strm=null,this.status=0,this.pending_buf=null,this.pending_buf_size=0,this.pending_out=0,this.pending=0,this.wrap=0,this.gzhead=null,this.gzindex=0,this.method=S,this.last_flush=-1,this.w_size=0,this.w_bits=0,this.w_mask=0,this.window=null,this.window_size=0,this.prev=null,this.head=null,this.ins_h=0,this.hash_size=0,this.hash_bits=0,this.hash_mask=0,this.hash_shift=0,this.block_start=0,this.match_length=0,this.prev_match=0,this.match_available=0,this.strstart=0,this.match_start=0,this.lookahead=0,this.prev_length=0,this.max_chain_length=0,this.max_lazy_match=0,this.level=0,this.strategy=0,this.good_match=0,this.nice_match=0,this.dyn_ltree=new n.Buf16(2*O),this.dyn_dtree=new n.Buf16(2*(2*C+1)),this.bl_tree=new n.Buf16(2*(2*N+1)),$(this.dyn_ltree),$(this.dyn_dtree),$(this.bl_tree),this.l_desc=null,this.d_desc=null,this.bl_desc=null,this.bl_count=new n.Buf16(D+1),this.heap=new n.Buf16(2*R+1),$(this.heap),this.heap_len=0,this.heap_max=0,this.depth=new n.Buf16(2*R+1),$(this.depth),this.l_buf=0,this.lit_bufsize=0,this.last_lit=0,this.d_buf=0,this.opt_len=0,this.static_len=0,this.matches=0,this.insert=0,this.bi_buf=0,this.bi_valid=0}function dt(t){var e;return t&&t.state?(t.total_in=t.total_out=0,t.data_type=B,(e=t.state).pending=0,e.pending_out=0,e.wrap<0&&(e.wrap=-e.wrap),e.status=e.wrap?L:P,t.adler=2===e.wrap?0:1,e.last_flush=h,r._tr_init(e),c):Q(t,g)}function ft(t){var e,a=dt(t);return a===c&&((e=t.state).window_size=2*e.w_size,$(e.head),e.max_lazy_match=i[e.level].max_lazy,e.good_match=i[e.level].good_length,e.nice_match=i[e.level].nice_length,e.max_chain_length=i[e.level].max_chain,e.strstart=0,e.block_start=0,e.lookahead=0,e.insert=0,e.match_length=e.prev_length=I-1,e.match_available=0,e.ins_h=0),a}function _t(t,e,a,i,r,s){if(!t)return g;var o=1;if(e===p&&(e=6),i<0?(o=0,i=-i):i>15&&(o=2,i-=16),r<1||r>E||a!==S||i<8||i>15||e<0||e>9||s<0||s>x)return Q(t,g);8===i&&(i=9);var l=new ht;return t.state=l,l.strm=t,l.wrap=o,l.gzhead=null,l.w_bits=i,l.w_size=1<<l.w_bits,l.w_mask=l.w_size-1,l.hash_bits=r+7,l.hash_size=1<<l.hash_bits,l.hash_mask=l.hash_size-1,l.hash_shift=~~((l.hash_bits+I-1)/I),l.window=new n.Buf8(2*l.w_size),l.head=new n.Buf16(l.hash_size),l.prev=new n.Buf16(l.w_size),l.lit_bufsize=1<<r+6,l.pending_buf_size=4*l.lit_bufsize,l.pending_buf=new n.Buf8(l.pending_buf_size),l.d_buf=1*l.lit_bufsize,l.l_buf=3*l.lit_bufsize,l.level=e,l.strategy=s,l.method=a,ft(t)}i=[new lt(0,0,0,0,function(t,e){var a=65535;for(a>t.pending_buf_size-5&&(a=t.pending_buf_size-5);;){if(t.lookahead<=1){if(rt(t),0===t.lookahead&&e===h)return q;if(0===t.lookahead)break}t.strstart+=t.lookahead,t.lookahead=0;var i=t.block_start+a;if((0===t.strstart||t.strstart>=i)&&(t.lookahead=t.strstart-i,t.strstart=i,et(t,!1),0===t.strm.avail_out))return q;if(t.strstart-t.block_start>=t.w_size-T&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):(t.strstart>t.block_start&&(et(t,!1),t.strm.avail_out),q)}),new lt(4,4,8,4,st),new lt(4,5,16,8,st),new lt(4,6,32,32,st),new lt(4,4,16,16,ot),new lt(8,16,32,32,ot),new lt(8,16,128,128,ot),new lt(8,32,128,256,ot),new lt(32,128,258,1024,ot),new lt(32,258,258,4096,ot)],a.deflateInit=function(t,e){return _t(t,e,S,A,Z,z)},a.deflateInit2=_t,a.deflateReset=ft,a.deflateResetKeep=dt,a.deflateSetHeader=function(t,e){return t&&t.state?2!==t.state.wrap?g:(t.state.gzhead=e,c):g},a.deflate=function(t,e){var a,n,s,l;if(!t||!t.state||e>u||e<0)return t?Q(t,g):g;if(n=t.state,!t.output||!t.input&&0!==t.avail_in||n.status===Y&&e!==_)return Q(t,0===t.avail_out?w:g);if(n.strm=t,a=n.last_flush,n.last_flush=e,n.status===L)if(2===n.wrap)t.adler=0,at(n,31),at(n,139),at(n,8),n.gzhead?(at(n,(n.gzhead.text?1:0)+(n.gzhead.hcrc?2:0)+(n.gzhead.extra?4:0)+(n.gzhead.name?8:0)+(n.gzhead.comment?16:0)),at(n,255&n.gzhead.time),at(n,n.gzhead.time>>8&255),at(n,n.gzhead.time>>16&255),at(n,n.gzhead.time>>24&255),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,255&n.gzhead.os),n.gzhead.extra&&n.gzhead.extra.length&&(at(n,255&n.gzhead.extra.length),at(n,n.gzhead.extra.length>>8&255)),n.gzhead.hcrc&&(t.adler=o(t.adler,n.pending_buf,n.pending,0)),n.gzindex=0,n.status=H):(at(n,0),at(n,0),at(n,0),at(n,0),at(n,0),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,J),n.status=P);else{var m=S+(n.w_bits-8<<4)<<8;m|=(n.strategy>=k||n.level<2?0:n.level<6?1:6===n.level?2:3)<<6,0!==n.strstart&&(m|=F),m+=31-m%31,n.status=P,it(n,m),0!==n.strstart&&(it(n,t.adler>>>16),it(n,65535&t.adler)),t.adler=1}if(n.status===H)if(n.gzhead.extra){for(s=n.pending;n.gzindex<(65535&n.gzhead.extra.length)&&(n.pending!==n.pending_buf_size||(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending!==n.pending_buf_size));)at(n,255&n.gzhead.extra[n.gzindex]),n.gzindex++;n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),n.gzindex===n.gzhead.extra.length&&(n.gzindex=0,n.status=j)}else n.status=j;if(n.status===j)if(n.gzhead.name){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.name.length?255&n.gzhead.name.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.gzindex=0,n.status=K)}else n.status=K;if(n.status===K)if(n.gzhead.comment){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.comment.length?255&n.gzhead.comment.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.status=M)}else n.status=M;if(n.status===M&&(n.gzhead.hcrc?(n.pending+2>n.pending_buf_size&&tt(t),n.pending+2<=n.pending_buf_size&&(at(n,255&t.adler),at(n,t.adler>>8&255),t.adler=0,n.status=P)):n.status=P),0!==n.pending){if(tt(t),0===t.avail_out)return n.last_flush=-1,c}else if(0===t.avail_in&&V(e)<=V(a)&&e!==_)return Q(t,w);if(n.status===Y&&0!==t.avail_in)return Q(t,w);if(0!==t.avail_in||0!==n.lookahead||e!==h&&n.status!==Y){var p=n.strategy===k?function(t,e){for(var a;;){if(0===t.lookahead&&(rt(t),0===t.lookahead)){if(e===h)return q;break}if(t.match_length=0,a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++,a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):n.strategy===y?function(t,e){for(var a,i,n,s,o=t.window;;){if(t.lookahead<=U){if(rt(t),t.lookahead<=U&&e===h)return q;if(0===t.lookahead)break}if(t.match_length=0,t.lookahead>=I&&t.strstart>0&&(i=o[n=t.strstart-1])===o[++n]&&i===o[++n]&&i===o[++n]){s=t.strstart+U;do{}while(i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&n<s);t.match_length=U-(s-n),t.match_length>t.lookahead&&(t.match_length=t.lookahead)}if(t.match_length>=I?(a=r._tr_tally(t,1,t.match_length-I),t.lookahead-=t.match_length,t.strstart+=t.match_length,t.match_length=0):(a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++),a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):i[n.level].func(n,e);if(p!==X&&p!==W||(n.status=Y),p===q||p===X)return 0===t.avail_out&&(n.last_flush=-1),c;if(p===G&&(e===d?r._tr_align(n):e!==u&&(r._tr_stored_block(n,0,0,!1),e===f&&($(n.head),0===n.lookahead&&(n.strstart=0,n.block_start=0,n.insert=0))),tt(t),0===t.avail_out))return n.last_flush=-1,c}return e!==_?c:n.wrap<=0?b:(2===n.wrap?(at(n,255&t.adler),at(n,t.adler>>8&255),at(n,t.adler>>16&255),at(n,t.adler>>24&255),at(n,255&t.total_in),at(n,t.total_in>>8&255),at(n,t.total_in>>16&255),at(n,t.total_in>>24&255)):(it(n,t.adler>>>16),it(n,65535&t.adler)),tt(t),n.wrap>0&&(n.wrap=-n.wrap),0!==n.pending?c:b)},a.deflateEnd=function(t){var e;return t&&t.state?(e=t.state.status)!==L&&e!==H&&e!==j&&e!==K&&e!==M&&e!==P&&e!==Y?Q(t,g):(t.state=null,e===P?Q(t,m):c):g},a.deflateSetDictionary=function(t,e){var a,i,r,o,l,h,d,f,_=e.length;if(!t||!t.state)return g;if(2===(o=(a=t.state).wrap)||1===o&&a.status!==L||a.lookahead)return g;for(1===o&&(t.adler=s(t.adler,e,_,0)),a.wrap=0,_>=a.w_size&&(0===o&&($(a.head),a.strstart=0,a.block_start=0,a.insert=0),f=new n.Buf8(a.w_size),n.arraySet(f,e,_-a.w_size,a.w_size,0),e=f,_=a.w_size),l=t.avail_in,h=t.next_in,d=t.input,t.avail_in=_,t.next_in=0,t.input=e,rt(a);a.lookahead>=I;){i=a.strstart,r=a.lookahead-(I-1);do{a.ins_h=(a.ins_h<<a.hash_shift^a.window[i+I-1])&a.hash_mask,a.prev[i&a.w_mask]=a.head[a.ins_h],a.head[a.ins_h]=i,i++}while(--r);a.strstart=i,a.lookahead=I-1,rt(a)}return a.strstart+=a.lookahead,a.block_start=a.strstart,a.insert=a.lookahead,a.lookahead=0,a.match_length=a.prev_length=I-1,a.match_available=0,t.next_in=h,t.input=d,t.avail_in=l,a.wrap=o,c},a.deflateInfo="pako deflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./messages":13,"./trees":14}],9:[function(t,e,a){"use strict";e.exports=function(){this.text=0,this.time=0,this.xflags=0,this.os=0,this.extra=null,this.extra_len=0,this.name="",this.comment="",this.hcrc=0,this.done=!1}},{}],10:[function(t,e,a){"use strict";e.exports=function(t,e){var a,i,n,r,s,o,l,h,d,f,_,u,c,b,g,m,w,p,v,k,y,x,z,B,S;a=t.state,i=t.next_in,B=t.input,n=i+(t.avail_in-5),r=t.next_out,S=t.output,s=r-(e-t.avail_out),o=r+(t.avail_out-257),l=a.dmax,h=a.wsize,d=a.whave,f=a.wnext,_=a.window,u=a.hold,c=a.bits,b=a.lencode,g=a.distcode,m=(1<<a.lenbits)-1,w=(1<<a.distbits)-1;t:do{c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=b[u&m];e:for(;;){if(u>>>=v=p>>>24,c-=v,0===(v=p>>>16&255))S[r++]=65535&p;else{if(!(16&v)){if(0==(64&v)){p=b[(65535&p)+(u&(1<<v)-1)];continue e}if(32&v){a.mode=12;break t}t.msg="invalid literal/length code",a.mode=30;break t}k=65535&p,(v&=15)&&(c<v&&(u+=B[i++]<<c,c+=8),k+=u&(1<<v)-1,u>>>=v,c-=v),c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=g[u&w];a:for(;;){if(u>>>=v=p>>>24,c-=v,!(16&(v=p>>>16&255))){if(0==(64&v)){p=g[(65535&p)+(u&(1<<v)-1)];continue a}t.msg="invalid distance code",a.mode=30;break t}if(y=65535&p,c<(v&=15)&&(u+=B[i++]<<c,(c+=8)<v&&(u+=B[i++]<<c,c+=8)),(y+=u&(1<<v)-1)>l){t.msg="invalid distance too far back",a.mode=30;break t}if(u>>>=v,c-=v,y>(v=r-s)){if((v=y-v)>d&&a.sane){t.msg="invalid distance too far back",a.mode=30;break t}if(x=0,z=_,0===f){if(x+=h-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}else if(f<v){if(x+=h+f-v,(v-=f)<k){k-=v;do{S[r++]=_[x++]}while(--v);if(x=0,f<k){k-=v=f;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}}else if(x+=f-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}for(;k>2;)S[r++]=z[x++],S[r++]=z[x++],S[r++]=z[x++],k-=3;k&&(S[r++]=z[x++],k>1&&(S[r++]=z[x++]))}else{x=r-y;do{S[r++]=S[x++],S[r++]=S[x++],S[r++]=S[x++],k-=3}while(k>2);k&&(S[r++]=S[x++],k>1&&(S[r++]=S[x++]))}break}}break}}while(i<n&&r<o);i-=k=c>>3,u&=(1<<(c-=k<<3))-1,t.next_in=i,t.next_out=r,t.avail_in=i<n?n-i+5:5-(i-n),t.avail_out=r<o?o-r+257:257-(r-o),a.hold=u,a.bits=c}},{}],11:[function(t,e,a){"use strict";var i=t("../utils/common"),n=t("./adler32"),r=t("./crc32"),s=t("./inffast"),o=t("./inftrees"),l=0,h=1,d=2,f=4,_=5,u=6,c=0,b=1,g=2,m=-2,w=-3,p=-4,v=-5,k=8,y=1,x=2,z=3,B=4,S=5,E=6,A=7,Z=8,R=9,C=10,N=11,O=12,D=13,I=14,U=15,T=16,F=17,L=18,H=19,j=20,K=21,M=22,P=23,Y=24,q=25,G=26,X=27,W=28,J=29,Q=30,V=31,$=32,tt=852,et=592,at=15;function it(t){return(t>>>24&255)+(t>>>8&65280)+((65280&t)<<8)+((255&t)<<24)}function nt(){this.mode=0,this.last=!1,this.wrap=0,this.havedict=!1,this.flags=0,this.dmax=0,this.check=0,this.total=0,this.head=null,this.wbits=0,this.wsize=0,this.whave=0,this.wnext=0,this.window=null,this.hold=0,this.bits=0,this.length=0,this.offset=0,this.extra=0,this.lencode=null,this.distcode=null,this.lenbits=0,this.distbits=0,this.ncode=0,this.nlen=0,this.ndist=0,this.have=0,this.next=null,this.lens=new i.Buf16(320),this.work=new i.Buf16(288),this.lendyn=null,this.distdyn=null,this.sane=0,this.back=0,this.was=0}function rt(t){var e;return t&&t.state?(e=t.state,t.total_in=t.total_out=e.total=0,t.msg="",e.wrap&&(t.adler=1&e.wrap),e.mode=y,e.last=0,e.havedict=0,e.dmax=32768,e.head=null,e.hold=0,e.bits=0,e.lencode=e.lendyn=new i.Buf32(tt),e.distcode=e.distdyn=new i.Buf32(et),e.sane=1,e.back=-1,c):m}function st(t){var e;return t&&t.state?((e=t.state).wsize=0,e.whave=0,e.wnext=0,rt(t)):m}function ot(t,e){var a,i;return t&&t.state?(i=t.state,e<0?(a=0,e=-e):(a=1+(e>>4),e<48&&(e&=15)),e&&(e<8||e>15)?m:(null!==i.window&&i.wbits!==e&&(i.window=null),i.wrap=a,i.wbits=e,st(t))):m}function lt(t,e){var a,i;return t?(i=new nt,t.state=i,i.window=null,(a=ot(t,e))!==c&&(t.state=null),a):m}var ht,dt,ft=!0;function _t(t){if(ft){var e;for(ht=new i.Buf32(512),dt=new i.Buf32(32),e=0;e<144;)t.lens[e++]=8;for(;e<256;)t.lens[e++]=9;for(;e<280;)t.lens[e++]=7;for(;e<288;)t.lens[e++]=8;for(o(h,t.lens,0,288,ht,0,t.work,{bits:9}),e=0;e<32;)t.lens[e++]=5;o(d,t.lens,0,32,dt,0,t.work,{bits:5}),ft=!1}t.lencode=ht,t.lenbits=9,t.distcode=dt,t.distbits=5}function ut(t,e,a,n){var r,s=t.state;return null===s.window&&(s.wsize=1<<s.wbits,s.wnext=0,s.whave=0,s.window=new i.Buf8(s.wsize)),n>=s.wsize?(i.arraySet(s.window,e,a-s.wsize,s.wsize,0),s.wnext=0,s.whave=s.wsize):((r=s.wsize-s.wnext)>n&&(r=n),i.arraySet(s.window,e,a-n,r,s.wnext),(n-=r)?(i.arraySet(s.window,e,a-n,n,0),s.wnext=n,s.whave=s.wsize):(s.wnext+=r,s.wnext===s.wsize&&(s.wnext=0),s.whave<s.wsize&&(s.whave+=r))),0}a.inflateReset=st,a.inflateReset2=ot,a.inflateResetKeep=rt,a.inflateInit=function(t){return lt(t,at)},a.inflateInit2=lt,a.inflate=function(t,e){var a,tt,et,at,nt,rt,st,ot,lt,ht,dt,ft,ct,bt,gt,mt,wt,pt,vt,kt,yt,xt,zt,Bt,St=0,Et=new i.Buf8(4),At=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15];if(!t||!t.state||!t.output||!t.input&&0!==t.avail_in)return m;(a=t.state).mode===O&&(a.mode=D),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,ht=rt,dt=st,xt=c;t:for(;;)switch(a.mode){case y:if(0===a.wrap){a.mode=D;break}for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(2&a.wrap&&35615===ot){a.check=0,Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0),ot=0,lt=0,a.mode=x;break}if(a.flags=0,a.head&&(a.head.done=!1),!(1&a.wrap)||(((255&ot)<<8)+(ot>>8))%31){t.msg="incorrect header check",a.mode=Q;break}if((15&ot)!==k){t.msg="unknown compression method",a.mode=Q;break}if(lt-=4,yt=8+(15&(ot>>>=4)),0===a.wbits)a.wbits=yt;else if(yt>a.wbits){t.msg="invalid window size",a.mode=Q;break}a.dmax=1<<yt,t.adler=a.check=1,a.mode=512&ot?C:O,ot=0,lt=0;break;case x:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.flags=ot,(255&a.flags)!==k){t.msg="unknown compression method",a.mode=Q;break}if(57344&a.flags){t.msg="unknown header flags set",a.mode=Q;break}a.head&&(a.head.text=ot>>8&1),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=z;case z:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.time=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,Et[2]=ot>>>16&255,Et[3]=ot>>>24&255,a.check=r(a.check,Et,4,0)),ot=0,lt=0,a.mode=B;case B:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.xflags=255&ot,a.head.os=ot>>8),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=S;case S:if(1024&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length=ot,a.head&&(a.head.extra_len=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0}else a.head&&(a.head.extra=null);a.mode=E;case E:if(1024&a.flags&&((ft=a.length)>rt&&(ft=rt),ft&&(a.head&&(yt=a.head.extra_len-a.length,a.head.extra||(a.head.extra=new Array(a.head.extra_len)),i.arraySet(a.head.extra,tt,at,ft,yt)),512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,a.length-=ft),a.length))break t;a.length=0,a.mode=A;case A:if(2048&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.name+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.name=null);a.length=0,a.mode=Z;case Z:if(4096&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.comment+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.comment=null);a.mode=R;case R:if(512&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(65535&a.check)){t.msg="header crc mismatch",a.mode=Q;break}ot=0,lt=0}a.head&&(a.head.hcrc=a.flags>>9&1,a.head.done=!0),t.adler=a.check=0,a.mode=O;break;case C:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}t.adler=a.check=it(ot),ot=0,lt=0,a.mode=N;case N:if(0===a.havedict)return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,g;t.adler=a.check=1,a.mode=O;case O:if(e===_||e===u)break t;case D:if(a.last){ot>>>=7&lt,lt-=7&lt,a.mode=X;break}for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}switch(a.last=1&ot,lt-=1,3&(ot>>>=1)){case 0:a.mode=I;break;case 1:if(_t(a),a.mode=j,e===u){ot>>>=2,lt-=2;break t}break;case 2:a.mode=F;break;case 3:t.msg="invalid block type",a.mode=Q}ot>>>=2,lt-=2;break;case I:for(ot>>>=7&lt,lt-=7&lt;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if((65535&ot)!=(ot>>>16^65535)){t.msg="invalid stored block lengths",a.mode=Q;break}if(a.length=65535&ot,ot=0,lt=0,a.mode=U,e===u)break t;case U:a.mode=T;case T:if(ft=a.length){if(ft>rt&&(ft=rt),ft>st&&(ft=st),0===ft)break t;i.arraySet(et,tt,at,ft,nt),rt-=ft,at+=ft,st-=ft,nt+=ft,a.length-=ft;break}a.mode=O;break;case F:for(;lt<14;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.nlen=257+(31&ot),ot>>>=5,lt-=5,a.ndist=1+(31&ot),ot>>>=5,lt-=5,a.ncode=4+(15&ot),ot>>>=4,lt-=4,a.nlen>286||a.ndist>30){t.msg="too many length or distance symbols",a.mode=Q;break}a.have=0,a.mode=L;case L:for(;a.have<a.ncode;){for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.lens[At[a.have++]]=7&ot,ot>>>=3,lt-=3}for(;a.have<19;)a.lens[At[a.have++]]=0;if(a.lencode=a.lendyn,a.lenbits=7,zt={bits:a.lenbits},xt=o(l,a.lens,0,19,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid code lengths set",a.mode=Q;break}a.have=0,a.mode=H;case H:for(;a.have<a.nlen+a.ndist;){for(;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(wt<16)ot>>>=gt,lt-=gt,a.lens[a.have++]=wt;else{if(16===wt){for(Bt=gt+2;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot>>>=gt,lt-=gt,0===a.have){t.msg="invalid bit length repeat",a.mode=Q;break}yt=a.lens[a.have-1],ft=3+(3&ot),ot>>>=2,lt-=2}else if(17===wt){for(Bt=gt+3;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=3+(7&(ot>>>=gt)),ot>>>=3,lt-=3}else{for(Bt=gt+7;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=11+(127&(ot>>>=gt)),ot>>>=7,lt-=7}if(a.have+ft>a.nlen+a.ndist){t.msg="invalid bit length repeat",a.mode=Q;break}for(;ft--;)a.lens[a.have++]=yt}}if(a.mode===Q)break;if(0===a.lens[256]){t.msg="invalid code -- missing end-of-block",a.mode=Q;break}if(a.lenbits=9,zt={bits:a.lenbits},xt=o(h,a.lens,0,a.nlen,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid literal/lengths set",a.mode=Q;break}if(a.distbits=6,a.distcode=a.distdyn,zt={bits:a.distbits},xt=o(d,a.lens,a.nlen,a.ndist,a.distcode,0,a.work,zt),a.distbits=zt.bits,xt){t.msg="invalid distances set",a.mode=Q;break}if(a.mode=j,e===u)break t;case j:a.mode=K;case K:if(rt>=6&&st>=258){t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,s(t,dt),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,a.mode===O&&(a.back=-1);break}for(a.back=0;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(mt&&0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.lencode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,a.length=wt,0===mt){a.mode=G;break}if(32&mt){a.back=-1,a.mode=O;break}if(64&mt){t.msg="invalid literal/length code",a.mode=Q;break}a.extra=15&mt,a.mode=M;case M:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}a.was=a.length,a.mode=P;case P:for(;mt=(St=a.distcode[ot&(1<<a.distbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.distcode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,64&mt){t.msg="invalid distance code",a.mode=Q;break}a.offset=wt,a.extra=15&mt,a.mode=Y;case Y:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.offset+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}if(a.offset>a.dmax){t.msg="invalid distance too far back",a.mode=Q;break}a.mode=q;case q:if(0===st)break t;if(ft=dt-st,a.offset>ft){if((ft=a.offset-ft)>a.whave&&a.sane){t.msg="invalid distance too far back",a.mode=Q;break}ft>a.wnext?(ft-=a.wnext,ct=a.wsize-ft):ct=a.wnext-ft,ft>a.length&&(ft=a.length),bt=a.window}else bt=et,ct=nt-a.offset,ft=a.length;ft>st&&(ft=st),st-=ft,a.length-=ft;do{et[nt++]=bt[ct++]}while(--ft);0===a.length&&(a.mode=K);break;case G:if(0===st)break t;et[nt++]=a.length,st--,a.mode=K;break;case X:if(a.wrap){for(;lt<32;){if(0===rt)break t;rt--,ot|=tt[at++]<<lt,lt+=8}if(dt-=st,t.total_out+=dt,a.total+=dt,dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,nt-dt):n(a.check,et,dt,nt-dt)),dt=st,(a.flags?ot:it(ot))!==a.check){t.msg="incorrect data check",a.mode=Q;break}ot=0,lt=0}a.mode=W;case W:if(a.wrap&&a.flags){for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(4294967295&a.total)){t.msg="incorrect length check",a.mode=Q;break}ot=0,lt=0}a.mode=J;case J:xt=b;break t;case Q:xt=w;break t;case V:return p;case $:default:return m}return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,(a.wsize||dt!==t.avail_out&&a.mode<Q&&(a.mode<X||e!==f))&&ut(t,t.output,t.next_out,dt-t.avail_out)?(a.mode=V,p):(ht-=t.avail_in,dt-=t.avail_out,t.total_in+=ht,t.total_out+=dt,a.total+=dt,a.wrap&&dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,t.next_out-dt):n(a.check,et,dt,t.next_out-dt)),t.data_type=a.bits+(a.last?64:0)+(a.mode===O?128:0)+(a.mode===j||a.mode===U?256:0),(0===ht&&0===dt||e===f)&&xt===c&&(xt=v),xt)},a.inflateEnd=function(t){if(!t||!t.state)return m;var e=t.state;return e.window&&(e.window=null),t.state=null,c},a.inflateGetHeader=function(t,e){var a;return t&&t.state?0==(2&(a=t.state).wrap)?m:(a.head=e,e.done=!1,c):m},a.inflateSetDictionary=function(t,e){var a,i=e.length;return t&&t.state?0!==(a=t.state).wrap&&a.mode!==N?m:a.mode===N&&n(1,e,i,0)!==a.check?w:ut(t,e,i,i)?(a.mode=V,p):(a.havedict=1,c):m},a.inflateInfo="pako inflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./inffast":10,"./inftrees":12}],12:[function(t,e,a){"use strict";var i=t("../utils/common"),n=[3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258,0,0],r=[16,16,16,16,16,16,16,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,16,72,78],s=[1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0],o=[16,16,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,64,64];e.exports=function(t,e,a,l,h,d,f,_){var u,c,b,g,m,w,p,v,k,y=_.bits,x=0,z=0,B=0,S=0,E=0,A=0,Z=0,R=0,C=0,N=0,O=null,D=0,I=new i.Buf16(16),U=new i.Buf16(16),T=null,F=0;for(x=0;x<=15;x++)I[x]=0;for(z=0;z<l;z++)I[e[a+z]]++;for(E=y,S=15;S>=1&&0===I[S];S--);if(E>S&&(E=S),0===S)return h[d++]=20971520,h[d++]=20971520,_.bits=1,0;for(B=1;B<S&&0===I[B];B++);for(E<B&&(E=B),R=1,x=1;x<=15;x++)if(R<<=1,(R-=I[x])<0)return-1;if(R>0&&(0===t||1!==S))return-1;for(U[1]=0,x=1;x<15;x++)U[x+1]=U[x]+I[x];for(z=0;z<l;z++)0!==e[a+z]&&(f[U[e[a+z]]++]=z);if(0===t?(O=T=f,w=19):1===t?(O=n,D-=257,T=r,F-=257,w=256):(O=s,T=o,w=-1),N=0,z=0,x=B,m=d,A=E,Z=0,b=-1,g=(C=1<<E)-1,1===t&&C>852||2===t&&C>592)return 1;for(;;){p=x-Z,f[z]<w?(v=0,k=f[z]):f[z]>w?(v=T[F+f[z]],k=O[D+f[z]]):(v=96,k=0),u=1<<x-Z,B=c=1<<A;do{h[m+(N>>Z)+(c-=u)]=p<<24|v<<16|k|0}while(0!==c);for(u=1<<x-1;N&u;)u>>=1;if(0!==u?(N&=u-1,N+=u):N=0,z++,0==--I[x]){if(x===S)break;x=e[a+f[z]]}if(x>E&&(N&g)!==b){for(0===Z&&(Z=E),m+=B,R=1<<(A=x-Z);A+Z<S&&!((R-=I[A+Z])<=0);)A++,R<<=1;if(C+=1<<A,1===t&&C>852||2===t&&C>592)return 1;h[b=N&g]=E<<24|A<<16|m-d|0}}return 0!==N&&(h[m+N]=x-Z<<24|64<<16|0),_.bits=E,0}},{"../utils/common":3}],13:[function(t,e,a){"use strict";e.exports={2:"need dictionary",1:"stream end",0:"","-1":"file error","-2":"stream error","-3":"data error","-4":"insufficient memory","-5":"buffer error","-6":"incompatible version"}},{}],14:[function(t,e,a){"use strict";var i=t("../utils/common"),n=4,r=0,s=1,o=2;function l(t){for(var e=t.length;--e>=0;)t[e]=0}var h=0,d=1,f=2,_=29,u=256,c=u+1+_,b=30,g=19,m=2*c+1,w=15,p=16,v=7,k=256,y=16,x=17,z=18,B=[0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0],S=[0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13],E=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7],A=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15],Z=new Array(2*(c+2));l(Z);var R=new Array(2*b);l(R);var C=new Array(512);l(C);var N=new Array(256);l(N);var O=new Array(_);l(O);var D,I,U,T=new Array(b);function F(t,e,a,i,n){this.static_tree=t,this.extra_bits=e,this.extra_base=a,this.elems=i,this.max_length=n,this.has_stree=t&&t.length}function L(t,e){this.dyn_tree=t,this.max_code=0,this.stat_desc=e}function H(t){return t<256?C[t]:C[256+(t>>>7)]}function j(t,e){t.pending_buf[t.pending++]=255&e,t.pending_buf[t.pending++]=e>>>8&255}function K(t,e,a){t.bi_valid>p-a?(t.bi_buf|=e<<t.bi_valid&65535,j(t,t.bi_buf),t.bi_buf=e>>p-t.bi_valid,t.bi_valid+=a-p):(t.bi_buf|=e<<t.bi_valid&65535,t.bi_valid+=a)}function M(t,e,a){K(t,a[2*e],a[2*e+1])}function P(t,e){var a=0;do{a|=1&t,t>>>=1,a<<=1}while(--e>0);return a>>>1}function Y(t,e,a){var i,n,r=new Array(w+1),s=0;for(i=1;i<=w;i++)r[i]=s=s+a[i-1]<<1;for(n=0;n<=e;n++){var o=t[2*n+1];0!==o&&(t[2*n]=P(r[o]++,o))}}function q(t){var e;for(e=0;e<c;e++)t.dyn_ltree[2*e]=0;for(e=0;e<b;e++)t.dyn_dtree[2*e]=0;for(e=0;e<g;e++)t.bl_tree[2*e]=0;t.dyn_ltree[2*k]=1,t.opt_len=t.static_len=0,t.last_lit=t.matches=0}function G(t){t.bi_valid>8?j(t,t.bi_buf):t.bi_valid>0&&(t.pending_buf[t.pending++]=t.bi_buf),t.bi_buf=0,t.bi_valid=0}function X(t,e,a,i){var n=2*e,r=2*a;return t[n]<t[r]||t[n]===t[r]&&i[e]<=i[a]}function W(t,e,a){for(var i=t.heap[a],n=a<<1;n<=t.heap_len&&(n<t.heap_len&&X(e,t.heap[n+1],t.heap[n],t.depth)&&n++,!X(e,i,t.heap[n],t.depth));)t.heap[a]=t.heap[n],a=n,n<<=1;t.heap[a]=i}function J(t,e,a){var i,n,r,s,o=0;if(0!==t.last_lit)do{i=t.pending_buf[t.d_buf+2*o]<<8|t.pending_buf[t.d_buf+2*o+1],n=t.pending_buf[t.l_buf+o],o++,0===i?M(t,n,e):(M(t,(r=N[n])+u+1,e),0!==(s=B[r])&&K(t,n-=O[r],s),M(t,r=H(--i),a),0!==(s=S[r])&&K(t,i-=T[r],s))}while(o<t.last_lit);M(t,k,e)}function Q(t,e){var a,i,n,r=e.dyn_tree,s=e.stat_desc.static_tree,o=e.stat_desc.has_stree,l=e.stat_desc.elems,h=-1;for(t.heap_len=0,t.heap_max=m,a=0;a<l;a++)0!==r[2*a]?(t.heap[++t.heap_len]=h=a,t.depth[a]=0):r[2*a+1]=0;for(;t.heap_len<2;)r[2*(n=t.heap[++t.heap_len]=h<2?++h:0)]=1,t.depth[n]=0,t.opt_len--,o&&(t.static_len-=s[2*n+1]);for(e.max_code=h,a=t.heap_len>>1;a>=1;a--)W(t,r,a);n=l;do{a=t.heap[1],t.heap[1]=t.heap[t.heap_len--],W(t,r,1),i=t.heap[1],t.heap[--t.heap_max]=a,t.heap[--t.heap_max]=i,r[2*n]=r[2*a]+r[2*i],t.depth[n]=(t.depth[a]>=t.depth[i]?t.depth[a]:t.depth[i])+1,r[2*a+1]=r[2*i+1]=n,t.heap[1]=n++,W(t,r,1)}while(t.heap_len>=2);t.heap[--t.heap_max]=t.heap[1],function(t,e){var a,i,n,r,s,o,l=e.dyn_tree,h=e.max_code,d=e.stat_desc.static_tree,f=e.stat_desc.has_stree,_=e.stat_desc.extra_bits,u=e.stat_desc.extra_base,c=e.stat_desc.max_length,b=0;for(r=0;r<=w;r++)t.bl_count[r]=0;for(l[2*t.heap[t.heap_max]+1]=0,a=t.heap_max+1;a<m;a++)(r=l[2*l[2*(i=t.heap[a])+1]+1]+1)>c&&(r=c,b++),l[2*i+1]=r,i>h||(t.bl_count[r]++,s=0,i>=u&&(s=_[i-u]),o=l[2*i],t.opt_len+=o*(r+s),f&&(t.static_len+=o*(d[2*i+1]+s)));if(0!==b){do{for(r=c-1;0===t.bl_count[r];)r--;t.bl_count[r]--,t.bl_count[r+1]+=2,t.bl_count[c]--,b-=2}while(b>0);for(r=c;0!==r;r--)for(i=t.bl_count[r];0!==i;)(n=t.heap[--a])>h||(l[2*n+1]!==r&&(t.opt_len+=(r-l[2*n+1])*l[2*n],l[2*n+1]=r),i--)}}(t,e),Y(r,h,t.bl_count)}function V(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),e[2*(a+1)+1]=65535,i=0;i<=a;i++)n=s,s=e[2*(i+1)+1],++o<l&&n===s||(o<h?t.bl_tree[2*n]+=o:0!==n?(n!==r&&t.bl_tree[2*n]++,t.bl_tree[2*y]++):o<=10?t.bl_tree[2*x]++:t.bl_tree[2*z]++,o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4))}function $(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),i=0;i<=a;i++)if(n=s,s=e[2*(i+1)+1],!(++o<l&&n===s)){if(o<h)do{M(t,n,t.bl_tree)}while(0!=--o);else 0!==n?(n!==r&&(M(t,n,t.bl_tree),o--),M(t,y,t.bl_tree),K(t,o-3,2)):o<=10?(M(t,x,t.bl_tree),K(t,o-3,3)):(M(t,z,t.bl_tree),K(t,o-11,7));o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4)}}l(T);var tt=!1;function et(t,e,a,n){K(t,(h<<1)+(n?1:0),3),function(t,e,a,n){G(t),n&&(j(t,a),j(t,~a)),i.arraySet(t.pending_buf,t.window,e,a,t.pending),t.pending+=a}(t,e,a,!0)}a._tr_init=function(t){tt||(function(){var t,e,a,i,n,r=new Array(w+1);for(a=0,i=0;i<_-1;i++)for(O[i]=a,t=0;t<1<<B[i];t++)N[a++]=i;for(N[a-1]=i,n=0,i=0;i<16;i++)for(T[i]=n,t=0;t<1<<S[i];t++)C[n++]=i;for(n>>=7;i<b;i++)for(T[i]=n<<7,t=0;t<1<<S[i]-7;t++)C[256+n++]=i;for(e=0;e<=w;e++)r[e]=0;for(t=0;t<=143;)Z[2*t+1]=8,t++,r[8]++;for(;t<=255;)Z[2*t+1]=9,t++,r[9]++;for(;t<=279;)Z[2*t+1]=7,t++,r[7]++;for(;t<=287;)Z[2*t+1]=8,t++,r[8]++;for(Y(Z,c+1,r),t=0;t<b;t++)R[2*t+1]=5,R[2*t]=P(t,5);D=new F(Z,B,u+1,c,w),I=new F(R,S,0,b,w),U=new F(new Array(0),E,0,g,v)}(),tt=!0),t.l_desc=new L(t.dyn_ltree,D),t.d_desc=new L(t.dyn_dtree,I),t.bl_desc=new L(t.bl_tree,U),t.bi_buf=0,t.bi_valid=0,q(t)},a._tr_stored_block=et,a._tr_flush_block=function(t,e,a,i){var l,h,_=0;t.level>0?(t.strm.data_type===o&&(t.strm.data_type=function(t){var e,a=4093624447;for(e=0;e<=31;e++,a>>>=1)if(1&a&&0!==t.dyn_ltree[2*e])return r;if(0!==t.dyn_ltree[18]||0!==t.dyn_ltree[20]||0!==t.dyn_ltree[26])return s;for(e=32;e<u;e++)if(0!==t.dyn_ltree[2*e])return s;return r}(t)),Q(t,t.l_desc),Q(t,t.d_desc),_=function(t){var e;for(V(t,t.dyn_ltree,t.l_desc.max_code),V(t,t.dyn_dtree,t.d_desc.max_code),Q(t,t.bl_desc),e=g-1;e>=3&&0===t.bl_tree[2*A[e]+1];e--);return t.opt_len+=3*(e+1)+5+5+4,e}(t),l=t.opt_len+3+7>>>3,(h=t.static_len+3+7>>>3)<=l&&(l=h)):l=h=a+5,a+4<=l&&-1!==e?et(t,e,a,i):t.strategy===n||h===l?(K(t,(d<<1)+(i?1:0),3),J(t,Z,R)):(K(t,(f<<1)+(i?1:0),3),function(t,e,a,i){var n;for(K(t,e-257,5),K(t,a-1,5),K(t,i-4,4),n=0;n<i;n++)K(t,t.bl_tree[2*A[n]+1],3);$(t,t.dyn_ltree,e-1),$(t,t.dyn_dtree,a-1)}(t,t.l_desc.max_code+1,t.d_desc.max_code+1,_+1),J(t,t.dyn_ltree,t.dyn_dtree)),q(t),i&&G(t)},a._tr_tally=function(t,e,a){return t.pending_buf[t.d_buf+2*t.last_lit]=e>>>8&255,t.pending_buf[t.d_buf+2*t.last_lit+1]=255&e,t.pending_buf[t.l_buf+t.last_lit]=255&a,t.last_lit++,0===e?t.dyn_ltree[2*a]++:(t.matches++,e--,t.dyn_ltree[2*(N[a]+u+1)]++,t.dyn_dtree[2*H(e)]++),t.last_lit===t.lit_bufsize-1},a._tr_align=function(t){K(t,d<<1,3),M(t,k,Z),function(t){16===t.bi_valid?(j(t,t.bi_buf),t.bi_buf=0,t.bi_valid=0):t.bi_valid>=8&&(t.pending_buf[t.pending++]=255&t.bi_buf,t.bi_buf>>=8,t.bi_valid-=8)}(t)}},{"../utils/common":3}],15:[function(t,e,a){"use strict";e.exports=function(){this.input=null,this.next_in=0,this.avail_in=0,this.total_in=0,this.output=null,this.next_out=0,this.avail_out=0,this.total_out=0,this.msg="",this.state=null,this.data_type=2,this.adler=0}},{}],"/":[function(t,e,a){"use strict";var i={};(0,t("./lib/utils/common").assign)(i,t("./lib/deflate"),t("./lib/inflate"),t("./lib/zlib/constants")),e.exports=i},{"./lib/deflate":1,"./lib/inflate":2,"./lib/utils/common":3,"./lib/zlib/constants":6}]},{},[])("/")});
+</script>
+<script>
+    !function(){var e={};"object"==typeof module?module.exports=e:window.UPNG=e,function(e,r){e.toRGBA8=function(r){var t=r.width,n=r.height;if(null==r.tabs.acTL)return[e.toRGBA8.decodeImage(r.data,t,n,r).buffer];var i=[];null==r.frames[0].data&&(r.frames[0].data=r.data);for(var a,f=new Uint8Array(t*n*4),o=0;o<r.frames.length;o++){var s=r.frames[o],l=s.rect.x,c=s.rect.y,u=s.rect.width,d=s.rect.height,h=e.toRGBA8.decodeImage(s.data,u,d,r);if(0==o?a=h:0==s.blend?e._copyTile(h,u,d,a,t,n,l,c,0):1==s.blend&&e._copyTile(h,u,d,a,t,n,l,c,1),i.push(a.buffer),a=a.slice(0),0==s.dispose);else if(1==s.dispose)e._copyTile(f,u,d,a,t,n,l,c,0);else if(2==s.dispose){for(var v=o-1;2==r.frames[v].dispose;)v--;a=new Uint8Array(i[v]).slice(0)}}return i},e.toRGBA8.decodeImage=function(r,t,n,i){var a=t*n,f=e.decode._getBPP(i),o=Math.ceil(t*f/8),s=new Uint8Array(4*a),l=new Uint32Array(s.buffer),c=i.ctype,u=i.depth,d=e._bin.readUshort;if(6==c){var h=a<<2;if(8==u)for(var v=0;v<h;v++)s[v]=r[v];if(16==u)for(v=0;v<h;v++)s[v]=r[v<<1]}else if(2==c){var p=i.tabs.tRNS,b=-1,g=-1,m=-1;if(p&&(b=p[0],g=p[1],m=p[2]),8==u)for(v=0;v<a;v++){var y=3*v;s[M=v<<2]=r[y],s[M+1]=r[y+1],s[M+2]=r[y+2],s[M+3]=255,-1!=b&&r[y]==b&&r[y+1]==g&&r[y+2]==m&&(s[M+3]=0)}if(16==u)for(v=0;v<a;v++){y=6*v;s[M=v<<2]=r[y],s[M+1]=r[y+2],s[M+2]=r[y+4],s[M+3]=255,-1!=b&&d(r,y)==b&&d(r,y+2)==g&&d(r,y+4)==m&&(s[M+3]=0)}}else if(3==c){var w=i.tabs.PLTE,A=i.tabs.tRNS,U=A?A.length:0;if(1==u)for(var _=0;_<n;_++){var q=_*o,I=_*t;for(v=0;v<t;v++){var M=I+v<<2,T=3*(z=r[q+(v>>3)]>>7-((7&v)<<0)&1);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}if(2==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>2)]>>6-((3&v)<<1)&3);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(4==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>1)]>>4-((1&v)<<2)&15);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(8==u)for(v=0;v<a;v++){var z;M=v<<2,T=3*(z=r[v]);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}else if(4==c){if(8==u)for(v=0;v<a;v++){M=v<<2;var R=r[N=v<<1];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+1]}if(16==u)for(v=0;v<a;v++){var N;M=v<<2,R=r[N=v<<2];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+2]}}else if(0==c){b=i.tabs.tRNS?i.tabs.tRNS:-1;if(1==u)for(v=0;v<a;v++){var L=(R=255*(r[v>>3]>>7-(7&v)&1))==255*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(2==u)for(v=0;v<a;v++){L=(R=85*(r[v>>2]>>6-((3&v)<<1)&3))==85*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(4==u)for(v=0;v<a;v++){L=(R=17*(r[v>>1]>>4-((1&v)<<2)&15))==17*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(8==u)for(v=0;v<a;v++){L=(R=r[v])==b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(16==u)for(v=0;v<a;v++){R=r[v<<1],L=d(r,v<<1)==b?0:255;l[v]=L<<24|R<<16|R<<8|R}}return s},e.decode=function(r){for(var t,n=new Uint8Array(r),i=8,a=e._bin,f=a.readUshort,o=a.readUint,s={tabs:{},frames:[]},l=new Uint8Array(n.length),c=0,u=0,d=[137,80,78,71,13,10,26,10],h=0;h<8;h++)if(n[h]!=d[h])throw"The input is not a PNG file!";for(;i<n.length;){var v=a.readUint(n,i);i+=4;var p=a.readASCII(n,i,4);if(i+=4,"IHDR"==p)e.decode._IHDR(n,i,s);else if("IDAT"==p){for(h=0;h<v;h++)l[c+h]=n[i+h];c+=v}else if("acTL"==p)s.tabs[p]={num_frames:o(n,i),num_plays:o(n,i+4)},t=new Uint8Array(n.length);else if("fcTL"==p){var b;if(0!=u)(b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0;var g={x:o(n,i+12),y:o(n,i+16),width:o(n,i+4),height:o(n,i+8)},m=f(n,i+22);m=f(n,i+20)/(0==m?100:m);var y={rect:g,delay:Math.round(1e3*m),dispose:n[i+24],blend:n[i+25]};s.frames.push(y)}else if("fdAT"==p){for(h=0;h<v-4;h++)t[u+h]=n[i+h+4];u+=v-4}else if("pHYs"==p)s.tabs[p]=[a.readUint(n,i),a.readUint(n,i+4),n[i+8]];else if("cHRM"==p){s.tabs[p]=[];for(h=0;h<8;h++)s.tabs[p].push(a.readUint(n,i+4*h))}else if("tEXt"==p){null==s.tabs[p]&&(s.tabs[p]={});var w=a.nextZero(n,i),A=a.readASCII(n,i,w-i),U=a.readASCII(n,w+1,i+v-w-1);s.tabs[p][A]=U}else if("iTXt"==p){null==s.tabs[p]&&(s.tabs[p]={});w=0;var _=i;w=a.nextZero(n,_);A=a.readASCII(n,_,w-_),n[_=w+1],n[_+1];_+=2,w=a.nextZero(n,_);a.readASCII(n,_,w-_);_=w+1,w=a.nextZero(n,_);a.readUTF8(n,_,w-_);_=w+1;U=a.readUTF8(n,_,v-(_-i));s.tabs[p][A]=U}else if("PLTE"==p)s.tabs[p]=a.readBytes(n,i,v);else if("hIST"==p){var q=s.tabs.PLTE.length/3;s.tabs[p]=[];for(h=0;h<q;h++)s.tabs[p].push(f(n,i+2*h))}else if("tRNS"==p)3==s.ctype?s.tabs[p]=a.readBytes(n,i,v):0==s.ctype?s.tabs[p]=f(n,i):2==s.ctype&&(s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]);else if("gAMA"==p)s.tabs[p]=a.readUint(n,i)/1e5;else if("sRGB"==p)s.tabs[p]=n[i];else if("bKGD"==p)0==s.ctype||4==s.ctype?s.tabs[p]=[f(n,i)]:2==s.ctype||6==s.ctype?s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]:3==s.ctype&&(s.tabs[p]=n[i]);else if("IEND"==p)break;i+=v;a.readUint(n,i);i+=4}0!=u&&((b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0);return s.data=e.decode._decompress(s,l,s.width,s.height),delete s.compress,delete s.interlace,delete s.filter,s},e.decode._decompress=function(r,t,n,i){return 0==r.compress&&(t=e.decode._inflate(t)),0==r.interlace?t=e.decode._filterZero(t,r,0,n,i):1==r.interlace&&(t=e.decode._readInterlace(t,r)),t},e.decode._inflate=function(e){return r.inflate(e)},e.decode._readInterlace=function(r,t){for(var n=t.width,i=t.height,a=e.decode._getBPP(t),f=a>>3,o=Math.ceil(n*a/8),s=new Uint8Array(i*o),l=0,c=[0,0,4,0,2,0,1],u=[0,4,0,2,0,1,0],d=[8,8,8,4,4,2,2],h=[8,8,4,4,2,2,1],v=0;v<7;){for(var p=d[v],b=h[v],g=0,m=0,y=c[v];y<i;)y+=p,m++;for(var w=u[v];w<n;)w+=b,g++;var A=Math.ceil(g*a/8);e.decode._filterZero(r,t,l,g,m);for(var U=0,_=c[v];_<i;){for(var q=u[v],I=l+U*A<<3;q<n;){var M;if(1==a)M=(M=r[I>>3])>>7-(7&I)&1,s[_*o+(q>>3)]|=M<<7-((3&q)<<0);if(2==a)M=(M=r[I>>3])>>6-(7&I)&3,s[_*o+(q>>2)]|=M<<6-((3&q)<<1);if(4==a)M=(M=r[I>>3])>>4-(7&I)&15,s[_*o+(q>>1)]|=M<<4-((1&q)<<2);if(a>=8)for(var T=_*o+q*f,z=0;z<f;z++)s[T+z]=r[(I>>3)+z];I+=a,q+=b}U++,_+=p}g*m!=0&&(l+=m*(1+A)),v+=1}return s},e.decode._getBPP=function(e){return[1,null,3,1,2,null,4][e.ctype]*e.depth},e.decode._filterZero=function(r,t,n,i,a){var f=e.decode._getBPP(t),o=Math.ceil(i*f/8),s=e.decode._paeth;f=Math.ceil(f/8);for(var l=0;l<a;l++){var c=n+l*o,u=c+l+1,d=r[u-1];if(0==d)for(var h=0;h<o;h++)r[c+h]=r[u+h];else if(1==d){for(h=0;h<f;h++)r[c+h]=r[u+h];for(h=f;h<o;h++)r[c+h]=r[u+h]+r[c+h-f]&255}else if(0==l){for(h=0;h<f;h++)r[c+h]=r[u+h];if(2==d)for(h=f;h<o;h++)r[c+h]=255&r[u+h];if(3==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-f]>>1)&255;if(4==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],0,0)&255}else{if(2==d)for(h=0;h<o;h++)r[c+h]=r[u+h]+r[c+h-o]&255;if(3==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+(r[c+h-o]>>1)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-o]+r[c+h-f]>>1)&255}if(4==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+s(0,r[c+h-o],0)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],r[c+h-o],r[c+h-f-o])&255}}}return r},e.decode._paeth=function(e,r,t){var n=e+r-t,i=Math.abs(n-e),a=Math.abs(n-r),f=Math.abs(n-t);return i<=a&&i<=f?e:a<=f?r:t},e.decode._IHDR=function(r,t,n){var i=e._bin;n.width=i.readUint(r,t),t+=4,n.height=i.readUint(r,t),t+=4,n.depth=r[t],t++,n.ctype=r[t],t++,n.compress=r[t],t++,n.filter=r[t],t++,n.interlace=r[t],t++},e._bin={nextZero:function(e,r){for(;0!=e[r];)r++;return r},readUshort:function(e,r){return e[r]<<8|e[r+1]},writeUshort:function(e,r,t){e[r]=t>>8&255,e[r+1]=255&t},readUint:function(e,r){return 16777216*e[r]+(e[r+1]<<16|e[r+2]<<8|e[r+3])},writeUint:function(e,r,t){e[r]=t>>24&255,e[r+1]=t>>16&255,e[r+2]=t>>8&255,e[r+3]=255&t},readASCII:function(e,r,t){for(var n="",i=0;i<t;i++)n+=String.fromCharCode(e[r+i]);return n},writeASCII:function(e,r,t){for(var n=0;n<t.length;n++)e[r+n]=t.charCodeAt(n)},readBytes:function(e,r,t){for(var n=[],i=0;i<t;i++)n.push(e[r+i]);return n},pad:function(e){return e.length<2?"0"+e:e},readUTF8:function(r,t,n){for(var i,a="",f=0;f<n;f++)a+="%"+e._bin.pad(r[t+f].toString(16));try{i=decodeURIComponent(a)}catch(i){return e._bin.readASCII(r,t,n)}return i}},e._copyTile=function(e,r,t,n,i,a,f,o,s){for(var l=Math.min(r,i),c=Math.min(t,a),u=0,d=0,h=0;h<c;h++)for(var v=0;v<l;v++)if(f>=0&&o>=0?(u=h*r+v<<2,d=(o+h)*i+f+v<<2):(u=(-o+h)*r-f+v<<2,d=h*i+v<<2),0==s)n[d]=e[u],n[d+1]=e[u+1],n[d+2]=e[u+2],n[d+3]=e[u+3];else if(1==s){var p=e[u+3]*(1/255),b=e[u]*p,g=e[u+1]*p,m=e[u+2]*p,y=n[d+3]*(1/255),w=n[d]*y,A=n[d+1]*y,U=n[d+2]*y,_=1-p,q=p+y*_,I=0==q?0:1/q;n[d+3]=255*q,n[d+0]=(b+w*_)*I,n[d+1]=(g+A*_)*I,n[d+2]=(m+U*_)*I}else if(2==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];p==y&&b==w&&g==A&&m==U?(n[d]=0,n[d+1]=0,n[d+2]=0,n[d+3]=0):(n[d]=b,n[d+1]=g,n[d+2]=m,n[d+3]=p)}else if(3==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];if(p==y&&b==w&&g==A&&m==U)continue;if(p<220&&y>20)return!1}return!0},e.encode=function(r,t,n,i,a,f){null==i&&(i=0),null==f&&(f=!1);var o=e.encode.compress(r,t,n,i,!1,f);return e.encode.compressPNG(o,-1),e.encode._main(o,t,n,a)},e.encodeLL=function(r,t,n,i,a,f,o){for(var s={ctype:0+(1==i?0:2)+(0==a?0:4),depth:f,frames:[]},l=(i+a)*f,c=l*t,u=0;u<r.length;u++)s.frames.push({rect:{x:0,y:0,width:t,height:n},img:new Uint8Array(r[u]),blend:0,dispose:1,bpp:Math.ceil(l/8),bpl:Math.ceil(c/8)});return e.encode.compressPNG(s,4),e.encode._main(s,t,n,o)},e.encode._main=function(r,t,n,i){var a=e.crc.crc,f=e._bin.writeUint,o=e._bin.writeUshort,s=e._bin.writeASCII,l=8,c=r.frames.length>1,u=!1,d=46+(c?20:0);if(3==r.ctype){for(var h=r.plte.length,v=0;v<h;v++)r.plte[v]>>>24!=255&&(u=!0);d+=8+3*h+4+(u?8+1*h+4:0)}for(var p=0;p<r.frames.length;p++){c&&(d+=38),d+=(q=r.frames[p]).cimg.length+12,0!=p&&(d+=4)}d+=12;var b=new Uint8Array(d),g=[137,80,78,71,13,10,26,10];for(v=0;v<8;v++)b[v]=g[v];if(f(b,l,13),s(b,l+=4,"IHDR"),f(b,l+=4,t),f(b,l+=4,n),b[l+=4]=r.depth,b[++l]=r.ctype,b[++l]=0,b[++l]=0,b[++l]=0,f(b,++l,a(b,l-17,17)),f(b,l+=4,1),s(b,l+=4,"sRGB"),b[l+=4]=1,f(b,++l,a(b,l-5,5)),l+=4,c&&(f(b,l,8),s(b,l+=4,"acTL"),f(b,l+=4,r.frames.length),f(b,l+=4,0),f(b,l+=4,a(b,l-12,12)),l+=4),3==r.ctype){f(b,l,3*(h=r.plte.length)),s(b,l+=4,"PLTE"),l+=4;for(v=0;v<h;v++){var m=3*v,y=r.plte[v],w=255&y,A=y>>>8&255,U=y>>>16&255;b[l+m+0]=w,b[l+m+1]=A,b[l+m+2]=U}if(f(b,l+=3*h,a(b,l-3*h-4,3*h+4)),l+=4,u){f(b,l,h),s(b,l+=4,"tRNS"),l+=4;for(v=0;v<h;v++)b[l+v]=r.plte[v]>>>24&255;f(b,l+=h,a(b,l-h-4,h+4)),l+=4}}var _=0;for(p=0;p<r.frames.length;p++){var q=r.frames[p];c&&(f(b,l,26),s(b,l+=4,"fcTL"),f(b,l+=4,_++),f(b,l+=4,q.rect.width),f(b,l+=4,q.rect.height),f(b,l+=4,q.rect.x),f(b,l+=4,q.rect.y),o(b,l+=4,i[p]),o(b,l+=2,1e3),b[l+=2]=q.dispose,b[++l]=q.blend,f(b,++l,a(b,l-30,30)),l+=4);var I=q.cimg;f(b,l,(h=I.length)+(0==p?0:4));var M=l+=4;s(b,l,0==p?"IDAT":"fdAT"),l+=4,0!=p&&(f(b,l,_++),l+=4);for(v=0;v<h;v++)b[l+v]=I[v];f(b,l+=h,a(b,M,l-M)),l+=4}return f(b,l,0),s(b,l+=4,"IEND"),f(b,l+=4,a(b,l-4,4)),l+=4,b.buffer},e.encode.compressPNG=function(r,t){for(var n=0;n<r.frames.length;n++){var i=r.frames[n],a=(i.rect.width,i.rect.height),f=new Uint8Array(a*i.bpl+a);i.cimg=e.encode._filterZero(i.img,a,i.bpp,i.bpl,f,t)}},e.encode.compress=function(r,t,n,i,a,f){null==f&&(f=!1);for(var o=6,s=8,l=255,c=0;c<r.length;c++)for(var u=new Uint8Array(r[c]),d=u.length,h=0;h<d;h+=4)l&=u[h+3];var v=255!=l,p=v&&a,b=e.encode.framize(r,t,n,a,p),g={},m=[],y=[];if(0!=i){var w=[];for(h=0;h<b.length;h++)w.push(b[h].img.buffer);var A=e.encode.concatRGBA(w,a),U=e.quantize(A,i),_=0,q=new Uint8Array(U.abuf);for(h=0;h<b.length;h++){var I=(F=b[h].img).length;y.push(new Uint8Array(U.inds.buffer,_>>2,I>>2));for(c=0;c<I;c+=4)F[c]=q[_+c],F[c+1]=q[_+c+1],F[c+2]=q[_+c+2],F[c+3]=q[_+c+3];_+=I}for(h=0;h<U.plte.length;h++)m.push(U.plte[h].est.rgba)}else for(c=0;c<b.length;c++){var M=b[c],T=new Uint32Array(M.img.buffer),z=M.rect.width,R=(d=T.length,new Uint8Array(d));y.push(R);for(h=0;h<d;h++){var N=T[h];if(0!=h&&N==T[h-1])R[h]=R[h-1];else if(h>z&&N==T[h-z])R[h]=R[h-z];else{var L=g[N];if(null==L&&(g[N]=L=m.length,m.push(N),m.length>=300))break;R[h]=L}}}var P=m.length;P<=256&&0==f&&(s=P<=2?1:P<=4?2:P<=16?4:8,a&&(s=8));for(c=0;c<b.length;c++){(M=b[c]).rect.x,M.rect.y,z=M.rect.width;var S=M.rect.height,D=M.img,B=(new Uint32Array(D.buffer),4*z),x=4;if(P<=256&&0==f){B=Math.ceil(s*z/8);for(var C=new Uint8Array(B*S),G=y[c],Z=0;Z<S;Z++){h=Z*B;var k=Z*z;if(8==s)for(var E=0;E<z;E++)C[h+E]=G[k+E];else if(4==s)for(E=0;E<z;E++)C[h+(E>>1)]|=G[k+E]<<4-4*(1&E);else if(2==s)for(E=0;E<z;E++)C[h+(E>>2)]|=G[k+E]<<6-2*(3&E);else if(1==s)for(E=0;E<z;E++)C[h+(E>>3)]|=G[k+E]<<7-1*(7&E)}D=C,o=3,x=1}else if(0==v&&1==b.length){C=new Uint8Array(z*S*3);var H=z*S;for(h=0;h<H;h++){var F,K=4*h;C[F=3*h]=D[K],C[F+1]=D[K+1],C[F+2]=D[K+2]}D=C,o=2,x=3,B=3*z}M.img=D,M.bpl=B,M.bpp=x}return{ctype:o,depth:s,plte:m,frames:b}},e.encode.framize=function(r,t,n,i,a){for(var f=[],o=0;o<r.length;o++){var s=new Uint8Array(r[o]),l=new Uint32Array(s.buffer),c=0,u=0,d=t,h=n,v=0;if(0==o||a)s=s.slice(0);else{for(var p=i||1==o||2==f[f.length-2].dispose?1:2,b=0,g=1e9,m=0;m<p;m++){for(var y=new Uint8Array(r[o-1-m]),w=new Uint32Array(r[o-1-m]),A=t,U=n,_=-1,q=-1,I=0;I<n;I++)for(var M=0;M<t;M++){var T=I*t+M;l[T]!=w[T]&&(M<A&&(A=M),M>_&&(_=M),I<U&&(U=I),I>q&&(q=I))}var z=-1==_?1:(_-A+1)*(q-U+1);z<g&&(g=z,b=m,-1==_?(c=u=0,d=h=1):(c=A,u=U,d=_-A+1,h=q-U+1))}y=new Uint8Array(r[o-1-b]);1==b&&(f[f.length-1].dispose=2);var R=new Uint8Array(d*h*4);new Uint32Array(R.buffer);e._copyTile(y,t,n,R,d,h,-c,-u,0),e._copyTile(s,t,n,R,d,h,-c,-u,3)?(e._copyTile(s,t,n,R,d,h,-c,-u,2),v=1):(e._copyTile(s,t,n,R,d,h,-c,-u,0),v=0),s=R}f.push({rect:{x:c,y:u,width:d,height:h},img:s,blend:v,dispose:a?1:0})}return f},e.encode._filterZero=function(t,n,i,a,f,o){if(-1!=o){for(var s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,o);return r.deflate(f)}for(var l=[],c=0;c<5;c++)if(!(n*a>5e5)||2!=c&&3!=c&&4!=c){for(s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,c);if(l.push(r.deflate(f)),1==i)break}for(var u,d=1e9,h=0;h<l.length;h++)l[h].length<d&&(u=h,d=l[h].length);return l[u]},e.encode._filterLine=function(r,t,n,i,a,f){var o=n*i,s=o+n,l=e.decode._paeth;if(r[s]=f,s++,0==f)for(var c=0;c<i;c++)r[s+c]=t[o+c];else if(1==f){for(c=0;c<a;c++)r[s+c]=t[o+c];for(c=a;c<i;c++)r[s+c]=t[o+c]-t[o+c-a]+256&255}else if(0==n){for(c=0;c<a;c++)r[s+c]=t[o+c];if(2==f)for(c=a;c<i;c++)r[s+c]=t[o+c];if(3==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-(t[o+c-a]>>1)+256&255;if(4==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-l(t[o+c-a],0,0)+256&255}else{if(2==f)for(c=0;c<i;c++)r[s+c]=t[o+c]+256-t[o+c-i]&255;if(3==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-(t[o+c-i]>>1)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-(t[o+c-i]+t[o+c-a]>>1)&255}if(4==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-l(0,t[o+c-i],0)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-l(t[o+c-a],t[o+c-i],t[o+c-a-i])&255}}},e.crc={table:function(){for(var e=new Uint32Array(256),r=0;r<256;r++){for(var t=r,n=0;n<8;n++)1&t?t=3988292384^t>>>1:t>>>=1;e[r]=t}return e}(),update:function(r,t,n,i){for(var a=0;a<i;a++)r=e.crc.table[255&(r^t[n+a])]^r>>>8;return r},crc:function(r,t,n){return 4294967295^e.crc.update(4294967295,r,t,n)}},e.quantize=function(r,t){for(var n=new Uint8Array(r),i=n.slice(0),a=new Uint32Array(i.buffer),f=e.quantize.getKDtree(i,t),o=f[0],s=f[1],l=(e.quantize.planeDst,n),c=a,u=l.length,d=new Uint8Array(n.length>>2),h=0;h<u;h+=4){var v=l[h]*(1/255),p=l[h+1]*(1/255),b=l[h+2]*(1/255),g=l[h+3]*(1/255),m=e.quantize.getNearest(o,v,p,b,g);d[h>>2]=m.ind,c[h>>2]=m.est.rgba}return{abuf:i.buffer,inds:d,plte:s}},e.quantize.getKDtree=function(r,t,n){null==n&&(n=1e-4);var i=new Uint32Array(r.buffer),a={i0:0,i1:r.length,bst:null,est:null,tdst:0,left:null,right:null};a.bst=e.quantize.stats(r,a.i0,a.i1),a.est=e.quantize.estats(a.bst);for(var f=[a];f.length<t;){for(var o=0,s=0,l=0;l<f.length;l++)f[l].est.L>o&&(o=f[l].est.L,s=l);if(o<n)break;var c=f[s],u=e.quantize.splitPixels(r,i,c.i0,c.i1,c.est.e,c.est.eMq255);if(c.i0>=u||c.i1<=u)c.est.L=0;else{var d={i0:c.i0,i1:u,bst:null,est:null,tdst:0,left:null,right:null};d.bst=e.quantize.stats(r,d.i0,d.i1),d.est=e.quantize.estats(d.bst);var h={i0:u,i1:c.i1,bst:null,est:null,tdst:0,left:null,right:null};h.bst={R:[],m:[],N:c.bst.N-d.bst.N};for(l=0;l<16;l++)h.bst.R[l]=c.bst.R[l]-d.bst.R[l];for(l=0;l<4;l++)h.bst.m[l]=c.bst.m[l]-d.bst.m[l];h.est=e.quantize.estats(h.bst),c.left=d,c.right=h,f[s]=d,f.push(h)}}f.sort(function(e,r){return r.bst.N-e.bst.N});for(l=0;l<f.length;l++)f[l].ind=l;return[a,f]},e.quantize.getNearest=function(r,t,n,i,a){if(null==r.left)return r.tdst=e.quantize.dist(r.est.q,t,n,i,a),r;var f=e.quantize.planeDst(r.est,t,n,i,a),o=r.left,s=r.right;f>0&&(o=r.right,s=r.left);var l=e.quantize.getNearest(o,t,n,i,a);if(l.tdst<=f*f)return l;var c=e.quantize.getNearest(s,t,n,i,a);return c.tdst<l.tdst?c:l},e.quantize.planeDst=function(e,r,t,n,i){var a=e.e;return a[0]*r+a[1]*t+a[2]*n+a[3]*i-e.eMq},e.quantize.dist=function(e,r,t,n,i){var a=r-e[0],f=t-e[1],o=n-e[2],s=i-e[3];return a*a+f*f+o*o+s*s},e.quantize.splitPixels=function(r,t,n,i,a,f){var o=e.quantize.vecDot;i-=4;for(;n<i;){for(;o(r,n,a)<=f;)n+=4;for(;o(r,i,a)>f;)i-=4;if(n>=i)break;var s=t[n>>2];t[n>>2]=t[i>>2],t[i>>2]=s,n+=4,i-=4}for(;o(r,n,a)>f;)n-=4;return n+4},e.quantize.vecDot=function(e,r,t){return e[r]*t[0]+e[r+1]*t[1]+e[r+2]*t[2]+e[r+3]*t[3]},e.quantize.stats=function(e,r,t){for(var n=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],i=[0,0,0,0],a=t-r>>2,f=r;f<t;f+=4){var o=e[f]*(1/255),s=e[f+1]*(1/255),l=e[f+2]*(1/255),c=e[f+3]*(1/255);i[0]+=o,i[1]+=s,i[2]+=l,i[3]+=c,n[0]+=o*o,n[1]+=o*s,n[2]+=o*l,n[3]+=o*c,n[5]+=s*s,n[6]+=s*l,n[7]+=s*c,n[10]+=l*l,n[11]+=l*c,n[15]+=c*c}return n[4]=n[1],n[8]=n[2],n[9]=n[6],n[12]=n[3],n[13]=n[7],n[14]=n[11],{R:n,m:i,N:a}},e.quantize.estats=function(r){var t=r.R,n=r.m,i=r.N,a=n[0],f=n[1],o=n[2],s=n[3],l=0==i?0:1/i,c=[t[0]-a*a*l,t[1]-a*f*l,t[2]-a*o*l,t[3]-a*s*l,t[4]-f*a*l,t[5]-f*f*l,t[6]-f*o*l,t[7]-f*s*l,t[8]-o*a*l,t[9]-o*f*l,t[10]-o*o*l,t[11]-o*s*l,t[12]-s*a*l,t[13]-s*f*l,t[14]-s*o*l,t[15]-s*s*l],u=c,d=e.M4,h=[.5,.5,.5,.5],v=0,p=0;if(0!=i)for(var b=0;b<10&&(h=d.multVec(u,h),p=Math.sqrt(d.dot(h,h)),h=d.sml(1/p,h),!(Math.abs(p-v)<1e-9));b++)v=p;var g=[a*l,f*l,o*l,s*l];return{Cov:c,q:g,e:h,L:v,eMq255:d.dot(d.sml(255,g),h),eMq:d.dot(h,g),rgba:(Math.round(255*g[3])<<24|Math.round(255*g[2])<<16|Math.round(255*g[1])<<8|Math.round(255*g[0])<<0)>>>0}},e.M4={multVec:function(e,r){return[e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3],e[4]*r[0]+e[5]*r[1]+e[6]*r[2]+e[7]*r[3],e[8]*r[0]+e[9]*r[1]+e[10]*r[2]+e[11]*r[3],e[12]*r[0]+e[13]*r[1]+e[14]*r[2]+e[15]*r[3]]},dot:function(e,r){return e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3]},sml:function(e,r){return[e*r[0],e*r[1],e*r[2],e*r[3]]}},e.encode.concatRGBA=function(e,r){for(var t=0,n=0;n<e.length;n++)t+=e[n].byteLength;var i=new Uint8Array(t),a=0;for(n=0;n<e.length;n++){for(var f=new Uint8Array(e[n]),o=f.length,s=0;s<o;s+=4){var l=f[s],c=f[s+1],u=f[s+2],d=f[s+3];r&&(d=0==(128&d)?0:255),0==d&&(l=c=u=0),i[a+s]=l,i[a+s+1]=c,i[a+s+2]=u,i[a+s+3]=d}a+=o}return i.buffer}}(e,"function"==typeof require?require("pako"):window.pako)}();
+</script>
+
+<script>
+    class Player {
+
+        constructor(container) {
+            this.container = container
+            this.global_frac = 0.0
+            this.container = document.getElementById(container)
+            this.progress = null;
+            this.mat = [[]]
+
+            this.player = this.container.querySelector('audio')
+            this.demo_img = this.container.querySelector('.underlay > img')
+            this.overlay = this.container.querySelector('.overlay')
+            this.playpause = this.container.querySelector(".playpause");
+            this.download = this.container.querySelector(".download");
+            this.play_img = this.container.querySelector('.play-img')
+            this.pause_img = this.container.querySelector('.pause-img')
+            this.canvas = this.container.querySelector('.response-canvas')
+            this.response_container = this.container.querySelector('.response')
+            this.context = this.canvas.getContext('2d');
+
+            // console.log(this.player.duration)
+            var togglePlayPause = () => {
+                if (this.player.networkState !== 1) {
+                    return
+                }
+                if (this.player.paused || this.player.ended) {
+                    this.play()
+                } else {
+                    this.pause()
+                }
+            }
+
+            this.update = () => {
+                this.global_frac = this.player.currentTime / this.player.duration
+                // this.global_frac = frac
+                // console.log(this.player.currentTime, this.player.duration, this.global_frac)
+                this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
+                this.redraw()
+            }
+
+            // var start = null;
+            this.updateLoop = (timestamp) => {
+                // if (!start) start = timestamp;
+                // var progress = timestamp - start;
+                this.update()
+                // this.progress = setTimeout(this.updateLoop, 10)
+                this.progress = window.requestAnimationFrame(this.updateLoop)
+            }
+
+            this.seek = (e) => {
+                this.global_frac = e.offsetX / this.demo_img.width
+                this.player.currentTime = this.global_frac * this.player.duration
+                // console.log(this.global_frac)
+                this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
+                this.redraw()
+            }
+
+            var download_audio = () => {
+                var url = this.player.querySelector('#src').src
+                const a = document.createElement('a')
+                a.href = url
+                a.download = "download"
+                document.body.appendChild(a)
+                a.click()
+                document.body.removeChild(a)
+            }
+
+            this.demo_img.onclick = this.seek;
+            this.playpause.disabled = true
+            this.player.onplay = this.updateLoop
+            this.player.onpause = () => {
+                window.cancelAnimationFrame(this.progress)
+                this.update();
+            }
+            this.player.onended = () => {this.pause()}
+            this.playpause.onclick = togglePlayPause;
+            this.download.onclick = download_audio;
+        }
+
+        load(audio_fname, img_fname, levels_fname) {
+            this.pause()
+            window.cancelAnimationFrame(this.progress)
+            this.playpause.disabled = true
+
+            this.player.querySelector('#src').setAttribute("src", audio_fname)
+            this.player.load()
+            this.demo_img.setAttribute("src", img_fname)
+            this.overlay.style.width = '0%'
+
+            fetch(levels_fname)
+            .then(response => response.arrayBuffer())
+            .then(text => {
+                this.mat = this.parse(text);
+                this.playpause.disabled = false;
+                this.redraw();
+            })
+        }
+
+        parse(buffer) {
+            var img = UPNG.decode(buffer)
+            var dat = UPNG.toRGBA8(img)[0]
+            var view = new DataView(dat)
+            var data = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
+
+            var min =100
+            var max = -100
+            var idx = 0
+            for (let i=0; i < img.height*img.width*4; i+=4) {
+                var rgba = [view.getUint8(i, 1) / 255, view.getUint8(i + 1, 1) / 255, view.getUint8(i + 2, 1) / 255, view.getUint8(i + 3, 1) / 255]
+                var norm = Math.pow(Math.pow(rgba[0], 2) + Math.pow(rgba[1], 2) + Math.pow(rgba[2], 2), 0.5)
+                data[idx % img.width][img.height - Math.floor(idx / img.width) - 1] = norm
+
+                idx += 1
+                min = Math.min(min, norm)
+                max = Math.max(max, norm)
+            }
+            for (let i = 0; i < data.length; i++) {
+                for (let j = 0; j < data[i].length; j++) {
+                    data[i][j] = Math.pow((data[i][j] - min) / (max - min), 1.5)
+                }
+            }
+            var data3 = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
+            for (let i = 0; i < data.length; i++) {
+                for (let j = 0; j < data[i].length; j++) {
+                    if (i == 0 || i == (data.length - 1)) {
+                        data3[i][j] = data[i][j]
+                    } else{
+                        data3[i][j] = 0.33*(data[i - 1][j]) + 0.33*(data[i][j]) + 0.33*(data[i + 1][j])
+                        // data3[i][j] = 0.00*(data[i - 1][j]) + 1.00*(data[i][j]) + 0.00*(data[i + 1][j])
+                    }
+                }
+            }
+
+            var scale = 5
+            var data2 = new Array(scale*img.width).fill(0).map(() => new Array(img.height).fill(0));
+            for (let j = 0; j < data[0].length; j++) {
+                for (let i = 0; i < data.length - 1; i++) {
+                    for (let k = 0; k < scale; k++) {
+                        data2[scale*i + k][j] = (1.0 - (k/scale))*data3[i][j] + (k / scale)*data3[i + 1][j]
+                    }
+                }
+            }
+            return data2
+        }
+
+        play() {
+            this.player.play();
+            this.play_img.style.display = 'none'
+            this.pause_img.style.display = 'block'
+        }
+
+        pause() {
+            this.player.pause();
+            this.pause_img.style.display = 'none'
+            this.play_img.style.display = 'block'
+        }
+
+        redraw() {
+            this.canvas.width = window.devicePixelRatio*this.response_container.offsetWidth;
+            this.canvas.height = window.devicePixelRatio*this.response_container.offsetHeight;
+
+            this.context.clearRect(0, 0, this.canvas.width, this.canvas.height)
+            this.canvas.style.width = (this.canvas.width / window.devicePixelRatio).toString() + "px";
+            this.canvas.style.height = (this.canvas.height / window.devicePixelRatio).toString() + "px";
+
+            var f = this.global_frac*this.mat.length
+            var tstep = Math.min(Math.floor(f), this.mat.length - 2)
+            var heights = this.mat[tstep]
+            var bar_width = (this.canvas.width / heights.length) - 1
+
+            for (let k = 0; k < heights.length - 1; k++) {
+                var height = Math.max(Math.round((heights[k])*this.canvas.height), 3)
+                this.context.fillStyle = '#696f7b';
+                this.context.fillRect(k*(bar_width + 1), (this.canvas.height - height) / 2, bar_width, height);
+            }
+        }
+    }
+</script>
diff --git a/audiotools/core/templates/pandoc.css b/audiotools/core/templates/pandoc.css
new file mode 100644
index 0000000000000000000000000000000000000000..842be7be6d65580dab44c6a8013259644f38e6ee
--- /dev/null
+++ b/audiotools/core/templates/pandoc.css
@@ -0,0 +1,407 @@
+/*
+Copyright (c) 2017 Chris Patuzzo
+https://twitter.com/chrispatuzzo
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
+
+body {
+  font-family: Helvetica, arial, sans-serif;
+  font-size: 14px;
+  line-height: 1.6;
+  padding-top: 10px;
+  padding-bottom: 10px;
+  background-color: white;
+  padding: 30px;
+  color: #333;
+}
+
+body > *:first-child {
+  margin-top: 0 !important;
+}
+
+body > *:last-child {
+  margin-bottom: 0 !important;
+}
+
+a {
+  color: #4183C4;
+  text-decoration: none;
+}
+
+a.absent {
+  color: #cc0000;
+}
+
+a.anchor {
+  display: block;
+  padding-left: 30px;
+  margin-left: -30px;
+  cursor: pointer;
+  position: absolute;
+  top: 0;
+  left: 0;
+  bottom: 0;
+}
+
+h1, h2, h3, h4, h5, h6 {
+  margin: 20px 0 10px;
+  padding: 0;
+  font-weight: bold;
+  -webkit-font-smoothing: antialiased;
+  cursor: text;
+  position: relative;
+}
+
+h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
+  text-decoration: none;
+}
+
+h1 tt, h1 code {
+  font-size: inherit;
+}
+
+h2 tt, h2 code {
+  font-size: inherit;
+}
+
+h3 tt, h3 code {
+  font-size: inherit;
+}
+
+h4 tt, h4 code {
+  font-size: inherit;
+}
+
+h5 tt, h5 code {
+  font-size: inherit;
+}
+
+h6 tt, h6 code {
+  font-size: inherit;
+}
+
+h1 {
+  font-size: 28px;
+  color: black;
+}
+
+h2 {
+  font-size: 24px;
+  border-bottom: 1px solid #cccccc;
+  color: black;
+}
+
+h3 {
+  font-size: 18px;
+}
+
+h4 {
+  font-size: 16px;
+}
+
+h5 {
+  font-size: 14px;
+}
+
+h6 {
+  color: #777777;
+  font-size: 14px;
+}
+
+p, blockquote, ul, ol, dl, li, table, pre {
+  margin: 15px 0;
+}
+
+hr {
+  border: 0 none;
+  color: #cccccc;
+  height: 4px;
+  padding: 0;
+}
+
+body > h2:first-child {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+body > h1:first-child {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+body > h1:first-child + h2 {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
+  margin-top: 0;
+  padding-top: 0;
+}
+
+h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
+  margin-top: 0;
+}
+
+li p.first {
+  display: inline-block;
+}
+
+ul, ol {
+  padding-left: 30px;
+}
+
+ul :first-child, ol :first-child {
+  margin-top: 0;
+}
+
+ul :last-child, ol :last-child {
+  margin-bottom: 0;
+}
+
+dl {
+  padding: 0;
+}
+
+dl dt {
+  font-size: 14px;
+  font-weight: bold;
+  font-style: italic;
+  padding: 0;
+  margin: 15px 0 5px;
+}
+
+dl dt:first-child {
+  padding: 0;
+}
+
+dl dt > :first-child {
+  margin-top: 0;
+}
+
+dl dt > :last-child {
+  margin-bottom: 0;
+}
+
+dl dd {
+  margin: 0 0 15px;
+  padding: 0 15px;
+}
+
+dl dd > :first-child {
+  margin-top: 0;
+}
+
+dl dd > :last-child {
+  margin-bottom: 0;
+}
+
+blockquote {
+  border-left: 4px solid #dddddd;
+  padding: 0 15px;
+  color: #777777;
+}
+
+blockquote > :first-child {
+  margin-top: 0;
+}
+
+blockquote > :last-child {
+  margin-bottom: 0;
+}
+
+table {
+  padding: 0;
+}
+table tr {
+  border-top: 1px solid #cccccc;
+  background-color: white;
+  margin: 0;
+  padding: 0;
+}
+
+table tr:nth-child(2n) {
+  background-color: #f8f8f8;
+}
+
+table tr th {
+  font-weight: bold;
+  border: 1px solid #cccccc;
+  text-align: left;
+  margin: 0;
+  padding: 6px 13px;
+}
+
+table tr td {
+  border: 1px solid #cccccc;
+  text-align: left;
+  margin: 0;
+  padding: 6px 13px;
+}
+
+table tr th :first-child, table tr td :first-child {
+  margin-top: 0;
+}
+
+table tr th :last-child, table tr td :last-child {
+  margin-bottom: 0;
+}
+
+img {
+  max-width: 100%;
+}
+
+span.frame {
+  display: block;
+  overflow: hidden;
+}
+
+span.frame > span {
+  border: 1px solid #dddddd;
+  display: block;
+  float: left;
+  overflow: hidden;
+  margin: 13px 0 0;
+  padding: 7px;
+  width: auto;
+}
+
+span.frame span img {
+  display: block;
+  float: left;
+}
+
+span.frame span span {
+  clear: both;
+  color: #333333;
+  display: block;
+  padding: 5px 0 0;
+}
+
+span.align-center {
+  display: block;
+  overflow: hidden;
+  clear: both;
+}
+
+span.align-center > span {
+  display: block;
+  overflow: hidden;
+  margin: 13px auto 0;
+  text-align: center;
+}
+
+span.align-center span img {
+  margin: 0 auto;
+  text-align: center;
+}
+
+span.align-right {
+  display: block;
+  overflow: hidden;
+  clear: both;
+}
+
+span.align-right > span {
+  display: block;
+  overflow: hidden;
+  margin: 13px 0 0;
+  text-align: right;
+}
+
+span.align-right span img {
+  margin: 0;
+  text-align: right;
+}
+
+span.float-left {
+  display: block;
+  margin-right: 13px;
+  overflow: hidden;
+  float: left;
+}
+
+span.float-left span {
+  margin: 13px 0 0;
+}
+
+span.float-right {
+  display: block;
+  margin-left: 13px;
+  overflow: hidden;
+  float: right;
+}
+
+span.float-right > span {
+  display: block;
+  overflow: hidden;
+  margin: 13px auto 0;
+  text-align: right;
+}
+
+code, tt {
+  margin: 0 2px;
+  padding: 0 5px;
+  white-space: nowrap;
+  border-radius: 3px;
+}
+
+pre code {
+  margin: 0;
+  padding: 0;
+  white-space: pre;
+  border: none;
+  background: transparent;
+}
+
+.highlight pre {
+  font-size: 13px;
+  line-height: 19px;
+  overflow: auto;
+  padding: 6px 10px;
+  border-radius: 3px;
+}
+
+pre {
+  font-size: 13px;
+  line-height: 19px;
+  overflow: auto;
+  padding: 6px 10px;
+  border-radius: 3px;
+}
+
+pre code, pre tt {
+  background-color: transparent;
+  border: none;
+}
+
+body {
+  max-width: 600px;
+}
diff --git a/audiotools/core/templates/widget.html b/audiotools/core/templates/widget.html
new file mode 100644
index 0000000000000000000000000000000000000000..0b44e8aec64fd1db929da5fa6208dee00247c967
--- /dev/null
+++ b/audiotools/core/templates/widget.html
@@ -0,0 +1,52 @@
+<div id='PLAYER_ID' class='player' style="max-width: MAX_WIDTH;">
+    <div class='spectrogram' style="padding-top: PADDING_AMOUNT;">
+        <div class='overlay'></div>
+        <div class='underlay'>
+            <img>
+        </div>
+    </div>
+
+    <div class='audio-controls'>
+        <button id="playpause" disabled class='playpause' title="play">
+            <svg class='play-img' width="14px" height="19px" viewBox="0 0 14 19">
+                <polygon id="Triangle" fill="#000000" transform="translate(9, 9.5) rotate(90) translate(-7, -9.5) " points="7 2.5 16.5 16.5 -2.5 16.5"></polygon>
+            </svg>
+            <svg class='pause-img' width="16px" height="19px" viewBox="0 0 16 19">
+                <g fill="#000000" stroke="#000000">
+                    <rect id="Rectangle" x="0.5" y="0.5" width="4" height="18"></rect>
+                    <rect id="Rectangle" x="11.5" y="0.5" width="4" height="18"></rect>
+                </g>
+            </svg>
+        </button>
+
+        <audio class='play'>
+            <source id='src'>
+        </audio>
+        <div class='response'>
+            <canvas class='response-canvas'></canvas>
+        </div>
+
+        <button id="download" class='download' title="download">
+            <svg class='download-img' x="0px" y="0px" viewBox="0 0 29.978 29.978" style="enable-background:new 0 0 29.978 29.978;" xml:space="preserve">
+            <g>
+                <path d="M25.462,19.105v6.848H4.515v-6.848H0.489v8.861c0,1.111,0.9,2.012,2.016,2.012h24.967c1.115,0,2.016-0.9,2.016-2.012
+                    v-8.861H25.462z"/>
+                <path d="M14.62,18.426l-5.764-6.965c0,0-0.877-0.828,0.074-0.828s3.248,0,3.248,0s0-0.557,0-1.416c0-2.449,0-6.906,0-8.723
+                    c0,0-0.129-0.494,0.615-0.494c0.75,0,4.035,0,4.572,0c0.536,0,0.524,0.416,0.524,0.416c0,1.762,0,6.373,0,8.742
+                    c0,0.768,0,1.266,0,1.266s1.842,0,2.998,0c1.154,0,0.285,0.867,0.285,0.867s-4.904,6.51-5.588,7.193
+                    C15.092,18.979,14.62,18.426,14.62,18.426z"/>
+            </g>
+            </svg>
+        </button>
+    </div>
+</div>
+
+<script>
+    var PLAYER_ID = new Player('PLAYER_ID')
+    PLAYER_ID.load(
+        "AUDIO_SRC",
+        "IMAGE_SRC",
+        "LEVELS_SRC"
+    )
+    window.addEventListener("resize", function() {PLAYER_ID.redraw()})
+</script>
diff --git a/audiotools/core/util.py b/audiotools/core/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece1344658d10836aa2eb693f275294ad8cdbb52
--- /dev/null
+++ b/audiotools/core/util.py
@@ -0,0 +1,671 @@
+import csv
+import glob
+import math
+import numbers
+import os
+import random
+import typing
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict
+from typing import List
+
+import numpy as np
+import torch
+import torchaudio
+from flatten_dict import flatten
+from flatten_dict import unflatten
+
+
+@dataclass
+class Info:
+    """Shim for torchaudio.info API changes."""
+
+    sample_rate: float
+    num_frames: int
+
+    @property
+    def duration(self) -> float:
+        return self.num_frames / self.sample_rate
+
+
+def info(audio_path: str):
+    """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
+
+    Parameters
+    ----------
+    audio_path : str
+        Path to audio file.
+    """
+    # try default backend first, then fallback to soundfile
+    try:
+        info = torchaudio.info(str(audio_path))
+    except:  # pragma: no cover
+        info = torchaudio.backend.soundfile_backend.info(str(audio_path))
+
+    if isinstance(info, tuple):  # pragma: no cover
+        signal_info = info[0]
+        info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
+    else:
+        info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
+
+    return info
+
+
+def ensure_tensor(
+    x: typing.Union[np.ndarray, torch.Tensor, float, int],
+    ndim: int = None,
+    batch_size: int = None,
+):
+    """Ensures that the input ``x`` is a tensor of specified
+    dimensions and batch size.
+
+    Parameters
+    ----------
+    x : typing.Union[np.ndarray, torch.Tensor, float, int]
+        Data that will become a tensor on its way out.
+    ndim : int, optional
+        How many dimensions should be in the output, by default None
+    batch_size : int, optional
+        The batch size of the output, by default None
+
+    Returns
+    -------
+    torch.Tensor
+        Modified version of ``x`` as a tensor.
+    """
+    if not torch.is_tensor(x):
+        x = torch.as_tensor(x)
+    if ndim is not None:
+        assert x.ndim <= ndim
+        while x.ndim < ndim:
+            x = x.unsqueeze(-1)
+    if batch_size is not None:
+        if x.shape[0] != batch_size:
+            shape = list(x.shape)
+            shape[0] = batch_size
+            x = x.expand(*shape)
+    return x
+
+
+def _get_value(other):
+    from . import AudioSignal
+
+    if isinstance(other, AudioSignal):
+        return other.audio_data
+    return other
+
+
+def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
+    """Closest frequency bin given a frequency, number
+    of bins, and a sampling rate.
+
+    Parameters
+    ----------
+    hz : torch.Tensor
+       Tensor of frequencies in Hz.
+    n_fft : int
+        Number of FFT bins.
+    sample_rate : int
+        Sample rate of audio.
+
+    Returns
+    -------
+    torch.Tensor
+        Closest bins to the data.
+    """
+    shape = hz.shape
+    hz = hz.flatten()
+    freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
+    hz[hz > sample_rate / 2] = sample_rate / 2
+
+    closest = (hz[None, :] - freqs[:, None]).abs()
+    closest_bins = closest.min(dim=0).indices
+
+    return closest_bins.reshape(*shape)
+
+
+def random_state(seed: typing.Union[int, np.random.RandomState]):
+    """
+    Turn seed into a np.random.RandomState instance.
+
+    Parameters
+    ----------
+    seed : typing.Union[int, np.random.RandomState] or None
+        If seed is None, return the RandomState singleton used by np.random.
+        If seed is an int, return a new RandomState instance seeded with seed.
+        If seed is already a RandomState instance, return it.
+        Otherwise raise ValueError.
+
+    Returns
+    -------
+    np.random.RandomState
+        Random state object.
+
+    Raises
+    ------
+    ValueError
+        If seed is not valid, an error is thrown.
+    """
+    if seed is None or seed is np.random:
+        return np.random.mtrand._rand
+    elif isinstance(seed, (numbers.Integral, np.integer, int)):
+        return np.random.RandomState(seed)
+    elif isinstance(seed, np.random.RandomState):
+        return seed
+    else:
+        raise ValueError(
+            "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
+        )
+
+
+def seed(random_seed, set_cudnn=False):
+    """
+    Seeds all random states with the same random seed
+    for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
+    random generators.
+    For full reproducibility, two further options must be set
+    according to the torch documentation:
+    https://pytorch.org/docs/stable/notes/randomness.html
+    To do this, ``set_cudnn`` must be True. It defaults to
+    False, since setting it to True results in a performance
+    hit.
+
+    Args:
+        random_seed (int): integer corresponding to random seed to
+        use.
+        set_cudnn (bool): Whether or not to set cudnn into determinstic
+        mode and off of benchmark mode. Defaults to False.
+    """
+
+    torch.manual_seed(random_seed)
+    np.random.seed(random_seed)
+    random.seed(random_seed)
+
+    if set_cudnn:
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+
+@contextmanager
+def _close_temp_files(tmpfiles: list):
+    """Utility function for creating a context and closing all temporary files
+    once the context is exited. For correct functionality, all temporary file
+    handles created inside the context must be appended to the ```tmpfiles```
+    list.
+
+    This function is taken wholesale from Scaper.
+
+    Parameters
+    ----------
+    tmpfiles : list
+        List of temporary file handles
+    """
+
+    def _close():
+        for t in tmpfiles:
+            try:
+                t.close()
+                os.unlink(t.name)
+            except:
+                pass
+
+    try:
+        yield
+    except:  # pragma: no cover
+        _close()
+        raise
+    _close()
+
+
+AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
+
+
+def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
+    """Finds all audio files in a directory recursively.
+    Returns a list.
+
+    Parameters
+    ----------
+    folder : str
+        Folder to look for audio files in, recursively.
+    ext : List[str], optional
+        Extensions to look for without the ., by default
+        ``['.wav', '.flac', '.mp3', '.mp4']``.
+    """
+    folder = Path(folder)
+    # Take care of case where user has passed in an audio file directly
+    # into one of the calling functions.
+    if str(folder).endswith(tuple(ext)):
+        # if, however, there's a glob in the path, we need to
+        # return the glob, not the file.
+        if "*" in str(folder):
+            return glob.glob(str(folder), recursive=("**" in str(folder)))
+        else:
+            return [folder]
+
+    files = []
+    for x in ext:
+        files += folder.glob(f"**/*{x}")
+    return files
+
+
+def read_sources(
+    sources: List[str],
+    remove_empty: bool = True,
+    relative_path: str = "",
+    ext: List[str] = AUDIO_EXTENSIONS,
+):
+    """Reads audio sources that can either be folders
+    full of audio files, or CSV files that contain paths
+    to audio files. CSV files that adhere to the expected
+    format can be generated by
+    :py:func:`audiotools.data.preprocess.create_csv`.
+
+    Parameters
+    ----------
+    sources : List[str]
+        List of audio sources to be converted into a
+        list of lists of audio files.
+    remove_empty : bool, optional
+        Whether or not to remove rows with an empty "path"
+        from each CSV file, by default True.
+
+    Returns
+    -------
+    list
+        List of lists of rows of CSV files.
+    """
+    files = []
+    relative_path = Path(relative_path)
+    for source in sources:
+        source = str(source)
+        _files = []
+        if source.endswith(".csv"):
+            with open(source, "r") as f:
+                reader = csv.DictReader(f)
+                for x in reader:
+                    if remove_empty and x["path"] == "":
+                        continue
+                    if x["path"] != "":
+                        x["path"] = str(relative_path / x["path"])
+                    _files.append(x)
+        else:
+            for x in find_audio(source, ext=ext):
+                x = str(relative_path / x)
+                _files.append({"path": x})
+        files.append(sorted(_files, key=lambda x: x["path"]))
+    return files
+
+
+def choose_from_list_of_lists(
+    state: np.random.RandomState, list_of_lists: list, p: float = None
+):
+    """Choose a single item from a list of lists.
+
+    Parameters
+    ----------
+    state : np.random.RandomState
+        Random state to use when choosing an item.
+    list_of_lists : list
+        A list of lists from which items will be drawn.
+    p : float, optional
+        Probabilities of each list, by default None
+
+    Returns
+    -------
+    typing.Any
+        An item from the list of lists.
+    """
+    source_idx = state.choice(list(range(len(list_of_lists))), p=p)
+    item_idx = state.randint(len(list_of_lists[source_idx]))
+    return list_of_lists[source_idx][item_idx], source_idx, item_idx
+
+
+@contextmanager
+def chdir(newdir: typing.Union[Path, str]):
+    """
+    Context manager for switching directories to run a
+    function. Useful for when you want to use relative
+    paths to different runs.
+
+    Parameters
+    ----------
+    newdir : typing.Union[Path, str]
+        Directory to switch to.
+    """
+    curdir = os.getcwd()
+    try:
+        os.chdir(newdir)
+        yield
+    finally:
+        os.chdir(curdir)
+
+
+def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
+    """Moves items in a batch (typically generated by a DataLoader as a list
+    or a dict) to the specified device. This works even if dictionaries
+    are nested.
+
+    Parameters
+    ----------
+    batch : typing.Union[dict, list, torch.Tensor]
+        Batch, typically generated by a dataloader, that will be moved to
+        the device.
+    device : str, optional
+        Device to move batch to, by default "cpu"
+
+    Returns
+    -------
+    typing.Union[dict, list, torch.Tensor]
+        Batch with all values moved to the specified device.
+    """
+    if isinstance(batch, dict):
+        batch = flatten(batch)
+        for key, val in batch.items():
+            try:
+                batch[key] = val.to(device)
+            except:
+                pass
+        batch = unflatten(batch)
+    elif torch.is_tensor(batch):
+        batch = batch.to(device)
+    elif isinstance(batch, list):
+        for i in range(len(batch)):
+            try:
+                batch[i] = batch[i].to(device)
+            except:
+                pass
+    return batch
+
+
+def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
+    """Samples from a distribution defined by a tuple. The first
+    item in the tuple is the distribution type, and the rest of the
+    items are arguments to that distribution. The distribution function
+    is gotten from the ``np.random.RandomState`` object.
+
+    Parameters
+    ----------
+    dist_tuple : tuple
+        Distribution tuple
+    state : np.random.RandomState, optional
+        Random state, or seed to use, by default None
+
+    Returns
+    -------
+    typing.Union[float, int, str]
+        Draw from the distribution.
+
+    Examples
+    --------
+    Sample from a uniform distribution:
+
+    >>> dist_tuple = ("uniform", 0, 1)
+    >>> sample_from_dist(dist_tuple)
+
+    Sample from a constant distribution:
+
+    >>> dist_tuple = ("const", 0)
+    >>> sample_from_dist(dist_tuple)
+
+    Sample from a normal distribution:
+
+    >>> dist_tuple = ("normal", 0, 0.5)
+    >>> sample_from_dist(dist_tuple)
+
+    """
+    if dist_tuple[0] == "const":
+        return dist_tuple[1]
+    state = random_state(state)
+    dist_fn = getattr(state, dist_tuple[0])
+    return dist_fn(*dist_tuple[1:])
+
+
+def collate(list_of_dicts: list, n_splits: int = None):
+    """Collates a list of dictionaries (e.g. as returned by a
+    dataloader) into a dictionary with batched values. This routine
+    uses the default torch collate function for everything
+    except AudioSignal objects, which are handled by the
+    :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
+    function.
+
+    This function takes n_splits to enable splitting a batch
+    into multiple sub-batches for the purposes of gradient accumulation,
+    etc.
+
+    Parameters
+    ----------
+    list_of_dicts : list
+        List of dictionaries to be collated.
+    n_splits : int
+        Number of splits to make when creating the batches (split into
+        sub-batches). Useful for things like gradient accumulation.
+
+    Returns
+    -------
+    dict
+        Dictionary containing batched data.
+    """
+
+    from . import AudioSignal
+
+    batches = []
+    list_len = len(list_of_dicts)
+
+    return_list = False if n_splits is None else True
+    n_splits = 1 if n_splits is None else n_splits
+    n_items = int(math.ceil(list_len / n_splits))
+
+    for i in range(0, list_len, n_items):
+        # Flatten the dictionaries to avoid recursion.
+        list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
+        dict_of_lists = {
+            k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
+        }
+
+        batch = {}
+        for k, v in dict_of_lists.items():
+            if isinstance(v, list):
+                if all(isinstance(s, AudioSignal) for s in v):
+                    batch[k] = AudioSignal.batch(v, pad_signals=True)
+                else:
+                    # Borrow the default collate fn from torch.
+                    batch[k] = torch.utils.data._utils.collate.default_collate(v)
+        batches.append(unflatten(batch))
+
+    batches = batches[0] if not return_list else batches
+    return batches
+
+
+BASE_SIZE = 864
+DEFAULT_FIG_SIZE = (9, 3)
+
+
+def format_figure(
+    fig_size: tuple = None,
+    title: str = None,
+    fig=None,
+    format_axes: bool = True,
+    format: bool = True,
+    font_color: str = "white",
+):
+    """Prettifies the spectrogram and waveform plots. A title
+    can be inset into the top right corner, and the axes can be
+    inset into the figure, allowing the data to take up the entire
+    image. Used in
+
+    - :py:func:`audiotools.core.display.DisplayMixin.specshow`
+    - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
+    - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
+
+    Parameters
+    ----------
+    fig_size : tuple, optional
+        Size of figure, by default (9, 3)
+    title : str, optional
+        Title to inset in top right, by default None
+    fig : matplotlib.figure.Figure, optional
+        Figure object, if None ``plt.gcf()`` will be used, by default None
+    format_axes : bool, optional
+        Format the axes to be inside the figure, by default True
+    format : bool, optional
+        This formatting can be skipped entirely by passing ``format=False``
+        to any of the plotting functions that use this formater, by default True
+    font_color : str, optional
+        Color of font of axes, by default "white"
+    """
+    import matplotlib
+    import matplotlib.pyplot as plt
+
+    if fig_size is None:
+        fig_size = DEFAULT_FIG_SIZE
+    if not format:
+        return
+    if fig is None:
+        fig = plt.gcf()
+    fig.set_size_inches(*fig_size)
+    axs = fig.axes
+
+    pixels = (fig.get_size_inches() * fig.dpi)[0]
+    font_scale = pixels / BASE_SIZE
+
+    if format_axes:
+        axs = fig.axes
+
+        for ax in axs:
+            ymin, _ = ax.get_ylim()
+            xmin, _ = ax.get_xlim()
+
+            ticks = ax.get_yticks()
+            for t in ticks[2:-1]:
+                t = axs[0].annotate(
+                    f"{(t / 1000):2.1f}k",
+                    xy=(xmin, t),
+                    xycoords="data",
+                    xytext=(5, -5),
+                    textcoords="offset points",
+                    ha="left",
+                    va="top",
+                    color=font_color,
+                    fontsize=12 * font_scale,
+                    alpha=0.75,
+                )
+
+            ticks = ax.get_xticks()[2:]
+            for t in ticks[:-1]:
+                t = axs[0].annotate(
+                    f"{t:2.1f}s",
+                    xy=(t, ymin),
+                    xycoords="data",
+                    xytext=(5, 5),
+                    textcoords="offset points",
+                    ha="center",
+                    va="bottom",
+                    color=font_color,
+                    fontsize=12 * font_scale,
+                    alpha=0.75,
+                )
+
+            ax.margins(0, 0)
+            ax.set_axis_off()
+            ax.xaxis.set_major_locator(plt.NullLocator())
+            ax.yaxis.set_major_locator(plt.NullLocator())
+
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+
+    if title is not None:
+        t = axs[0].annotate(
+            title,
+            xy=(1, 1),
+            xycoords="axes fraction",
+            fontsize=20 * font_scale,
+            xytext=(-5, -5),
+            textcoords="offset points",
+            ha="right",
+            va="top",
+            color="white",
+        )
+        t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
+
+
+def generate_chord_dataset(
+    max_voices: int = 8,
+    sample_rate: int = 44100,
+    num_items: int = 5,
+    duration: float = 1.0,
+    min_note: str = "C2",
+    max_note: str = "C6",
+    output_dir: Path = "chords",
+):
+    """
+    Generates a toy multitrack dataset of chords, synthesized from sine waves.
+
+
+    Parameters
+    ----------
+    max_voices : int, optional
+        Maximum number of voices in a chord, by default 8
+    sample_rate : int, optional
+        Sample rate of audio, by default 44100
+    num_items : int, optional
+        Number of items to generate, by default 5
+    duration : float, optional
+        Duration of each item, by default 1.0
+    min_note : str, optional
+        Minimum note in the dataset, by default "C2"
+    max_note : str, optional
+        Maximum note in the dataset, by default "C6"
+    output_dir : Path, optional
+        Directory to save the dataset, by default "chords"
+
+    """
+    import librosa
+    from . import AudioSignal
+    from ..data.preprocess import create_csv
+
+    min_midi = librosa.note_to_midi(min_note)
+    max_midi = librosa.note_to_midi(max_note)
+
+    tracks = []
+    for idx in range(num_items):
+        track = {}
+        # figure out how many voices to put in this track
+        num_voices = random.randint(1, max_voices)
+        for voice_idx in range(num_voices):
+            # choose some random params
+            midinote = random.randint(min_midi, max_midi)
+            dur = random.uniform(0.85 * duration, duration)
+
+            sig = AudioSignal.wave(
+                frequency=librosa.midi_to_hz(midinote),
+                duration=dur,
+                sample_rate=sample_rate,
+                shape="sine",
+            )
+            track[f"voice_{voice_idx}"] = sig
+        tracks.append(track)
+
+    # save the tracks to disk
+    output_dir = Path(output_dir)
+    output_dir.mkdir(exist_ok=True)
+    for idx, track in enumerate(tracks):
+        track_dir = output_dir / f"track_{idx}"
+        track_dir.mkdir(exist_ok=True)
+        for voice_name, sig in track.items():
+            sig.write(track_dir / f"{voice_name}.wav")
+
+    all_voices = list(set([k for track in tracks for k in track.keys()]))
+    voice_lists = {voice: [] for voice in all_voices}
+    for track in tracks:
+        for voice_name in all_voices:
+            if voice_name in track:
+                voice_lists[voice_name].append(track[voice_name].path_to_file)
+            else:
+                voice_lists[voice_name].append("")
+
+    for voice_name, paths in voice_lists.items():
+        create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
+
+    return output_dir
diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c071f934fc3e2be3138e7596b1c6d2ef79eade
--- /dev/null
+++ b/audiotools/core/whisper.py
@@ -0,0 +1,97 @@
+import torch
+
+
+class WhisperMixin:
+    is_initialized = False
+
+    def setup_whisper(
+        self,
+        pretrained_model_name_or_path: str = "openai/whisper-base.en",
+        device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+    ):
+        from transformers import WhisperForConditionalGeneration
+        from transformers import WhisperProcessor
+
+        self.whisper_device = device
+        self.whisper_processor = WhisperProcessor.from_pretrained(
+            pretrained_model_name_or_path
+        )
+        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
+            pretrained_model_name_or_path
+        ).to(self.whisper_device)
+        self.is_initialized = True
+
+    def get_whisper_features(self) -> torch.Tensor:
+        """Preprocess audio signal as per the whisper model's training config.
+
+        Returns
+        -------
+        torch.Tensor
+            The prepinput features of the audio signal. Shape: (1, channels, seq_len)
+        """
+        import torch
+
+        if not self.is_initialized:
+            self.setup_whisper()
+
+        signal = self.to(self.device)
+        raw_speech = list(
+            (
+                signal.clone()
+                .resample(self.whisper_processor.feature_extractor.sampling_rate)
+                .audio_data[:, 0, :]
+                .numpy()
+            )
+        )
+
+        with torch.inference_mode():
+            input_features = self.whisper_processor(
+                raw_speech,
+                sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
+                return_tensors="pt",
+            ).input_features
+
+        return input_features
+
+    def get_whisper_transcript(self) -> str:
+        """Get the transcript of the audio signal using the whisper model.
+
+        Returns
+        -------
+        str
+            The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
+        """
+
+        if not self.is_initialized:
+            self.setup_whisper()
+
+        input_features = self.get_whisper_features()
+
+        with torch.inference_mode():
+            input_features = input_features.to(self.whisper_device)
+            generated_ids = self.whisper_model.generate(inputs=input_features)
+
+        transcription = self.whisper_processor.batch_decode(generated_ids)
+        return transcription[0]
+
+    def get_whisper_embeddings(self) -> torch.Tensor:
+        """Get the last hidden state embeddings of the audio signal using the whisper model.
+
+        Returns
+        -------
+        torch.Tensor
+            The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
+        """
+        import torch
+
+        if not self.is_initialized:
+            self.setup_whisper()
+
+        input_features = self.get_whisper_features()
+        encoder = self.whisper_model.get_encoder()
+
+        with torch.inference_mode():
+            input_features = input_features.to(self.whisper_device)
+            embeddings = encoder(input_features)
+
+        return embeddings.last_hidden_state
diff --git a/audiotools/data/__init__.py b/audiotools/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aead269f26f3782043e68418b4c87ee323cbd015
--- /dev/null
+++ b/audiotools/data/__init__.py
@@ -0,0 +1,3 @@
+from . import datasets
+from . import preprocess
+from . import transforms
diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e7a60963399aa15ff865de2d06537818ce18ee
--- /dev/null
+++ b/audiotools/data/datasets.py
@@ -0,0 +1,517 @@
+from pathlib import Path
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
+
+import numpy as np
+from torch.utils.data import SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from ..core import AudioSignal
+from ..core import util
+
+
+class AudioLoader:
+    """Loads audio endlessly from a list of audio sources
+    containing paths to audio files. Audio sources can be
+    folders full of audio files (which are found via file
+    extension) or by providing a CSV file which contains paths
+    to audio files.
+
+    Parameters
+    ----------
+    sources : List[str], optional
+        Sources containing folders, or CSVs with
+        paths to audio files, by default None
+    weights : List[float], optional
+        Weights to sample audio files from each source, by default None
+    relative_path : str, optional
+        Path audio should be loaded relative to, by default ""
+    transform : Callable, optional
+        Transform to instantiate alongside audio sample,
+        by default None
+    ext : List[str]
+        List of extensions to find audio within each source by. Can
+        also be a file name (e.g. "vocals.wav"). by default
+        ``['.wav', '.flac', '.mp3', '.mp4']``.
+    shuffle: bool
+        Whether to shuffle the files within the dataloader. Defaults to True.
+    shuffle_state: int
+        State to use to seed the shuffle of the files.
+    """
+
+    def __init__(
+        self,
+        sources: List[str] = None,
+        weights: List[float] = None,
+        transform: Callable = None,
+        relative_path: str = "",
+        ext: List[str] = util.AUDIO_EXTENSIONS,
+        shuffle: bool = True,
+        shuffle_state: int = 0,
+    ):
+        self.audio_lists = util.read_sources(
+            sources, relative_path=relative_path, ext=ext
+        )
+
+        self.audio_indices = [
+            (src_idx, item_idx)
+            for src_idx, src in enumerate(self.audio_lists)
+            for item_idx in range(len(src))
+        ]
+        if shuffle:
+            state = util.random_state(shuffle_state)
+            state.shuffle(self.audio_indices)
+
+        self.sources = sources
+        self.weights = weights
+        self.transform = transform
+
+    def __call__(
+        self,
+        state,
+        sample_rate: int,
+        duration: float,
+        loudness_cutoff: float = -40,
+        num_channels: int = 1,
+        offset: float = None,
+        source_idx: int = None,
+        item_idx: int = None,
+        global_idx: int = None,
+    ):
+        if source_idx is not None and item_idx is not None:
+            try:
+                audio_info = self.audio_lists[source_idx][item_idx]
+            except:
+                audio_info = {"path": "none"}
+        elif global_idx is not None:
+            source_idx, item_idx = self.audio_indices[
+                global_idx % len(self.audio_indices)
+            ]
+            audio_info = self.audio_lists[source_idx][item_idx]
+        else:
+            audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
+                state, self.audio_lists, p=self.weights
+            )
+
+        path = audio_info["path"]
+        signal = AudioSignal.zeros(duration, sample_rate, num_channels)
+
+        if path != "none":
+            if offset is None:
+                signal = AudioSignal.salient_excerpt(
+                    path,
+                    duration=duration,
+                    state=state,
+                    loudness_cutoff=loudness_cutoff,
+                )
+            else:
+                signal = AudioSignal(
+                    path,
+                    offset=offset,
+                    duration=duration,
+                )
+
+        if num_channels == 1:
+            signal = signal.to_mono()
+        signal = signal.resample(sample_rate)
+
+        if signal.duration < duration:
+            signal = signal.zero_pad_to(int(duration * sample_rate))
+
+        for k, v in audio_info.items():
+            signal.metadata[k] = v
+
+        item = {
+            "signal": signal,
+            "source_idx": source_idx,
+            "item_idx": item_idx,
+            "source": str(self.sources[source_idx]),
+            "path": str(path),
+        }
+        if self.transform is not None:
+            item["transform_args"] = self.transform.instantiate(state, signal=signal)
+        return item
+
+
+def default_matcher(x, y):
+    return Path(x).parent == Path(y).parent
+
+
+def align_lists(lists, matcher: Callable = default_matcher):
+    longest_list = lists[np.argmax([len(l) for l in lists])]
+    for i, x in enumerate(longest_list):
+        for l in lists:
+            if i >= len(l):
+                l.append({"path": "none"})
+            elif not matcher(l[i]["path"], x["path"]):
+                l.insert(i, {"path": "none"})
+    return lists
+
+
+class AudioDataset:
+    """Loads audio from multiple loaders (with associated transforms)
+    for a specified number of samples. Excerpts are drawn randomly
+    of the specified duration, above a specified loudness threshold
+    and are resampled on the fly to the desired sample rate
+    (if it is different from the audio source sample rate).
+
+    This takes either a single AudioLoader object,
+    a dictionary of AudioLoader objects, or a dictionary of AudioLoader
+    objects. Each AudioLoader is called by the dataset, and the
+    result is placed in the output dictionary. A transform can also be
+    specified for the entire dataset, rather than for each specific
+    loader. This transform can be applied to the output of all the
+    loaders if desired.
+
+    AudioLoader objects can be specified as aligned, which means the
+    loaders correspond to multitrack audio (e.g. a vocals, bass,
+    drums, and other loader for multitrack music mixtures).
+
+
+    Parameters
+    ----------
+    loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
+        AudioLoaders to sample audio from.
+    sample_rate : int
+        Desired sample rate.
+    n_examples : int, optional
+        Number of examples (length of dataset), by default 1000
+    duration : float, optional
+        Duration of audio samples, by default 0.5
+    loudness_cutoff : float, optional
+        Loudness cutoff threshold for audio samples, by default -40
+    num_channels : int, optional
+        Number of channels in output audio, by default 1
+    transform : Callable, optional
+        Transform to instantiate alongside each dataset item, by default None
+    aligned : bool, optional
+        Whether the loaders should be sampled in an aligned manner (e.g. same
+        offset, duration, and matched file name), by default False
+    shuffle_loaders : bool, optional
+        Whether to shuffle the loaders before sampling from them, by default False
+    matcher : Callable
+        How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
+        by default uses the parent directory of each file.
+    without_replacement : bool
+        Whether to choose files with or without replacement, by default True.
+
+
+    Examples
+    --------
+    >>> from audiotools.data.datasets import AudioLoader
+    >>> from audiotools.data.datasets import AudioDataset
+    >>> from audiotools import transforms as tfm
+    >>> import numpy as np
+    >>>
+    >>> loaders = [
+    >>>     AudioLoader(
+    >>>         sources=[f"tests/audio/spk"],
+    >>>         transform=tfm.Equalizer(),
+    >>>         ext=["wav"],
+    >>>     )
+    >>>     for i in range(5)
+    >>> ]
+    >>>
+    >>> dataset = AudioDataset(
+    >>>     loaders = loaders,
+    >>>     sample_rate = 44100,
+    >>>     duration = 1.0,
+    >>>     transform = tfm.RescaleAudio(),
+    >>> )
+    >>>
+    >>> item = dataset[np.random.randint(len(dataset))]
+    >>>
+    >>> for i in range(len(loaders)):
+    >>>     item[i]["signal"] = loaders[i].transform(
+    >>>         item[i]["signal"], **item[i]["transform_args"]
+    >>>     )
+    >>>     item[i]["signal"].widget(i)
+    >>>
+    >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
+    >>> mix = dataset.transform(mix, **item["transform_args"])
+    >>> mix.widget("mix")
+
+    Below is an example of how one could load MUSDB multitrack data:
+
+    >>> import audiotools as at
+    >>> from pathlib import Path
+    >>> from audiotools import transforms as tfm
+    >>> import numpy as np
+    >>> import torch
+    >>>
+    >>> def build_dataset(
+    >>>     sample_rate: int = 44100,
+    >>>     duration: float = 5.0,
+    >>>     musdb_path: str = "~/.data/musdb/",
+    >>> ):
+    >>>     musdb_path = Path(musdb_path).expanduser()
+    >>>     loaders = {
+    >>>         src: at.datasets.AudioLoader(
+    >>>             sources=[musdb_path],
+    >>>             transform=tfm.Compose(
+    >>>                 tfm.VolumeNorm(("uniform", -20, -10)),
+    >>>                 tfm.Silence(prob=0.1),
+    >>>             ),
+    >>>             ext=[f"{src}.wav"],
+    >>>         )
+    >>>         for src in ["vocals", "bass", "drums", "other"]
+    >>>     }
+    >>>
+    >>>     dataset = at.datasets.AudioDataset(
+    >>>         loaders=loaders,
+    >>>         sample_rate=sample_rate,
+    >>>         duration=duration,
+    >>>         num_channels=1,
+    >>>         aligned=True,
+    >>>         transform=tfm.RescaleAudio(),
+    >>>         shuffle_loaders=True,
+    >>>     )
+    >>>     return dataset, list(loaders.keys())
+    >>>
+    >>> train_data, sources = build_dataset()
+    >>> dataloader = torch.utils.data.DataLoader(
+    >>>     train_data,
+    >>>     batch_size=16,
+    >>>     num_workers=0,
+    >>>     collate_fn=train_data.collate,
+    >>> )
+    >>> batch = next(iter(dataloader))
+    >>>
+    >>> for k in sources:
+    >>>     src = batch[k]
+    >>>     src["transformed"] = train_data.loaders[k].transform(
+    >>>         src["signal"].clone(), **src["transform_args"]
+    >>>     )
+    >>>
+    >>> mixture = sum(batch[k]["transformed"] for k in sources)
+    >>> mixture = train_data.transform(mixture, **batch["transform_args"])
+    >>>
+    >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
+    >>> # Construct the targets:
+    >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
+
+    Similarly, here's example code for loading Slakh data:
+
+    >>> import audiotools as at
+    >>> from pathlib import Path
+    >>> from audiotools import transforms as tfm
+    >>> import numpy as np
+    >>> import torch
+    >>> import glob
+    >>>
+    >>> def build_dataset(
+    >>>     sample_rate: int = 16000,
+    >>>     duration: float = 10.0,
+    >>>     slakh_path: str = "~/.data/slakh/",
+    >>> ):
+    >>>     slakh_path = Path(slakh_path).expanduser()
+    >>>
+    >>>     # Find the max number of sources in Slakh
+    >>>     src_names = [x.name for x in list(slakh_path.glob("**/*.wav"))  if "S" in str(x.name)]
+    >>>     n_sources = len(list(set(src_names)))
+    >>>
+    >>>     loaders = {
+    >>>         f"S{i:02d}": at.datasets.AudioLoader(
+    >>>             sources=[slakh_path],
+    >>>             transform=tfm.Compose(
+    >>>                 tfm.VolumeNorm(("uniform", -20, -10)),
+    >>>                 tfm.Silence(prob=0.1),
+    >>>             ),
+    >>>             ext=[f"S{i:02d}.wav"],
+    >>>         )
+    >>>         for i in range(n_sources)
+    >>>     }
+    >>>     dataset = at.datasets.AudioDataset(
+    >>>         loaders=loaders,
+    >>>         sample_rate=sample_rate,
+    >>>         duration=duration,
+    >>>         num_channels=1,
+    >>>         aligned=True,
+    >>>         transform=tfm.RescaleAudio(),
+    >>>         shuffle_loaders=False,
+    >>>     )
+    >>>
+    >>>     return dataset, list(loaders.keys())
+    >>>
+    >>> train_data, sources = build_dataset()
+    >>> dataloader = torch.utils.data.DataLoader(
+    >>>     train_data,
+    >>>     batch_size=16,
+    >>>     num_workers=0,
+    >>>     collate_fn=train_data.collate,
+    >>> )
+    >>> batch = next(iter(dataloader))
+    >>>
+    >>> for k in sources:
+    >>>     src = batch[k]
+    >>>     src["transformed"] = train_data.loaders[k].transform(
+    >>>         src["signal"].clone(), **src["transform_args"]
+    >>>     )
+    >>>
+    >>> mixture = sum(batch[k]["transformed"] for k in sources)
+    >>> mixture = train_data.transform(mixture, **batch["transform_args"])
+
+    """
+
+    def __init__(
+        self,
+        loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
+        sample_rate: int,
+        n_examples: int = 1000,
+        duration: float = 0.5,
+        offset: float = None,
+        loudness_cutoff: float = -40,
+        num_channels: int = 1,
+        transform: Callable = None,
+        aligned: bool = False,
+        shuffle_loaders: bool = False,
+        matcher: Callable = default_matcher,
+        without_replacement: bool = True,
+    ):
+        # Internally we convert loaders to a dictionary
+        if isinstance(loaders, list):
+            loaders = {i: l for i, l in enumerate(loaders)}
+        elif isinstance(loaders, AudioLoader):
+            loaders = {0: loaders}
+
+        self.loaders = loaders
+        self.loudness_cutoff = loudness_cutoff
+        self.num_channels = num_channels
+
+        self.length = n_examples
+        self.transform = transform
+        self.sample_rate = sample_rate
+        self.duration = duration
+        self.offset = offset
+        self.aligned = aligned
+        self.shuffle_loaders = shuffle_loaders
+        self.without_replacement = without_replacement
+
+        if aligned:
+            loaders_list = list(loaders.values())
+            for i in range(len(loaders_list[0].audio_lists)):
+                input_lists = [l.audio_lists[i] for l in loaders_list]
+                # Alignment happens in-place
+                align_lists(input_lists, matcher)
+
+    def __getitem__(self, idx):
+        state = util.random_state(idx)
+        offset = None if self.offset is None else self.offset
+        item = {}
+
+        keys = list(self.loaders.keys())
+        if self.shuffle_loaders:
+            state.shuffle(keys)
+
+        loader_kwargs = {
+            "state": state,
+            "sample_rate": self.sample_rate,
+            "duration": self.duration,
+            "loudness_cutoff": self.loudness_cutoff,
+            "num_channels": self.num_channels,
+            "global_idx": idx if self.without_replacement else None,
+        }
+
+        # Draw item from first loader
+        loader = self.loaders[keys[0]]
+        item[keys[0]] = loader(**loader_kwargs)
+
+        for key in keys[1:]:
+            loader = self.loaders[key]
+            if self.aligned:
+                # Path mapper takes the current loader + everything
+                # returned by the first loader.
+                offset = item[keys[0]]["signal"].metadata["offset"]
+                loader_kwargs.update(
+                    {
+                        "offset": offset,
+                        "source_idx": item[keys[0]]["source_idx"],
+                        "item_idx": item[keys[0]]["item_idx"],
+                    }
+                )
+            item[key] = loader(**loader_kwargs)
+
+        # Sort dictionary back into original order
+        keys = list(self.loaders.keys())
+        item = {k: item[k] for k in keys}
+
+        item["idx"] = idx
+        if self.transform is not None:
+            item["transform_args"] = self.transform.instantiate(
+                state=state, signal=item[keys[0]]["signal"]
+            )
+
+        # If there's only one loader, pop it up
+        # to the main dictionary, instead of keeping it
+        # nested.
+        if len(keys) == 1:
+            item.update(item.pop(keys[0]))
+
+        return item
+
+    def __len__(self):
+        return self.length
+
+    @staticmethod
+    def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
+        """Collates items drawn from this dataset. Uses
+        :py:func:`audiotools.core.util.collate`.
+
+        Parameters
+        ----------
+        list_of_dicts : typing.Union[list, dict]
+            Data drawn from each item.
+        n_splits : int
+            Number of splits to make when creating the batches (split into
+            sub-batches). Useful for things like gradient accumulation.
+
+        Returns
+        -------
+        dict
+            Dictionary of batched data.
+        """
+        return util.collate(list_of_dicts, n_splits=n_splits)
+
+
+class ConcatDataset(AudioDataset):
+    def __init__(self, datasets: list):
+        self.datasets = datasets
+
+    def __len__(self):
+        return sum([len(d) for d in self.datasets])
+
+    def __getitem__(self, idx):
+        dataset = self.datasets[idx % len(self.datasets)]
+        return dataset[idx // len(self.datasets)]
+
+
+class ResumableDistributedSampler(DistributedSampler):  # pragma: no cover
+    """Distributed sampler that can be resumed from a given start index."""
+
+    def __init__(self, dataset, start_idx: int = None, **kwargs):
+        super().__init__(dataset, **kwargs)
+        # Start index, allows to resume an experiment at the index it was
+        self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
+
+    def __iter__(self):
+        for i, idx in enumerate(super().__iter__()):
+            if i >= self.start_idx:
+                yield idx
+        self.start_idx = 0  # set the index back to 0 so for the next epoch
+
+
+class ResumableSequentialSampler(SequentialSampler):  # pragma: no cover
+    """Sequential sampler that can be resumed from a given start index."""
+
+    def __init__(self, dataset, start_idx: int = None, **kwargs):
+        super().__init__(dataset, **kwargs)
+        # Start index, allows to resume an experiment at the index it was
+        self.start_idx = start_idx if start_idx is not None else 0
+
+    def __iter__(self):
+        for i, idx in enumerate(super().__iter__()):
+            if i >= self.start_idx:
+                yield idx
+        self.start_idx = 0  # set the index back to 0 so for the next epoch
diff --git a/audiotools/data/preprocess.py b/audiotools/data/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90de210115e45838bc8d69b350f7516ba730406
--- /dev/null
+++ b/audiotools/data/preprocess.py
@@ -0,0 +1,81 @@
+import csv
+import os
+from pathlib import Path
+
+from tqdm import tqdm
+
+from ..core import AudioSignal
+
+
+def create_csv(
+    audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
+):
+    """Converts a folder of audio files to a CSV file. If ``loudness = True``,
+    the output of this function will create a CSV file that looks something
+    like:
+
+    ..  csv-table::
+        :header: path,loudness
+
+        daps/produced/f1_script1_produced.wav,-16.299999237060547
+        daps/produced/f1_script2_produced.wav,-16.600000381469727
+        daps/produced/f1_script3_produced.wav,-17.299999237060547
+        daps/produced/f1_script4_produced.wav,-16.100000381469727
+        daps/produced/f1_script5_produced.wav,-16.700000762939453
+        daps/produced/f3_script1_produced.wav,-16.5
+
+    ..  note::
+        The paths above are written relative to the ``data_path`` argument
+        which defaults to the environment variable ``PATH_TO_DATA`` if
+        it isn't passed to this function, and defaults to the empty string
+        if that environment variable is not set.
+
+    You can produce a CSV file from a directory of audio files via:
+
+    >>> import audiotools
+    >>> directory = ...
+    >>> audio_files = audiotools.util.find_audio(directory)
+    >>> output_path = "train.csv"
+    >>> audiotools.data.preprocess.create_csv(
+    >>>     audio_files, output_csv, loudness=True
+    >>> )
+
+    Note that you can create empty rows in the CSV file by passing an empty
+    string or None in the ``audio_files`` list. This is useful if you want to
+    sync multiple CSV files in a multitrack setting. The loudness of these
+    empty rows will be set to -inf.
+
+    Parameters
+    ----------
+    audio_files : list
+        List of audio files.
+    output_csv : Path
+        Output CSV, with each row containing the relative path of every file
+        to ``data_path``, if specified (defaults to None).
+    loudness : bool
+        Compute loudness of entire file and store alongside path.
+    """
+
+    info = []
+    pbar = tqdm(audio_files)
+    for af in pbar:
+        af = Path(af)
+        pbar.set_description(f"Processing {af.name}")
+        _info = {}
+        if af.name == "":
+            _info["path"] = ""
+            if loudness:
+                _info["loudness"] = -float("inf")
+        else:
+            _info["path"] = af.relative_to(data_path) if data_path is not None else af
+            if loudness:
+                _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
+
+        info.append(_info)
+
+    with open(output_csv, "w") as f:
+        writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
+        writer.writeheader()
+
+        for item in info:
+            writer.writerow(item)
diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..504e87dc61777e36ba95eb794f497bed4cdc7d2c
--- /dev/null
+++ b/audiotools/data/transforms.py
@@ -0,0 +1,1592 @@
+import copy
+from contextlib import contextmanager
+from inspect import signature
+from typing import List
+
+import numpy as np
+import torch
+from flatten_dict import flatten
+from flatten_dict import unflatten
+from numpy.random import RandomState
+
+from .. import ml
+from ..core import AudioSignal
+from ..core import util
+from .datasets import AudioLoader
+
+tt = torch.tensor
+"""Shorthand for converting things to torch.tensor."""
+
+
+class BaseTransform:
+    """This is the base class for all transforms that are implemented
+    in this library. Transforms have two main operations: ``transform``
+    and ``instantiate``.
+
+    ``instantiate`` sets the parameters randomly
+    from distribution tuples for each parameter. For example, for the
+    ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
+    is chosen randomly by instantiate. By default, it chosen uniformly
+    between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
+
+    ``transform`` applies the transform using the instantiated parameters.
+    A simple example is as follows:
+
+    >>> seed = 0
+    >>> signal = ...
+    >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
+    >>> kwargs = transform.instantiate()
+    >>> output = transform(signal.clone(), **kwargs)
+
+    By breaking apart the instantiation of parameters from the actual audio
+    processing of the transform, we can make things more reproducible, while
+    also applying the transform on batches of data efficiently on GPU,
+    rather than on individual audio samples.
+
+    ..  note::
+        We call ``signal.clone()`` for the input to the ``transform`` function
+        because signals are modified in-place! If you don't clone the signal,
+        you will lose the original data.
+
+    Parameters
+    ----------
+    keys : list, optional
+        Keys that the transform looks for when
+        calling ``self.transform``, by default []. In general this is
+        set automatically, and you won't need to manipulate this argument.
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+
+    Examples
+    --------
+
+    >>> seed = 0
+    >>>
+    >>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
+    >>> signal = AudioSignal(audio_path, offset=10, duration=2)
+    >>> transform = tfm.Compose(
+    >>>     [
+    >>>         tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
+    >>>         tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
+    >>>     ],
+    >>> )
+    >>>
+    >>> kwargs = transform.instantiate(seed, signal)
+    >>> output = transform(signal, **kwargs)
+
+    """
+
+    def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
+        # Get keys from the _transform signature.
+        tfm_keys = list(signature(self._transform).parameters.keys())
+
+        # Filter out signal and kwargs keys.
+        ignore_keys = ["signal", "kwargs"]
+        tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
+
+        # Combine keys specified by the child class, the keys found in
+        # _transform signature, and the mask key.
+        self.keys = keys + tfm_keys + ["mask"]
+
+        self.prob = prob
+
+        if name is None:
+            name = self.__class__.__name__
+        self.name = name
+
+    def _prepare(self, batch: dict):
+        sub_batch = batch[self.name]
+
+        for k in self.keys:
+            assert k in sub_batch.keys(), f"{k} not in batch"
+
+        return sub_batch
+
+    def _transform(self, signal):
+        return signal
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        return {}
+
+    @staticmethod
+    def apply_mask(batch: dict, mask: torch.Tensor):
+        """Applies a mask to the batch.
+
+        Parameters
+        ----------
+        batch : dict
+            Batch whose values will be masked in the ``transform`` pass.
+        mask : torch.Tensor
+            Mask to apply to batch.
+
+        Returns
+        -------
+        dict
+            A dictionary that contains values only where ``mask = True``.
+        """
+        masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
+        return unflatten(masked_batch)
+
+    def transform(self, signal: AudioSignal, **kwargs):
+        """Apply the transform to the audio signal,
+        with given keyword arguments.
+
+        Parameters
+        ----------
+        signal : AudioSignal
+            Signal that will be modified by the transforms in-place.
+        kwargs: dict
+            Keyword arguments to the specific transforms ``self._transform``
+            function.
+
+        Returns
+        -------
+        AudioSignal
+            Transformed AudioSignal.
+
+        Examples
+        --------
+
+        >>> for seed in range(10):
+        >>>     kwargs = transform.instantiate(seed, signal)
+        >>>     output = transform(signal.clone(), **kwargs)
+
+        """
+        tfm_kwargs = self._prepare(kwargs)
+        mask = tfm_kwargs["mask"]
+
+        if torch.any(mask):
+            tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
+            tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
+            signal[mask] = self._transform(signal[mask], **tfm_kwargs)
+
+        return signal
+
+    def __call__(self, *args, **kwargs):
+        return self.transform(*args, **kwargs)
+
+    def instantiate(
+        self,
+        state: RandomState = None,
+        signal: AudioSignal = None,
+    ):
+        """Instantiates parameters for the transform.
+
+        Parameters
+        ----------
+        state : RandomState, optional
+            _description_, by default None
+        signal : AudioSignal, optional
+            _description_, by default None
+
+        Returns
+        -------
+        dict
+            Dictionary containing instantiated arguments for every keyword
+            argument to ``self._transform``.
+
+        Examples
+        --------
+
+        >>> for seed in range(10):
+        >>>     kwargs = transform.instantiate(seed, signal)
+        >>>     output = transform(signal.clone(), **kwargs)
+
+        """
+        state = util.random_state(state)
+
+        # Not all instantiates need the signal. Check if signal
+        # is needed before passing it in, so that the end-user
+        # doesn't need to have variables they're not using flowing
+        # into their function.
+        needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
+        kwargs = {}
+        if needs_signal:
+            kwargs = {"signal": signal}
+
+        # Instantiate the parameters for the transform.
+        params = self._instantiate(state, **kwargs)
+        for k in list(params.keys()):
+            v = params[k]
+            if isinstance(v, (AudioSignal, torch.Tensor, dict)):
+                params[k] = v
+            else:
+                params[k] = tt(v)
+        mask = state.rand() <= self.prob
+        params[f"mask"] = tt(mask)
+
+        # Put the params into a nested dictionary that will be
+        # used later when calling the transform. This is to avoid
+        # collisions in the dictionary.
+        params = {self.name: params}
+
+        return params
+
+    def batch_instantiate(
+        self,
+        states: list = None,
+        signal: AudioSignal = None,
+    ):
+        """Instantiates arguments for every item in a batch,
+        given a list of states. Each state in the list
+        corresponds to one item in the batch.
+
+        Parameters
+        ----------
+        states : list, optional
+            List of states, by default None
+        signal : AudioSignal, optional
+            AudioSignal to pass to the ``self.instantiate`` section
+            if it is needed for this transform, by default None
+
+        Returns
+        -------
+        dict
+            Collated dictionary of arguments.
+
+        Examples
+        --------
+
+        >>> batch_size = 4
+        >>> signal = AudioSignal(audio_path, offset=10, duration=2)
+        >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
+        >>>
+        >>> states = [seed + idx for idx in list(range(batch_size))]
+        >>> kwargs = transform.batch_instantiate(states, signal_batch)
+        >>> batch_output = transform(signal_batch, **kwargs)
+        """
+        kwargs = []
+        for state in states:
+            kwargs.append(self.instantiate(state, signal))
+        kwargs = util.collate(kwargs)
+        return kwargs
+
+
+class Identity(BaseTransform):
+    """This transform just returns the original signal."""
+
+    pass
+
+
+class SpectralTransform(BaseTransform):
+    """Spectral transforms require STFT data to exist, since manipulations
+    of the STFT require the spectrogram. This just calls ``stft`` before
+    the transform is called, and calls ``istft`` after the transform is
+    called so that the audio data is written to after the spectral
+    manipulation.
+    """
+
+    def transform(self, signal, **kwargs):
+        signal.stft()
+        super().transform(signal, **kwargs)
+        signal.istft()
+        return signal
+
+
+class Compose(BaseTransform):
+    """Compose applies transforms in sequence, one after the other. The
+    transforms are passed in as positional arguments or as a list like so:
+
+    >>> transform = tfm.Compose(
+    >>>     [
+    >>>         tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
+    >>>         tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
+    >>>     ],
+    >>> )
+
+    This will convolve the signal with a room impulse response, and then
+    add background noise to the signal. Instantiate instantiates
+    all the parameters for every transform in the transform list so the
+    interface for using the Compose transform is the same as everything
+    else:
+
+    >>> kwargs = transform.instantiate()
+    >>> output = transform(signal.clone(), **kwargs)
+
+    Under the hood, the transform maps each transform to a unique name
+    under the hood of the form ``{position}.{name}``, where ``position``
+    is the index of the transform in the list. ``Compose`` can nest
+    within other ``Compose`` transforms, like so:
+
+    >>> preprocess = transforms.Compose(
+    >>>     tfm.GlobalVolumeNorm(),
+    >>>     tfm.CrossTalk(),
+    >>>     name="preprocess",
+    >>> )
+    >>> augment = transforms.Compose(
+    >>>     tfm.RoomImpulseResponse(),
+    >>>     tfm.BackgroundNoise(),
+    >>>     name="augment",
+    >>> )
+    >>> postprocess = transforms.Compose(
+    >>>     tfm.VolumeChange(),
+    >>>     tfm.RescaleAudio(),
+    >>>     tfm.ShiftPhase(),
+    >>>     name="postprocess",
+    >>> )
+    >>> transform = transforms.Compose(preprocess, augment, postprocess),
+
+    This defines 3 composed transforms, and then composes them in sequence
+    with one another.
+
+    Parameters
+    ----------
+    *transforms : list
+        List of transforms to apply
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
+        if isinstance(transforms[0], list):
+            transforms = transforms[0]
+
+        for i, tfm in enumerate(transforms):
+            tfm.name = f"{i}.{tfm.name}"
+
+        keys = [tfm.name for tfm in transforms]
+        super().__init__(keys=keys, name=name, prob=prob)
+
+        self.transforms = transforms
+        self.transforms_to_apply = keys
+
+    @contextmanager
+    def filter(self, *names: list):
+        """This can be used to skip transforms entirely when applying
+        the sequence of transforms to a signal. For example, take
+        the following transforms with the names ``preprocess, augment, postprocess``.
+
+        >>> preprocess = transforms.Compose(
+        >>>     tfm.GlobalVolumeNorm(),
+        >>>     tfm.CrossTalk(),
+        >>>     name="preprocess",
+        >>> )
+        >>> augment = transforms.Compose(
+        >>>     tfm.RoomImpulseResponse(),
+        >>>     tfm.BackgroundNoise(),
+        >>>     name="augment",
+        >>> )
+        >>> postprocess = transforms.Compose(
+        >>>     tfm.VolumeChange(),
+        >>>     tfm.RescaleAudio(),
+        >>>     tfm.ShiftPhase(),
+        >>>     name="postprocess",
+        >>> )
+        >>> transform = transforms.Compose(preprocess, augment, postprocess)
+
+        If we wanted to apply all 3 to a signal, we do:
+
+        >>> kwargs = transform.instantiate()
+        >>> output = transform(signal.clone(), **kwargs)
+
+        But if we only wanted to apply the ``preprocess`` and ``postprocess``
+        transforms to the signal, we do:
+
+        >>> with transform_fn.filter("preprocess", "postprocess"):
+        >>>     output = transform(signal.clone(), **kwargs)
+
+        Parameters
+        ----------
+        *names : list
+            List of transforms, identified by name, to apply to signal.
+        """
+        old_transforms = self.transforms_to_apply
+        self.transforms_to_apply = names
+        yield
+        self.transforms_to_apply = old_transforms
+
+    def _transform(self, signal, **kwargs):
+        for transform in self.transforms:
+            if any([x in transform.name for x in self.transforms_to_apply]):
+                signal = transform(signal, **kwargs)
+        return signal
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        parameters = {}
+        for transform in self.transforms:
+            parameters.update(transform.instantiate(state, signal=signal))
+        return parameters
+
+    def __getitem__(self, idx):
+        return self.transforms[idx]
+
+    def __len__(self):
+        return len(self.transforms)
+
+    def __iter__(self):
+        for transform in self.transforms:
+            yield transform
+
+
+class Choose(Compose):
+    """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
+    but instead of applying all the transforms in sequence, it applies just a single transform,
+    which is chosen for each item in the batch.
+
+    Parameters
+    ----------
+    *transforms : list
+        List of transforms to apply
+    weights : list
+        Probability of choosing any specific transform.
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+
+    Examples
+    --------
+
+    >>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
+    """
+
+    def __init__(
+        self,
+        *transforms: list,
+        weights: list = None,
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(*transforms, name=name, prob=prob)
+
+        if weights is None:
+            _len = len(self.transforms)
+            weights = [1 / _len for _ in range(_len)]
+        self.weights = np.array(weights)
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        kwargs = super()._instantiate(state, signal)
+        tfm_idx = list(range(len(self.transforms)))
+        tfm_idx = state.choice(tfm_idx, p=self.weights)
+        one_hot = []
+        for i, t in enumerate(self.transforms):
+            mask = kwargs[t.name]["mask"]
+            if mask.item():
+                kwargs[t.name]["mask"] = tt(i == tfm_idx)
+            one_hot.append(kwargs[t.name]["mask"])
+        kwargs["one_hot"] = one_hot
+        return kwargs
+
+
+class Repeat(Compose):
+    """Repeatedly applies a given transform ``n_repeat`` times."
+
+    Parameters
+    ----------
+    transform : BaseTransform
+        Transform to repeat.
+    n_repeat : int, optional
+        Number of times to repeat transform, by default 1
+    """
+
+    def __init__(
+        self,
+        transform,
+        n_repeat: int = 1,
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        transforms = [copy.copy(transform) for _ in range(n_repeat)]
+        super().__init__(transforms, name=name, prob=prob)
+
+        self.n_repeat = n_repeat
+
+
+class RepeatUpTo(Choose):
+    """Repeatedly applies a given transform up to ``max_repeat`` times."
+
+    Parameters
+    ----------
+    transform : BaseTransform
+        Transform to repeat.
+    max_repeat : int, optional
+        Max number of times to repeat transform, by default 1
+    weights : list
+        Probability of choosing any specific number up to ``max_repeat``.
+    """
+
+    def __init__(
+        self,
+        transform,
+        max_repeat: int = 5,
+        weights: list = None,
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        transforms = []
+        for n in range(1, max_repeat):
+            transforms.append(Repeat(transform, n_repeat=n))
+        super().__init__(transforms, name=name, prob=prob, weights=weights)
+
+        self.max_repeat = max_repeat
+
+
+class ClippingDistortion(BaseTransform):
+    """Adds clipping distortion to signal. Corresponds
+    to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
+
+    Parameters
+    ----------
+    perc : tuple, optional
+        Clipping percentile. Values are between 0.0 to 1.0.
+        Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        perc: tuple = ("uniform", 0.0, 0.1),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.perc = perc
+
+    def _instantiate(self, state: RandomState):
+        return {"perc": util.sample_from_dist(self.perc, state)}
+
+    def _transform(self, signal, perc):
+        return signal.clip_distortion(perc)
+
+
+class Equalizer(BaseTransform):
+    """Applies an equalization curve to the audio signal. Corresponds
+    to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
+
+    Parameters
+    ----------
+    eq_amount : tuple, optional
+        The maximum dB cut to apply to the audio in any band,
+        by default ("const", 1.0 dB)
+    n_bands : int, optional
+        Number of bands in EQ, by default 6
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        eq_amount: tuple = ("const", 1.0),
+        n_bands: int = 6,
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.eq_amount = eq_amount
+        self.n_bands = n_bands
+
+    def _instantiate(self, state: RandomState):
+        eq_amount = util.sample_from_dist(self.eq_amount, state)
+        eq = -eq_amount * state.rand(self.n_bands)
+        return {"eq": eq}
+
+    def _transform(self, signal, eq):
+        return signal.equalizer(eq)
+
+
+class Quantization(BaseTransform):
+    """Applies quantization to the input waveform. Corresponds
+    to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
+
+    Parameters
+    ----------
+    channels : tuple, optional
+        Number of evenly spaced quantization channels to quantize
+        to, by default ("choice", [8, 32, 128, 256, 1024])
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.channels = channels
+
+    def _instantiate(self, state: RandomState):
+        return {"channels": util.sample_from_dist(self.channels, state)}
+
+    def _transform(self, signal, channels):
+        return signal.quantization(channels)
+
+
+class MuLawQuantization(BaseTransform):
+    """Applies mu-law quantization to the input waveform. Corresponds
+    to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
+
+    Parameters
+    ----------
+    channels : tuple, optional
+        Number of mu-law spaced quantization channels to quantize
+        to, by default ("choice", [8, 32, 128, 256, 1024])
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.channels = channels
+
+    def _instantiate(self, state: RandomState):
+        return {"channels": util.sample_from_dist(self.channels, state)}
+
+    def _transform(self, signal, channels):
+        return signal.mulaw_quantization(channels)
+
+
+class NoiseFloor(BaseTransform):
+    """Adds a noise floor of Gaussian noise to the signal at a specified
+    dB.
+
+    Parameters
+    ----------
+    db : tuple, optional
+        Level of noise to add to signal, by default ("const", -50.0)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        db: tuple = ("const", -50.0),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.db = db
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        db = util.sample_from_dist(self.db, state)
+        audio_data = state.randn(signal.num_channels, signal.signal_length)
+        nz_signal = AudioSignal(audio_data, signal.sample_rate)
+        nz_signal.normalize(db)
+        return {"nz_signal": nz_signal}
+
+    def _transform(self, signal, nz_signal):
+        # Clone bg_signal so that transform can be repeatedly applied
+        # to different signals with the same effect.
+        return signal + nz_signal
+
+
+class BackgroundNoise(BaseTransform):
+    """Adds background noise from audio specified by a set of CSV files.
+    A valid CSV file looks like, and is typically generated by
+    :py:func:`audiotools.data.preprocess.create_csv`:
+
+    ..  csv-table::
+        :header: path
+
+        room_tone/m6_script2_clean.wav
+        room_tone/m6_script2_cleanraw.wav
+        room_tone/m6_script2_ipad_balcony1.wav
+        room_tone/m6_script2_ipad_bedroom1.wav
+        room_tone/m6_script2_ipad_confroom1.wav
+        room_tone/m6_script2_ipad_confroom2.wav
+        room_tone/m6_script2_ipad_livingroom1.wav
+        room_tone/m6_script2_ipad_office1.wav
+
+    ..  note::
+        All paths are relative to an environment variable called ``PATH_TO_DATA``,
+        so that CSV files are portable across machines where data may be
+        located in different places.
+
+    This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
+    and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
+    hood.
+
+    Parameters
+    ----------
+    snr : tuple, optional
+        Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
+    sources : List[str], optional
+        Sources containing folders, or CSVs with paths to audio files,
+        by default None
+    weights : List[float], optional
+        Weights to sample audio files from each source, by default None
+    eq_amount : tuple, optional
+        Amount of equalization to apply, by default ("const", 1.0)
+    n_bands : int, optional
+        Number of bands in equalizer, by default 3
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    loudness_cutoff : float, optional
+        Loudness cutoff when loading from audio files, by default None
+    """
+
+    def __init__(
+        self,
+        snr: tuple = ("uniform", 10.0, 30.0),
+        sources: List[str] = None,
+        weights: List[float] = None,
+        eq_amount: tuple = ("const", 1.0),
+        n_bands: int = 3,
+        name: str = None,
+        prob: float = 1.0,
+        loudness_cutoff: float = None,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.snr = snr
+        self.eq_amount = eq_amount
+        self.n_bands = n_bands
+        self.loader = AudioLoader(sources, weights)
+        self.loudness_cutoff = loudness_cutoff
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        eq_amount = util.sample_from_dist(self.eq_amount, state)
+        eq = -eq_amount * state.rand(self.n_bands)
+        snr = util.sample_from_dist(self.snr, state)
+
+        bg_signal = self.loader(
+            state,
+            signal.sample_rate,
+            duration=signal.signal_duration,
+            loudness_cutoff=self.loudness_cutoff,
+            num_channels=signal.num_channels,
+        )["signal"]
+
+        return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
+
+    def _transform(self, signal, bg_signal, snr, eq):
+        # Clone bg_signal so that transform can be repeatedly applied
+        # to different signals with the same effect.
+        return signal.mix(bg_signal.clone(), snr, eq)
+
+
+class CrossTalk(BaseTransform):
+    """Adds crosstalk between speakers, whose audio is drawn from a CSV file
+    that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
+
+    This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
+    under the hood.
+
+    Parameters
+    ----------
+    snr : tuple, optional
+        How loud cross-talk speaker is relative to original signal in dB,
+        by default ("uniform", 0.0, 10.0)
+    sources : List[str], optional
+        Sources containing folders, or CSVs with paths to audio files,
+        by default None
+    weights : List[float], optional
+        Weights to sample audio files from each source, by default None
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    loudness_cutoff : float, optional
+        Loudness cutoff when loading from audio files, by default -40
+    """
+
+    def __init__(
+        self,
+        snr: tuple = ("uniform", 0.0, 10.0),
+        sources: List[str] = None,
+        weights: List[float] = None,
+        name: str = None,
+        prob: float = 1.0,
+        loudness_cutoff: float = -40,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.snr = snr
+        self.loader = AudioLoader(sources, weights)
+        self.loudness_cutoff = loudness_cutoff
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        snr = util.sample_from_dist(self.snr, state)
+        crosstalk_signal = self.loader(
+            state,
+            signal.sample_rate,
+            duration=signal.signal_duration,
+            loudness_cutoff=self.loudness_cutoff,
+            num_channels=signal.num_channels,
+        )["signal"]
+
+        return {"crosstalk_signal": crosstalk_signal, "snr": snr}
+
+    def _transform(self, signal, crosstalk_signal, snr):
+        # Clone bg_signal so that transform can be repeatedly applied
+        # to different signals with the same effect.
+        loudness = signal.loudness()
+        mix = signal.mix(crosstalk_signal.clone(), snr)
+        mix.normalize(loudness)
+        return mix
+
+
+class RoomImpulseResponse(BaseTransform):
+    """Convolves signal with a room impulse response, at a specified
+    direct-to-reverberant ratio, with equalization applied. Room impulse
+    response data is drawn from a CSV file that was produced via
+    :py:func:`audiotools.data.preprocess.create_csv`.
+
+    This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
+    under the hood.
+
+    Parameters
+    ----------
+    drr : tuple, optional
+        _description_, by default ("uniform", 0.0, 30.0)
+    sources : List[str], optional
+        Sources containing folders, or CSVs with paths to audio files,
+        by default None
+    weights : List[float], optional
+        Weights to sample audio files from each source, by default None
+    eq_amount : tuple, optional
+        Amount of equalization to apply, by default ("const", 1.0)
+    n_bands : int, optional
+        Number of bands in equalizer, by default 6
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    use_original_phase : bool, optional
+        Whether or not to use the original phase, by default False
+    offset : float, optional
+        Offset from each impulse response file to use, by default 0.0
+    duration : float, optional
+        Duration of each impulse response, by default 1.0
+    """
+
+    def __init__(
+        self,
+        drr: tuple = ("uniform", 0.0, 30.0),
+        sources: List[str] = None,
+        weights: List[float] = None,
+        eq_amount: tuple = ("const", 1.0),
+        n_bands: int = 6,
+        name: str = None,
+        prob: float = 1.0,
+        use_original_phase: bool = False,
+        offset: float = 0.0,
+        duration: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.drr = drr
+        self.eq_amount = eq_amount
+        self.n_bands = n_bands
+        self.use_original_phase = use_original_phase
+
+        self.loader = AudioLoader(sources, weights)
+        self.offset = offset
+        self.duration = duration
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        eq_amount = util.sample_from_dist(self.eq_amount, state)
+        eq = -eq_amount * state.rand(self.n_bands)
+        drr = util.sample_from_dist(self.drr, state)
+
+        ir_signal = self.loader(
+            state,
+            signal.sample_rate,
+            offset=self.offset,
+            duration=self.duration,
+            loudness_cutoff=None,
+            num_channels=signal.num_channels,
+        )["signal"]
+        ir_signal.zero_pad_to(signal.sample_rate)
+
+        return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
+
+    def _transform(self, signal, ir_signal, drr, eq):
+        # Clone ir_signal so that transform can be repeatedly applied
+        # to different signals with the same effect.
+        return signal.apply_ir(
+            ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
+        )
+
+
+class VolumeChange(BaseTransform):
+    """Changes the volume of the input signal.
+
+    Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
+
+    Parameters
+    ----------
+    db : tuple, optional
+        Change in volume in decibels, by default ("uniform", -12.0, 0.0)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        db: tuple = ("uniform", -12.0, 0.0),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.db = db
+
+    def _instantiate(self, state: RandomState):
+        return {"db": util.sample_from_dist(self.db, state)}
+
+    def _transform(self, signal, db):
+        return signal.volume_change(db)
+
+
+class VolumeNorm(BaseTransform):
+    """Normalizes the volume of the excerpt to a specified decibel.
+
+    Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
+
+    Parameters
+    ----------
+    db : tuple, optional
+        dB to normalize signal to, by default ("const", -24)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        db: tuple = ("const", -24),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.db = db
+
+    def _instantiate(self, state: RandomState):
+        return {"db": util.sample_from_dist(self.db, state)}
+
+    def _transform(self, signal, db):
+        return signal.normalize(db)
+
+
+class GlobalVolumeNorm(BaseTransform):
+    """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
+    transform also normalizes the volume of a signal, but it uses
+    the volume of the entire audio file the loaded excerpt comes from,
+    rather than the volume of just the excerpt. The volume of the
+    entire audio file is expected in ``signal.metadata["loudness"]``.
+    If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
+    with ``loudness = True``, like the following:
+
+    ..  csv-table::
+        :header: path,loudness
+
+        daps/produced/f1_script1_produced.wav,-16.299999237060547
+        daps/produced/f1_script2_produced.wav,-16.600000381469727
+        daps/produced/f1_script3_produced.wav,-17.299999237060547
+        daps/produced/f1_script4_produced.wav,-16.100000381469727
+        daps/produced/f1_script5_produced.wav,-16.700000762939453
+        daps/produced/f3_script1_produced.wav,-16.5
+
+    The ``AudioLoader`` will automatically load the loudness column into
+    the metadata of the signal.
+
+    Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
+
+    Parameters
+    ----------
+    db : tuple, optional
+        dB to normalize signal to, by default ("const", -24)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        db: tuple = ("const", -24),
+        name: str = None,
+        prob: float = 1.0,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.db = db
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        if "loudness" not in signal.metadata:
+            db_change = 0.0
+        elif float(signal.metadata["loudness"]) == float("-inf"):
+            db_change = 0.0
+        else:
+            db = util.sample_from_dist(self.db, state)
+            db_change = db - float(signal.metadata["loudness"])
+
+        return {"db": db_change}
+
+    def _transform(self, signal, db):
+        return signal.volume_change(db)
+
+
+class Silence(BaseTransform):
+    """Zeros out the signal with some probability.
+
+    Parameters
+    ----------
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 0.1
+    """
+
+    def __init__(self, name: str = None, prob: float = 0.1):
+        super().__init__(name=name, prob=prob)
+
+    def _transform(self, signal):
+        _loudness = signal._loudness
+        signal = AudioSignal(
+            torch.zeros_like(signal.audio_data),
+            sample_rate=signal.sample_rate,
+            stft_params=signal.stft_params,
+        )
+        # So that the amound of noise added is as if it wasn't silenced.
+        # TODO: improve this hack
+        signal._loudness = _loudness
+
+        return signal
+
+
+class LowPass(BaseTransform):
+    """Applies a LowPass filter.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
+
+    Parameters
+    ----------
+    cutoff : tuple, optional
+        Cutoff frequency distribution,
+        by default ``("choice", [4000, 8000, 16000])``
+    zeros : int, optional
+        Number of zero-crossings in filter, argument to
+        ``julius.LowPassFilters``, by default 51
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        cutoff: tuple = ("choice", [4000, 8000, 16000]),
+        zeros: int = 51,
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.cutoff = cutoff
+        self.zeros = zeros
+
+    def _instantiate(self, state: RandomState):
+        return {"cutoff": util.sample_from_dist(self.cutoff, state)}
+
+    def _transform(self, signal, cutoff):
+        return signal.low_pass(cutoff, zeros=self.zeros)
+
+
+class HighPass(BaseTransform):
+    """Applies a HighPass filter.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
+
+    Parameters
+    ----------
+    cutoff : tuple, optional
+        Cutoff frequency distribution,
+        by default ``("choice", [50, 100, 250, 500, 1000])``
+    zeros : int, optional
+        Number of zero-crossings in filter, argument to
+        ``julius.LowPassFilters``, by default 51
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
+        zeros: int = 51,
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+
+        self.cutoff = cutoff
+        self.zeros = zeros
+
+    def _instantiate(self, state: RandomState):
+        return {"cutoff": util.sample_from_dist(self.cutoff, state)}
+
+    def _transform(self, signal, cutoff):
+        return signal.high_pass(cutoff, zeros=self.zeros)
+
+
+class RescaleAudio(BaseTransform):
+    """Rescales the audio so it is in between ``-val`` and ``val``
+    only if the original audio exceeds those bounds. Useful if
+    transforms have caused the audio to clip.
+
+    Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
+
+    Parameters
+    ----------
+    val : float, optional
+        Max absolute value of signal, by default 1.0
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
+        super().__init__(name=name, prob=prob)
+
+        self.val = val
+
+    def _transform(self, signal):
+        return signal.ensure_max_of_audio(self.val)
+
+
+class ShiftPhase(SpectralTransform):
+    """Shifts the phase of the audio.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
+
+    Parameters
+    ----------
+    shift : tuple, optional
+        How much to shift phase by, by default ("uniform", -np.pi, np.pi)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        shift: tuple = ("uniform", -np.pi, np.pi),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.shift = shift
+
+    def _instantiate(self, state: RandomState):
+        return {"shift": util.sample_from_dist(self.shift, state)}
+
+    def _transform(self, signal, shift):
+        return signal.shift_phase(shift)
+
+
+class InvertPhase(ShiftPhase):
+    """Inverts the phase of the audio.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
+
+    Parameters
+    ----------
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(self, name: str = None, prob: float = 1):
+        super().__init__(shift=("const", np.pi), name=name, prob=prob)
+
+
+class CorruptPhase(SpectralTransform):
+    """Corrupts the phase of the audio.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
+
+    Parameters
+    ----------
+    scale : tuple, optional
+        How much to corrupt phase by, by default ("uniform", 0, np.pi)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
+    ):
+        super().__init__(name=name, prob=prob)
+        self.scale = scale
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        scale = util.sample_from_dist(self.scale, state)
+        corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
+        return {"corruption": corruption.astype("float32")}
+
+    def _transform(self, signal, corruption):
+        return signal.shift_phase(shift=corruption)
+
+
+class FrequencyMask(SpectralTransform):
+    """Masks a band of frequencies at a center frequency
+    from the audio.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
+
+    Parameters
+    ----------
+    f_center : tuple, optional
+        Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
+    f_width : tuple, optional
+        Width of zero'd out band, by default ("const", 0.1)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        f_center: tuple = ("uniform", 0.0, 1.0),
+        f_width: tuple = ("const", 0.1),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.f_center = f_center
+        self.f_width = f_width
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        f_center = util.sample_from_dist(self.f_center, state)
+        f_width = util.sample_from_dist(self.f_width, state)
+
+        fmin = max(f_center - (f_width / 2), 0.0)
+        fmax = min(f_center + (f_width / 2), 1.0)
+
+        fmin_hz = (signal.sample_rate / 2) * fmin
+        fmax_hz = (signal.sample_rate / 2) * fmax
+
+        return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
+
+    def _transform(self, signal, fmin_hz: float, fmax_hz: float):
+        return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
+
+
+class TimeMask(SpectralTransform):
+    """Masks out contiguous time-steps from signal.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
+
+    Parameters
+    ----------
+    t_center : tuple, optional
+        Center time in terms of 0.0 and 1.0 (duration of signal),
+        by default ("uniform", 0.0, 1.0)
+    t_width : tuple, optional
+        Width of dropped out portion, by default ("const", 0.025)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        t_center: tuple = ("uniform", 0.0, 1.0),
+        t_width: tuple = ("const", 0.025),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.t_center = t_center
+        self.t_width = t_width
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal):
+        t_center = util.sample_from_dist(self.t_center, state)
+        t_width = util.sample_from_dist(self.t_width, state)
+
+        tmin = max(t_center - (t_width / 2), 0.0)
+        tmax = min(t_center + (t_width / 2), 1.0)
+
+        tmin_s = signal.signal_duration * tmin
+        tmax_s = signal.signal_duration * tmax
+        return {"tmin_s": tmin_s, "tmax_s": tmax_s}
+
+    def _transform(self, signal, tmin_s: float, tmax_s: float):
+        return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
+
+
+class MaskLowMagnitudes(SpectralTransform):
+    """Masks low magnitude regions out of signal.
+
+    Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
+
+    Parameters
+    ----------
+    db_cutoff : tuple, optional
+        Decibel value for which things below it will be masked away,
+        by default ("uniform", -10, 10)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        db_cutoff: tuple = ("uniform", -10, 10),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.db_cutoff = db_cutoff
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
+
+    def _transform(self, signal, db_cutoff: float):
+        return signal.mask_low_magnitudes(db_cutoff)
+
+
+class Smoothing(BaseTransform):
+    """Convolves the signal with a smoothing window.
+
+    Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
+
+    Parameters
+    ----------
+    window_type : tuple, optional
+        Type of window to use, by default ("const", "average")
+    window_length : tuple, optional
+        Length of smoothing window, by
+        default ("choice", [8, 16, 32, 64, 128, 256, 512])
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        window_type: tuple = ("const", "average"),
+        window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(name=name, prob=prob)
+        self.window_type = window_type
+        self.window_length = window_length
+
+    def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+        window_type = util.sample_from_dist(self.window_type, state)
+        window_length = util.sample_from_dist(self.window_length, state)
+        window = signal.get_window(
+            window_type=window_type, window_length=window_length, device="cpu"
+        )
+        return {"window": AudioSignal(window, signal.sample_rate)}
+
+    def _transform(self, signal, window):
+        sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
+        sscale[sscale == 0.0] = 1.0
+
+        out = signal.convolve(window)
+
+        oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
+        oscale[oscale == 0.0] = 1.0
+
+        out = out * (sscale / oscale)
+        return out
+
+
+class TimeNoise(TimeMask):
+    """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
+    replaces with noise instead of zeros.
+
+    Parameters
+    ----------
+    t_center : tuple, optional
+        Center time in terms of 0.0 and 1.0 (duration of signal),
+        by default ("uniform", 0.0, 1.0)
+    t_width : tuple, optional
+        Width of dropped out portion, by default ("const", 0.025)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        t_center: tuple = ("uniform", 0.0, 1.0),
+        t_width: tuple = ("const", 0.025),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
+
+    def _transform(self, signal, tmin_s: float, tmax_s: float):
+        signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
+        mag, phase = signal.magnitude, signal.phase
+
+        mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
+        mask = (mag == 0.0) * (phase == 0.0)
+
+        mag[mask] = mag_r[mask]
+        phase[mask] = phase_r[mask]
+
+        signal.magnitude = mag
+        signal.phase = phase
+        return signal
+
+
+class FrequencyNoise(FrequencyMask):
+    """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
+    replaces with noise instead of zeros.
+
+    Parameters
+    ----------
+    f_center : tuple, optional
+        Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
+    f_width : tuple, optional
+        Width of zero'd out band, by default ("const", 0.1)
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        f_center: tuple = ("uniform", 0.0, 1.0),
+        f_width: tuple = ("const", 0.1),
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
+
+    def _transform(self, signal, fmin_hz: float, fmax_hz: float):
+        signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
+        mag, phase = signal.magnitude, signal.phase
+
+        mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
+        mask = (mag == 0.0) * (phase == 0.0)
+
+        mag[mask] = mag_r[mask]
+        phase[mask] = phase_r[mask]
+
+        signal.magnitude = mag
+        signal.phase = phase
+        return signal
+
+
+class SpectralDenoising(Equalizer):
+    """Applies denoising algorithm detailed in
+    :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
+    using a randomly generated noise signal for denoising.
+
+    Parameters
+    ----------
+    eq_amount : tuple, optional
+        Amount of eq to apply to noise signal, by default ("const", 1.0)
+    denoise_amount : tuple, optional
+        Amount to denoise by, by default ("uniform", 0.8, 1.0)
+    nz_volume : float, optional
+        Volume of noise to denoise with, by default -40
+    n_bands : int, optional
+        Number of bands in equalizer, by default 6
+    n_freq : int, optional
+        Number of frequency bins to smooth by, by default 3
+    n_time : int, optional
+        Number of time bins to smooth by, by default 5
+    name : str, optional
+        Name of this transform, used to identify it in the dictionary
+        produced by ``self.instantiate``, by default None
+    prob : float, optional
+        Probability of applying this transform, by default 1.0
+    """
+
+    def __init__(
+        self,
+        eq_amount: tuple = ("const", 1.0),
+        denoise_amount: tuple = ("uniform", 0.8, 1.0),
+        nz_volume: float = -40,
+        n_bands: int = 6,
+        n_freq: int = 3,
+        n_time: int = 5,
+        name: str = None,
+        prob: float = 1,
+    ):
+        super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
+
+        self.nz_volume = nz_volume
+        self.denoise_amount = denoise_amount
+        self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
+
+    def _transform(self, signal, nz, eq, denoise_amount):
+        nz = nz.normalize(self.nz_volume).equalizer(eq)
+        self.spectral_gate = self.spectral_gate.to(signal.device)
+        signal = self.spectral_gate(signal, nz, denoise_amount)
+        return signal
+
+    def _instantiate(self, state: RandomState):
+        kwargs = super()._instantiate(state)
+        kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
+        kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
+        return kwargs
diff --git a/audiotools/metrics/__init__.py b/audiotools/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9c8d2df61f94afae8e39e57abf156e8e4059a9e
--- /dev/null
+++ b/audiotools/metrics/__init__.py
@@ -0,0 +1,6 @@
+"""
+Functions for comparing AudioSignal objects to one another.
+"""  # fmt: skip
+from . import distance
+from . import quality
+from . import spectral
diff --git a/audiotools/metrics/distance.py b/audiotools/metrics/distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce78739bfc29f9ddc39b23063b4243ddac10adaf
--- /dev/null
+++ b/audiotools/metrics/distance.py
@@ -0,0 +1,131 @@
+import torch
+from torch import nn
+
+from .. import AudioSignal
+
+
+class L1Loss(nn.L1Loss):
+    """L1 Loss between AudioSignals. Defaults
+    to comparing ``audio_data``, but any
+    attribute of an AudioSignal can be used.
+
+    Parameters
+    ----------
+    attribute : str, optional
+        Attribute of signal to compare, defaults to ``audio_data``.
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+    """
+
+    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+        self.attribute = attribute
+        self.weight = weight
+        super().__init__(**kwargs)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate AudioSignal
+        y : AudioSignal
+            Reference AudioSignal
+
+        Returns
+        -------
+        torch.Tensor
+            L1 loss between AudioSignal attributes.
+        """
+        if isinstance(x, AudioSignal):
+            x = getattr(x, self.attribute)
+            y = getattr(y, self.attribute)
+        return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+    """
+    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+    of estimated and reference audio signals or aligned features.
+
+    Parameters
+    ----------
+    scaling : int, optional
+        Whether to use scale-invariant (True) or
+        signal-to-noise ratio (False), by default True
+    reduction : str, optional
+        How to reduce across the batch (either 'mean',
+        'sum', or none).], by default ' mean'
+    zero_mean : int, optional
+        Zero mean the references and estimates before
+        computing the loss, by default True
+    clip_min : int, optional
+        The minimum possible loss value. Helps network
+        to not focus on making already good examples better, by default None
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        scaling: int = True,
+        reduction: str = "mean",
+        zero_mean: int = True,
+        clip_min: int = None,
+        weight: float = 1.0,
+    ):
+        self.scaling = scaling
+        self.reduction = reduction
+        self.zero_mean = zero_mean
+        self.clip_min = clip_min
+        self.weight = weight
+        super().__init__()
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        eps = 1e-8
+        # nb, nc, nt
+        if isinstance(x, AudioSignal):
+            references = x.audio_data
+            estimates = y.audio_data
+        else:
+            references = x
+            estimates = y
+
+        nb = references.shape[0]
+        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+        # samples now on axis 1
+        if self.zero_mean:
+            mean_reference = references.mean(dim=1, keepdim=True)
+            mean_estimate = estimates.mean(dim=1, keepdim=True)
+        else:
+            mean_reference = 0
+            mean_estimate = 0
+
+        _references = references - mean_reference
+        _estimates = estimates - mean_estimate
+
+        references_projection = (_references**2).sum(dim=-2) + eps
+        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+        scale = (
+            (references_on_estimates / references_projection).unsqueeze(1)
+            if self.scaling
+            else 1
+        )
+
+        e_true = scale * _references
+        e_res = _estimates - e_true
+
+        signal = (e_true**2).sum(dim=1)
+        noise = (e_res**2).sum(dim=1)
+        sdr = -10 * torch.log10(signal / noise + eps)
+
+        if self.clip_min is not None:
+            sdr = torch.clamp(sdr, min=self.clip_min)
+
+        if self.reduction == "mean":
+            sdr = sdr.mean()
+        elif self.reduction == "sum":
+            sdr = sdr.sum()
+        return sdr
diff --git a/audiotools/metrics/quality.py b/audiotools/metrics/quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..1608f25507082b49ccbf49289025a5a94a422808
--- /dev/null
+++ b/audiotools/metrics/quality.py
@@ -0,0 +1,159 @@
+import os
+
+import numpy as np
+import torch
+
+from .. import AudioSignal
+
+
+def stoi(
+    estimates: AudioSignal,
+    references: AudioSignal,
+    extended: int = False,
+):
+    """Short term objective intelligibility
+    Computes the STOI (See [1][2]) of a denoised signal compared to a clean
+    signal, The output is expected to have a monotonic relation with the
+    subjective speech-intelligibility, where a higher score denotes better
+    speech intelligibility. Uses pystoi under the hood.
+
+    Parameters
+    ----------
+    estimates : AudioSignal
+        Denoised speech
+    references : AudioSignal
+        Clean original speech
+    extended : int, optional
+        Boolean, whether to use the extended STOI described in [3], by default False
+
+    Returns
+    -------
+    Tensor[float]
+        Short time objective intelligibility measure between clean and
+        denoised speech
+
+    References
+    ----------
+    1.  C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
+        Objective Intelligibility Measure for Time-Frequency Weighted Noisy
+        Speech', ICASSP 2010, Texas, Dallas.
+    2.  C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
+        Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
+        IEEE Transactions on Audio, Speech, and Language Processing, 2011.
+    3.  Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
+        Intelligibility of Speech Masked by Modulated Noise Maskers',
+        IEEE Transactions on Audio, Speech and Language Processing, 2016.
+    """
+    import pystoi
+
+    estimates = estimates.clone().to_mono()
+    references = references.clone().to_mono()
+
+    stois = []
+    for i in range(estimates.batch_size):
+        _stoi = pystoi.stoi(
+            references.audio_data[i, 0].detach().cpu().numpy(),
+            estimates.audio_data[i, 0].detach().cpu().numpy(),
+            references.sample_rate,
+            extended=extended,
+        )
+        stois.append(_stoi)
+    return torch.from_numpy(np.array(stois))
+
+
+def pesq(
+    estimates: AudioSignal,
+    references: AudioSignal,
+    mode: str = "wb",
+    target_sr: float = 16000,
+):
+    """_summary_
+
+    Parameters
+    ----------
+    estimates : AudioSignal
+        Degraded AudioSignal
+    references : AudioSignal
+        Reference AudioSignal
+    mode : str, optional
+        'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
+    target_sr : float, optional
+        Target sample rate, by default 16000
+
+    Returns
+    -------
+    Tensor[float]
+        PESQ score: P.862.2 Prediction (MOS-LQO)
+    """
+    from pesq import pesq as pesq_fn
+
+    estimates = estimates.clone().to_mono().resample(target_sr)
+    references = references.clone().to_mono().resample(target_sr)
+
+    pesqs = []
+    for i in range(estimates.batch_size):
+        _pesq = pesq_fn(
+            estimates.sample_rate,
+            references.audio_data[i, 0].detach().cpu().numpy(),
+            estimates.audio_data[i, 0].detach().cpu().numpy(),
+            mode,
+        )
+        pesqs.append(_pesq)
+    return torch.from_numpy(np.array(pesqs))
+
+
+def visqol(
+    estimates: AudioSignal,
+    references: AudioSignal,
+    mode: str = "audio",
+):  # pragma: no cover
+    """ViSQOL score.
+
+    Parameters
+    ----------
+    estimates : AudioSignal
+        Degraded AudioSignal
+    references : AudioSignal
+        Reference AudioSignal
+    mode : str, optional
+        'audio' or 'speech', by default 'audio'
+
+    Returns
+    -------
+    Tensor[float]
+        ViSQOL score (MOS-LQO)
+    """
+    from visqol import visqol_lib_py
+    from visqol.pb2 import visqol_config_pb2
+    from visqol.pb2 import similarity_result_pb2
+
+    config = visqol_config_pb2.VisqolConfig()
+    if mode == "audio":
+        target_sr = 48000
+        config.options.use_speech_scoring = False
+        svr_model_path = "libsvm_nu_svr_model.txt"
+    elif mode == "speech":
+        target_sr = 16000
+        config.options.use_speech_scoring = True
+        svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
+    else:
+        raise ValueError(f"Unrecognized mode: {mode}")
+    config.audio.sample_rate = target_sr
+    config.options.svr_model_path = os.path.join(
+        os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
+    )
+
+    api = visqol_lib_py.VisqolApi()
+    api.Create(config)
+
+    estimates = estimates.clone().to_mono().resample(target_sr)
+    references = references.clone().to_mono().resample(target_sr)
+
+    visqols = []
+    for i in range(estimates.batch_size):
+        _visqol = api.Measure(
+            references.audio_data[i, 0].detach().cpu().numpy().astype(float),
+            estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
+        )
+        visqols.append(_visqol.moslqo)
+    return torch.from_numpy(np.array(visqols))
diff --git a/audiotools/metrics/spectral.py b/audiotools/metrics/spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce953882efa4e5b777a0348bee6c1be39279a6c
--- /dev/null
+++ b/audiotools/metrics/spectral.py
@@ -0,0 +1,247 @@
+import typing
+from typing import List
+
+import numpy as np
+from torch import nn
+
+from .. import AudioSignal
+from .. import STFTParams
+
+
+class MultiScaleSTFTLoss(nn.Module):
+    """Computes the multi-scale STFT loss from [1].
+
+    Parameters
+    ----------
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    References
+    ----------
+
+    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
+        "DDSP: Differentiable Digital Signal Processing."
+        International Conference on Learning Representations. 2019.
+    """
+
+    def __init__(
+        self,
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.loss_fn = loss_fn
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.clamp_eps = clamp_eps
+        self.weight = weight
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes multi-scale STFT between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Multi-scale STFT loss.
+        """
+        loss = 0.0
+        for s in self.stft_params:
+            x.stft(s.window_length, s.hop_length, s.window_type)
+            y.stft(s.window_length, s.hop_length, s.window_type)
+            loss += self.log_weight * self.loss_fn(
+                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+        return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+    """Compute distance between mel spectrograms. Can be used
+    in a multi-scale way.
+
+    Parameters
+    ----------
+    n_mels : List[int]
+        Number of mels per STFT, by default [150, 80],
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+    """
+
+    def __init__(
+        self,
+        n_mels: List[int] = [150, 80],
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        mel_fmin: List[float] = [0.0, 0.0],
+        mel_fmax: List[float] = [None, None],
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.n_mels = n_mels
+        self.loss_fn = loss_fn
+        self.clamp_eps = clamp_eps
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.weight = weight
+        self.mel_fmin = mel_fmin
+        self.mel_fmax = mel_fmax
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes mel loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Mel loss.
+        """
+        loss = 0.0
+        for n_mels, fmin, fmax, s in zip(
+            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+        ):
+            kwargs = {
+                "window_length": s.window_length,
+                "hop_length": s.hop_length,
+                "window_type": s.window_type,
+            }
+            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+            loss += self.log_weight * self.loss_fn(
+                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+        return loss
+
+
+class PhaseLoss(nn.Module):
+    """Difference between phase spectrograms.
+
+    Parameters
+    ----------
+    window_length : int, optional
+        Length of STFT window, by default 2048
+    hop_length : int, optional
+        Hop length of STFT window, by default 512
+    weight : float, optional
+        Weight of loss, by default 1.0
+    """
+
+    def __init__(
+        self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
+    ):
+        super().__init__()
+
+        self.weight = weight
+        self.stft_params = STFTParams(window_length, hop_length)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes phase loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Phase loss.
+        """
+        s = self.stft_params
+        x.stft(s.window_length, s.hop_length, s.window_type)
+        y.stft(s.window_length, s.hop_length, s.window_type)
+
+        # Take circular difference
+        diff = x.phase - y.phase
+        diff[diff < -np.pi] += 2 * np.pi
+        diff[diff > np.pi] -= -2 * np.pi
+
+        # Scale true magnitude to weights in [0, 1]
+        x_min, x_max = x.magnitude.min(), x.magnitude.max()
+        weights = (x.magnitude - x_min) / (x_max - x_min)
+
+        # Take weighted mean of all phase errors
+        loss = ((weights * diff) ** 2).mean()
+        return loss
diff --git a/audiotools/ml/__init__.py b/audiotools/ml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ca69977bad57e1a92b7551d601d9224ee854ab
--- /dev/null
+++ b/audiotools/ml/__init__.py
@@ -0,0 +1,5 @@
+from . import decorators
+from . import layers
+from .accelerator import Accelerator
+from .experiment import Experiment
+from .layers import BaseModel
diff --git a/audiotools/ml/accelerator.py b/audiotools/ml/accelerator.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c6e8d954f112b8b0aff257894e62add8874e30
--- /dev/null
+++ b/audiotools/ml/accelerator.py
@@ -0,0 +1,184 @@
+import os
+import typing
+
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DataParallel
+from torch.nn.parallel import DistributedDataParallel
+
+from ..data.datasets import ResumableDistributedSampler as DistributedSampler
+from ..data.datasets import ResumableSequentialSampler as SequentialSampler
+
+
+class Accelerator:  # pragma: no cover
+    """This class is used to prepare models and dataloaders for
+    usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
+    prepare the respective objects. In the case of models, they are moved to
+    the appropriate GPU and SyncBatchNorm is applied to them. In the case of
+    dataloaders, a sampler is created and the dataloader is initialized with
+    that sampler.
+
+    If the world size is 1, prepare_model and prepare_dataloader are
+    no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
+    script was launched without ``torchrun``, and ``DataParallel``
+    will be used instead of ``DistributedDataParallel`` (not recommended), if
+    the world size (number of GPUs) is greater than 1.
+
+    Parameters
+    ----------
+    amp : bool, optional
+        Whether or not to enable automatic mixed precision, by default False
+    """
+
+    def __init__(self, amp: bool = False):
+        local_rank = os.getenv("LOCAL_RANK", None)
+        self.world_size = torch.cuda.device_count()
+
+        self.use_ddp = self.world_size > 1 and local_rank is not None
+        self.use_dp = self.world_size > 1 and local_rank is None
+        self.device = "cpu" if self.world_size == 0 else "cuda"
+
+        if self.use_ddp:
+            local_rank = int(local_rank)
+            dist.init_process_group(
+                "nccl",
+                init_method="env://",
+                world_size=self.world_size,
+                rank=local_rank,
+            )
+
+        self.local_rank = 0 if local_rank is None else local_rank
+        self.amp = amp
+
+        class DummyScaler:
+            def __init__(self):
+                pass
+
+            def step(self, optimizer):
+                optimizer.step()
+
+            def scale(self, loss):
+                return loss
+
+            def unscale_(self, optimizer):
+                return optimizer
+
+            def update(self):
+                pass
+
+        self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
+        self.device_ctx = (
+            torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
+        )
+
+    def __enter__(self):
+        if self.device_ctx is not None:
+            self.device_ctx.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if self.device_ctx is not None:
+            self.device_ctx.__exit__(exc_type, exc_value, traceback)
+
+    def prepare_model(self, model: torch.nn.Module, **kwargs):
+        """Prepares model for DDP or DP. The model is moved to
+        the device of the correct rank.
+
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Model that is converted for DDP or DP.
+
+        Returns
+        -------
+        torch.nn.Module
+            Wrapped model, or original model if DDP and DP are turned off.
+        """
+        model = model.to(self.device)
+        if self.use_ddp:
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+            model = DistributedDataParallel(
+                model, device_ids=[self.local_rank], **kwargs
+            )
+        elif self.use_dp:
+            model = DataParallel(model, **kwargs)
+        return model
+
+    # Automatic mixed-precision utilities
+    def autocast(self, *args, **kwargs):
+        """Context manager for autocasting. Arguments
+        go to ``torch.cuda.amp.autocast``.
+        """
+        return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
+
+    def backward(self, loss: torch.Tensor):
+        """Backwards pass, after scaling the loss if ``amp`` is
+        enabled.
+
+        Parameters
+        ----------
+        loss : torch.Tensor
+            Loss value.
+        """
+        self.scaler.scale(loss).backward()
+
+    def step(self, optimizer: torch.optim.Optimizer):
+        """Steps the optimizer, using a ``scaler`` if ``amp`` is
+        enabled.
+
+        Parameters
+        ----------
+        optimizer : torch.optim.Optimizer
+            Optimizer to step forward.
+        """
+        self.scaler.step(optimizer)
+
+    def update(self):
+        """Updates the scale factor."""
+        self.scaler.update()
+
+    def prepare_dataloader(
+        self, dataset: typing.Iterable, start_idx: int = None, **kwargs
+    ):
+        """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
+        enabled.
+
+        Parameters
+        ----------
+        dataset : typing.Iterable
+            Dataset to build Dataloader around.
+        start_idx : int, optional
+            Start index of sampler, useful if resuming from some epoch,
+            by default None
+
+        Returns
+        -------
+        _type_
+            _description_
+        """
+
+        if self.use_ddp:
+            sampler = DistributedSampler(
+                dataset,
+                start_idx,
+                num_replicas=self.world_size,
+                rank=self.local_rank,
+            )
+            if "num_workers" in kwargs:
+                kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
+            kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
+        else:
+            sampler = SequentialSampler(dataset, start_idx)
+
+        dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
+        return dataloader
+
+    @staticmethod
+    def unwrap(model):
+        """Unwraps the model if it was wrapped in DDP or DP, otherwise
+        just returns the model. Use this to unwrap the model returned by
+        :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
+        """
+        if hasattr(model, "module"):
+            return model.module
+        return model
diff --git a/audiotools/ml/decorators.py b/audiotools/ml/decorators.py
new file mode 100644
index 0000000000000000000000000000000000000000..834ec10270ff9e8e84a5fa99e13a686516a4af41
--- /dev/null
+++ b/audiotools/ml/decorators.py
@@ -0,0 +1,440 @@
+import math
+import os
+import time
+from collections import defaultdict
+from functools import wraps
+
+import torch
+import torch.distributed as dist
+from rich import box
+from rich.console import Console
+from rich.console import Group
+from rich.live import Live
+from rich.markdown import Markdown
+from rich.padding import Padding
+from rich.panel import Panel
+from rich.progress import BarColumn
+from rich.progress import Progress
+from rich.progress import SpinnerColumn
+from rich.progress import TimeElapsedColumn
+from rich.progress import TimeRemainingColumn
+from rich.rule import Rule
+from rich.table import Table
+from torch.utils.tensorboard import SummaryWriter
+
+
+# This is here so that the history can be pickled.
+def default_list():
+    return []
+
+
+class Mean:
+    """Keeps track of the running mean, along with the latest
+    value.
+    """
+
+    def __init__(self):
+        self.reset()
+
+    def __call__(self):
+        mean = self.total / max(self.count, 1)
+        return mean
+
+    def reset(self):
+        self.count = 0
+        self.total = 0
+
+    def update(self, val):
+        if math.isfinite(val):
+            self.count += 1
+            self.total += val
+
+
+def when(condition):
+    """Runs a function only when the condition is met. The condition is
+    a function that is run.
+
+    Parameters
+    ----------
+    condition : Callable
+        Function to run to check whether or not to run the decorated
+        function.
+
+    Example
+    -------
+    Checkpoint only runs every 100 iterations, and only if the
+    local rank is 0.
+
+    >>> i = 0
+    >>> rank = 0
+    >>>
+    >>> @when(lambda: i % 100 == 0 and rank == 0)
+    >>> def checkpoint():
+    >>>     print("Saving to /runs/exp1")
+    >>>
+    >>> for i in range(1000):
+    >>>     checkpoint()
+
+    """
+
+    def decorator(fn):
+        @wraps(fn)
+        def decorated(*args, **kwargs):
+            if condition():
+                return fn(*args, **kwargs)
+
+        return decorated
+
+    return decorator
+
+
+def timer(prefix: str = "time"):
+    """Adds execution time to the output dictionary of the decorated
+    function. The function decorated by this must output a dictionary.
+    The key added will follow the form "[prefix]/[name_of_function]"
+
+    Parameters
+    ----------
+    prefix : str, optional
+        The key added will follow the form "[prefix]/[name_of_function]",
+        by default "time".
+    """
+
+    def decorator(fn):
+        @wraps(fn)
+        def decorated(*args, **kwargs):
+            s = time.perf_counter()
+            output = fn(*args, **kwargs)
+            assert isinstance(output, dict)
+            e = time.perf_counter()
+            output[f"{prefix}/{fn.__name__}"] = e - s
+            return output
+
+        return decorated
+
+    return decorator
+
+
+class Tracker:
+    """
+    A tracker class that helps to monitor the progress of training and logging the metrics.
+
+    Attributes
+    ----------
+    metrics : dict
+        A dictionary containing the metrics for each label.
+    history : dict
+        A dictionary containing the history of metrics for each label.
+    writer : SummaryWriter
+        A SummaryWriter object for logging the metrics.
+    rank : int
+        The rank of the current process.
+    step : int
+        The current step of the training.
+    tasks : dict
+        A dictionary containing the progress bars and tables for each label.
+    pbar : Progress
+        A progress bar object for displaying the progress.
+    consoles : list
+        A list of console objects for logging.
+    live : Live
+        A Live object for updating the display live.
+
+    Methods
+    -------
+    print(msg: str)
+        Prints the given message to all consoles.
+    update(label: str, fn_name: str)
+        Updates the progress bar and table for the given label.
+    done(label: str, title: str)
+        Resets the progress bar and table for the given label and prints the final result.
+    track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
+        A decorator for tracking the progress and metrics of a function.
+    log(label: str, value_type: str = "value", history: bool = True)
+        A decorator for logging the metrics of a function.
+    is_best(label: str, key: str) -> bool
+        Checks if the latest value of the given key in the label is the best so far.
+    state_dict() -> dict
+        Returns a dictionary containing the state of the tracker.
+    load_state_dict(state_dict: dict) -> Tracker
+        Loads the state of the tracker from the given state dictionary.
+    """
+
+    def __init__(
+        self,
+        writer: SummaryWriter = None,
+        log_file: str = None,
+        rank: int = 0,
+        console_width: int = 100,
+        step: int = 0,
+    ):
+        """
+        Initializes the Tracker object.
+
+        Parameters
+        ----------
+        writer : SummaryWriter, optional
+            A SummaryWriter object for logging the metrics, by default None.
+        log_file : str, optional
+            The path to the log file, by default None.
+        rank : int, optional
+            The rank of the current process, by default 0.
+        console_width : int, optional
+            The width of the console, by default 100.
+        step : int, optional
+            The current step of the training, by default 0.
+        """
+        self.metrics = {}
+        self.history = {}
+        self.writer = writer
+        self.rank = rank
+        self.step = step
+
+        # Create progress bars etc.
+        self.tasks = {}
+        self.pbar = Progress(
+            SpinnerColumn(),
+            "[progress.description]{task.description}",
+            "{task.completed}/{task.total}",
+            BarColumn(),
+            TimeElapsedColumn(),
+            "/",
+            TimeRemainingColumn(),
+        )
+        self.consoles = [Console(width=console_width)]
+        self.live = Live(console=self.consoles[0], refresh_per_second=10)
+        if log_file is not None:
+            self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
+
+    def print(self, msg):
+        """
+        Prints the given message to all consoles.
+
+        Parameters
+        ----------
+        msg : str
+            The message to be printed.
+        """
+        if self.rank == 0:
+            for c in self.consoles:
+                c.log(msg)
+
+    def update(self, label, fn_name):
+        """
+        Updates the progress bar and table for the given label.
+
+        Parameters
+        ----------
+        label : str
+            The label of the progress bar and table to be updated.
+        fn_name : str
+            The name of the function associated with the label.
+        """
+        if self.rank == 0:
+            self.pbar.advance(self.tasks[label]["pbar"])
+
+            # Create table
+            table = Table(title=label, expand=True, box=box.MINIMAL)
+            table.add_column("key", style="cyan")
+            table.add_column("value", style="bright_blue")
+            table.add_column("mean", style="bright_green")
+
+            keys = self.metrics[label]["value"].keys()
+            for k in keys:
+                value = self.metrics[label]["value"][k]
+                mean = self.metrics[label]["mean"][k]()
+                table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
+
+            self.tasks[label]["table"] = table
+            tables = [t["table"] for t in self.tasks.values()]
+            group = Group(*tables, self.pbar)
+            self.live.update(
+                Group(
+                    Padding("", (0, 0)),
+                    Rule(f"[italic]{fn_name}()", style="white"),
+                    Padding("", (0, 0)),
+                    Panel.fit(
+                        group, padding=(0, 5), title="[b]Progress", border_style="blue"
+                    ),
+                )
+            )
+
+    def done(self, label: str, title: str):
+        """
+        Resets the progress bar and table for the given label and prints the final result.
+
+        Parameters
+        ----------
+        label : str
+            The label of the progress bar and table to be reset.
+        title : str
+            The title to be displayed when printing the final result.
+        """
+        for label in self.metrics:
+            for v in self.metrics[label]["mean"].values():
+                v.reset()
+
+        if self.rank == 0:
+            self.pbar.reset(self.tasks[label]["pbar"])
+            tables = [t["table"] for t in self.tasks.values()]
+            group = Group(Markdown(f"# {title}"), *tables, self.pbar)
+            self.print(group)
+
+    def track(
+        self,
+        label: str,
+        length: int,
+        completed: int = 0,
+        op: dist.ReduceOp = dist.ReduceOp.AVG,
+        ddp_active: bool = "LOCAL_RANK" in os.environ,
+    ):
+        """
+        A decorator for tracking the progress and metrics of a function.
+
+        Parameters
+        ----------
+        label : str
+            The label to be associated with the progress and metrics.
+        length : int
+            The total number of iterations to be completed.
+        completed : int, optional
+            The number of iterations already completed, by default 0.
+        op : dist.ReduceOp, optional
+            The reduce operation to be used, by default dist.ReduceOp.AVG.
+        ddp_active : bool, optional
+            Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
+        """
+        self.tasks[label] = {
+            "pbar": self.pbar.add_task(
+                f"[white]Iteration ({label})", total=length, completed=completed
+            ),
+            "table": Table(),
+        }
+        self.metrics[label] = {
+            "value": defaultdict(),
+            "mean": defaultdict(lambda: Mean()),
+        }
+
+        def decorator(fn):
+            @wraps(fn)
+            def decorated(*args, **kwargs):
+                output = fn(*args, **kwargs)
+                if not isinstance(output, dict):
+                    self.update(label, fn.__name__)
+                    return output
+                # Collect across all DDP processes
+                scalar_keys = []
+                for k, v in output.items():
+                    if isinstance(v, (int, float)):
+                        v = torch.tensor([v])
+                    if not torch.is_tensor(v):
+                        continue
+                    if ddp_active and v.is_cuda:  # pragma: no cover
+                        dist.all_reduce(v, op=op)
+                    output[k] = v.detach()
+                    if torch.numel(v) == 1:
+                        scalar_keys.append(k)
+                        output[k] = v.item()
+
+                # Save the outputs to tracker
+                for k, v in output.items():
+                    if k not in scalar_keys:
+                        continue
+                    self.metrics[label]["value"][k] = v
+                    # Update the running mean
+                    self.metrics[label]["mean"][k].update(v)
+
+                self.update(label, fn.__name__)
+                return output
+
+            return decorated
+
+        return decorator
+
+    def log(self, label: str, value_type: str = "value", history: bool = True):
+        """
+        A decorator for logging the metrics of a function.
+
+        Parameters
+        ----------
+        label : str
+            The label to be associated with the logging.
+        value_type : str, optional
+            The type of value to be logged, by default "value".
+        history : bool, optional
+            Whether to save the history of the metrics, by default True.
+        """
+        assert value_type in ["mean", "value"]
+        if history:
+            if label not in self.history:
+                self.history[label] = defaultdict(default_list)
+
+        def decorator(fn):
+            @wraps(fn)
+            def decorated(*args, **kwargs):
+                output = fn(*args, **kwargs)
+                if self.rank == 0:
+                    nonlocal value_type, label
+                    metrics = self.metrics[label][value_type]
+                    for k, v in metrics.items():
+                        v = v() if isinstance(v, Mean) else v
+                        if self.writer is not None:
+                            self.writer.add_scalar(f"{k}/{label}", v, self.step)
+                        if label in self.history:
+                            self.history[label][k].append(v)
+
+                    if label in self.history:
+                        self.history[label]["step"].append(self.step)
+
+                return output
+
+            return decorated
+
+        return decorator
+
+    def is_best(self, label, key):
+        """
+        Checks if the latest value of the given key in the label is the best so far.
+
+        Parameters
+        ----------
+        label : str
+            The label of the metrics to be checked.
+        key : str
+            The key of the metric to be checked.
+
+        Returns
+        -------
+        bool
+            True if the latest value is the best so far, otherwise False.
+        """
+        return self.history[label][key][-1] == min(self.history[label][key])
+
+    def state_dict(self):
+        """
+        Returns a dictionary containing the state of the tracker.
+
+        Returns
+        -------
+        dict
+            A dictionary containing the history and step of the tracker.
+        """
+        return {"history": self.history, "step": self.step}
+
+    def load_state_dict(self, state_dict):
+        """
+        Loads the state of the tracker from the given state dictionary.
+
+        Parameters
+        ----------
+        state_dict : dict
+            A dictionary containing the history and step of the tracker.
+
+        Returns
+        -------
+        Tracker
+            The tracker object with the loaded state.
+        """
+        self.history = state_dict["history"]
+        self.step = state_dict["step"]
+        return self
diff --git a/audiotools/ml/experiment.py b/audiotools/ml/experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..62833d0f8f80dcdf496a1a5d2785ef666e0a15b6
--- /dev/null
+++ b/audiotools/ml/experiment.py
@@ -0,0 +1,90 @@
+"""
+Useful class for Experiment tracking, and ensuring code is
+saved alongside files.
+"""  # fmt: skip
+import datetime
+import os
+import shlex
+import shutil
+import subprocess
+import typing
+from pathlib import Path
+
+import randomname
+
+
+class Experiment:
+    """This class contains utilities for managing experiments.
+    It is a context manager, that when you enter it, changes
+    your directory to a specified experiment folder (which
+    optionally can have an automatically generated experiment
+    name, or a specified one), and changes the CUDA device used
+    to the specified device (or devices).
+
+    Parameters
+    ----------
+    exp_directory : str
+        Folder where all experiments are saved, by default "runs/".
+    exp_name : str, optional
+        Name of the experiment, by default uses the current time, date, and
+        hostname to save.
+    """
+
+    def __init__(
+        self,
+        exp_directory: str = "runs/",
+        exp_name: str = None,
+    ):
+        if exp_name is None:
+            exp_name = self.generate_exp_name()
+        exp_dir = Path(exp_directory) / exp_name
+        exp_dir.mkdir(parents=True, exist_ok=True)
+
+        self.exp_dir = exp_dir
+        self.exp_name = exp_name
+        self.git_tracked_files = (
+            subprocess.check_output(
+                shlex.split("git ls-tree --full-tree --name-only -r HEAD")
+            )
+            .decode("utf-8")
+            .splitlines()
+        )
+        self.parent_directory = Path(".").absolute()
+
+    def __enter__(self):
+        self.prev_dir = os.getcwd()
+        os.chdir(self.exp_dir)
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        os.chdir(self.prev_dir)
+
+    @staticmethod
+    def generate_exp_name():
+        """Generates a random experiment name based on the date
+        and a randomly generated adjective-noun tuple.
+
+        Returns
+        -------
+        str
+            Randomly generated experiment name.
+        """
+        date = datetime.datetime.now().strftime("%y%m%d")
+        name = f"{date}-{randomname.get_name()}"
+        return name
+
+    def snapshot(self, filter_fn: typing.Callable = lambda f: True):
+        """Captures a full snapshot of all the files tracked by git at the time
+        the experiment is run. It also captures the diff against the committed
+        code as a separate file.
+
+        Parameters
+        ----------
+        filter_fn : typing.Callable, optional
+            Function that can be used to exclude some files
+            from the snapshot, by default accepts all files
+        """
+        for f in self.git_tracked_files:
+            if filter_fn(f):
+                Path(f).parent.mkdir(parents=True, exist_ok=True)
+                shutil.copyfile(self.parent_directory / f, f)
diff --git a/audiotools/ml/layers/__init__.py b/audiotools/ml/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a016cab2ddf06bf5dadfae241b7e5d9def4878
--- /dev/null
+++ b/audiotools/ml/layers/__init__.py
@@ -0,0 +1,2 @@
+from .base import BaseModel
+from .spectral_gate import SpectralGate
diff --git a/audiotools/ml/layers/base.py b/audiotools/ml/layers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd
--- /dev/null
+++ b/audiotools/ml/layers/base.py
@@ -0,0 +1,328 @@
+import inspect
+import shutil
+import tempfile
+import typing
+from pathlib import Path
+
+import torch
+from torch import nn
+
+
+class BaseModel(nn.Module):
+    """This is a class that adds useful save/load functionality to a
+    ``torch.nn.Module`` object. ``BaseModel`` objects can be saved
+    as ``torch.package`` easily, making them super easy to port between
+    machines without requiring a ton of dependencies. Files can also be
+    saved as just weights, in the standard way.
+
+    >>> class Model(ml.BaseModel):
+    >>>     def __init__(self, arg1: float = 1.0):
+    >>>         super().__init__()
+    >>>         self.arg1 = arg1
+    >>>         self.linear = nn.Linear(1, 1)
+    >>>
+    >>>     def forward(self, x):
+    >>>         return self.linear(x)
+    >>>
+    >>> model1 = Model()
+    >>>
+    >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
+    >>>     model1.save(
+    >>>         f.name,
+    >>>     )
+    >>>     model2 = Model.load(f.name)
+    >>>     out2 = seed_and_run(model2, x)
+    >>>     assert torch.allclose(out1, out2)
+    >>>
+    >>>     model1.save(f.name, package=True)
+    >>>     model2 = Model.load(f.name)
+    >>>     model2.save(f.name, package=False)
+    >>>     model3 = Model.load(f.name)
+    >>>     out3 = seed_and_run(model3, x)
+    >>>
+    >>> with tempfile.TemporaryDirectory() as d:
+    >>>     model1.save_to_folder(d, {"data": 1.0})
+    >>>     Model.load_from_folder(d)
+
+    """
+
+    EXTERN = [
+        "audiotools.**",
+        "tqdm",
+        "__main__",
+        "numpy.**",
+        "julius.**",
+        "torchaudio.**",
+        "scipy.**",
+        "einops",
+    ]
+    """Names of libraries that are external to the torch.package saving mechanism.
+    Source code from these libraries will not be packaged into the model. This can
+    be edited by the user of this class by editing ``model.EXTERN``."""
+    INTERN = []
+    """Names of libraries that are internal to the torch.package saving mechanism.
+    Source code from these libraries will be saved alongside the model."""
+
+    def save(
+        self,
+        path: str,
+        metadata: dict = None,
+        package: bool = True,
+        intern: list = [],
+        extern: list = [],
+        mock: list = [],
+    ):
+        """Saves the model, either as a torch package, or just as
+        weights, alongside some specified metadata.
+
+        Parameters
+        ----------
+        path : str
+            Path to save model to.
+        metadata : dict, optional
+            Any metadata to save alongside the model,
+            by default None
+        package : bool, optional
+            Whether to use ``torch.package`` to save the model in
+            a format that is portable, by default True
+        intern : list, optional
+            List of additional libraries that are internal
+            to the model, used with torch.package, by default []
+        extern : list, optional
+            List of additional libraries that are external to
+            the model, used with torch.package, by default []
+        mock : list, optional
+            List of libraries to mock, used with torch.package,
+            by default []
+
+        Returns
+        -------
+        str
+            Path to saved model.
+        """
+        sig = inspect.signature(self.__class__)
+        args = {}
+
+        for key, val in sig.parameters.items():
+            arg_val = val.default
+            if arg_val is not inspect.Parameter.empty:
+                args[key] = arg_val
+
+        # Look up attibutes in self, and if any of them are in args,
+        # overwrite them in args.
+        for attribute in dir(self):
+            if attribute in args:
+                args[attribute] = getattr(self, attribute)
+
+        metadata = {} if metadata is None else metadata
+        metadata["kwargs"] = args
+        if not hasattr(self, "metadata"):
+            self.metadata = {}
+        self.metadata.update(metadata)
+
+        if not package:
+            state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
+            torch.save(state_dict, path)
+        else:
+            self._save_package(path, intern=intern, extern=extern, mock=mock)
+
+        return path
+
+    @property
+    def device(self):
+        """Gets the device the model is on by looking at the device of
+        the first parameter. May not be valid if model is split across
+        multiple devices.
+        """
+        return list(self.parameters())[0].device
+
+    @classmethod
+    def load(
+        cls,
+        location: str,
+        *args,
+        package_name: str = None,
+        strict: bool = False,
+        **kwargs,
+    ):
+        """Load model from a path. Tries first to load as a package, and if
+        that fails, tries to load as weights. The arguments to the class are
+        specified inside the model weights file.
+
+        Parameters
+        ----------
+        location : str
+            Path to file.
+        package_name : str, optional
+            Name of package, by default ``cls.__name__``.
+        strict : bool, optional
+            Ignore unmatched keys, by default False
+        kwargs : dict
+            Additional keyword arguments to the model instantiation, if
+            not loading from package.
+
+        Returns
+        -------
+        BaseModel
+            A model that inherits from BaseModel.
+        """
+        try:
+            model = cls._load_package(location, package_name=package_name)
+        except:
+            model_dict = torch.load(location, "cpu")
+            metadata = model_dict["metadata"]
+            metadata["kwargs"].update(kwargs)
+
+            sig = inspect.signature(cls)
+            class_keys = list(sig.parameters.keys())
+            for k in list(metadata["kwargs"].keys()):
+                if k not in class_keys:
+                    metadata["kwargs"].pop(k)
+
+            model = cls(*args, **metadata["kwargs"])
+            model.load_state_dict(model_dict["state_dict"], strict=strict)
+            model.metadata = metadata
+
+        return model
+
+    def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
+        package_name = type(self).__name__
+        resource_name = f"{type(self).__name__}.pth"
+
+        # Below is for loading and re-saving a package.
+        if hasattr(self, "importer"):
+            kwargs["importer"] = (self.importer, torch.package.sys_importer)
+            del self.importer
+
+        # Why do we use a tempfile, you ask?
+        # It's so we can load a packaged model and then re-save
+        # it to the same location. torch.package throws an
+        # error if it's loading and writing to the same
+        # file (this is undocumented).
+        with tempfile.NamedTemporaryFile(suffix=".pth") as f:
+            with torch.package.PackageExporter(f.name, **kwargs) as exp:
+                exp.intern(self.INTERN + intern)
+                exp.mock(mock)
+                exp.extern(self.EXTERN + extern)
+                exp.save_pickle(package_name, resource_name, self)
+
+                if hasattr(self, "metadata"):
+                    exp.save_pickle(
+                        package_name, f"{package_name}.metadata", self.metadata
+                    )
+
+            shutil.copyfile(f.name, path)
+
+        # Must reset the importer back to `self` if it existed
+        # so that you can save the model again!
+        if "importer" in kwargs:
+            self.importer = kwargs["importer"][0]
+        return path
+
+    @classmethod
+    def _load_package(cls, path, package_name=None):
+        package_name = cls.__name__ if package_name is None else package_name
+        resource_name = f"{package_name}.pth"
+
+        imp = torch.package.PackageImporter(path)
+        model = imp.load_pickle(package_name, resource_name, "cpu")
+        try:
+            model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
+        except:  # pragma: no cover
+            pass
+        model.importer = imp
+
+        return model
+
+    def save_to_folder(
+        self,
+        folder: typing.Union[str, Path],
+        extra_data: dict = None,
+        package: bool = True,
+    ):
+        """Dumps a model into a folder, as both a package
+        and as weights, as well as anything specified in
+        ``extra_data``. ``extra_data`` is a dictionary of other
+        pickleable files, with the keys being the paths
+        to save them in. The model is saved under a subfolder
+        specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
+        if the model name was ``Generator``).
+
+        >>> with tempfile.TemporaryDirectory() as d:
+        >>>     extra_data = {
+        >>>         "optimizer.pth": optimizer.state_dict()
+        >>>     }
+        >>>     model.save_to_folder(d, extra_data)
+        >>>     Model.load_from_folder(d)
+
+        Parameters
+        ----------
+        folder : typing.Union[str, Path]
+            _description_
+        extra_data : dict, optional
+            _description_, by default None
+
+        Returns
+        -------
+        str
+            Path to folder
+        """
+        extra_data = {} if extra_data is None else extra_data
+        model_name = type(self).__name__.lower()
+        target_base = Path(f"{folder}/{model_name}/")
+        target_base.mkdir(exist_ok=True, parents=True)
+
+        if package:
+            package_path = target_base / f"package.pth"
+            self.save(package_path)
+
+        weights_path = target_base / f"weights.pth"
+        self.save(weights_path, package=False)
+
+        for path, obj in extra_data.items():
+            torch.save(obj, target_base / path)
+
+        return target_base
+
+    @classmethod
+    def load_from_folder(
+        cls,
+        folder: typing.Union[str, Path],
+        package: bool = True,
+        strict: bool = False,
+        **kwargs,
+    ):
+        """Loads the model from a folder generated by
+        :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
+        Like that function, this one looks for a subfolder that has
+        the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
+        model name was ``Generator``).
+
+        Parameters
+        ----------
+        folder : typing.Union[str, Path]
+            _description_
+        package : bool, optional
+            Whether to use ``torch.package`` to load the model,
+            loading the model from ``package.pth``.
+        strict : bool, optional
+            Ignore unmatched keys, by default False
+
+        Returns
+        -------
+        tuple
+            tuple of model and extra data as saved by
+            :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
+        """
+        folder = Path(folder) / cls.__name__.lower()
+        model_pth = "package.pth" if package else "weights.pth"
+        model_pth = folder / model_pth
+
+        model = cls.load(model_pth, strict=strict)
+        extra_data = {}
+        excluded = ["package.pth", "weights.pth"]
+        files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
+        for f in files:
+            extra_data[f.name] = torch.load(f, **kwargs)
+
+        return model, extra_data
diff --git a/audiotools/ml/layers/spectral_gate.py b/audiotools/ml/layers/spectral_gate.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ae8b5eab2e56ce13541695f52a11a454759dae
--- /dev/null
+++ b/audiotools/ml/layers/spectral_gate.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...core import AudioSignal
+from ...core import STFTParams
+from ...core import util
+
+
+class SpectralGate(nn.Module):
+    """Spectral gating algorithm for noise reduction,
+    as in Audacity/Ocenaudio. The steps are as follows:
+
+    1.  An FFT is calculated over the noise audio clip
+    2.  Statistics are calculated over FFT of the the noise
+        (in frequency)
+    3.  A threshold is calculated based upon the statistics
+        of the noise (and the desired sensitivity of the algorithm)
+    4.  An FFT is calculated over the signal
+    5.  A mask is determined by comparing the signal FFT to the
+        threshold
+    6.  The mask is smoothed with a filter over frequency and time
+    7.  The mask is appled to the FFT of the signal, and is inverted
+
+    Implementation inspired by Tim Sainburg's noisereduce:
+
+    https://timsainburg.com/noise-reduction-python.html
+
+    Parameters
+    ----------
+    n_freq : int, optional
+        Number of frequency bins to smooth by, by default 3
+    n_time : int, optional
+        Number of time bins to smooth by, by default 5
+    """
+
+    def __init__(self, n_freq: int = 3, n_time: int = 5):
+        super().__init__()
+
+        smoothing_filter = torch.outer(
+            torch.cat(
+                [
+                    torch.linspace(0, 1, n_freq + 2)[:-1],
+                    torch.linspace(1, 0, n_freq + 2),
+                ]
+            )[..., 1:-1],
+            torch.cat(
+                [
+                    torch.linspace(0, 1, n_time + 2)[:-1],
+                    torch.linspace(1, 0, n_time + 2),
+                ]
+            )[..., 1:-1],
+        )
+        smoothing_filter = smoothing_filter / smoothing_filter.sum()
+        smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
+        self.register_buffer("smoothing_filter", smoothing_filter)
+
+    def forward(
+        self,
+        audio_signal: AudioSignal,
+        nz_signal: AudioSignal,
+        denoise_amount: float = 1.0,
+        n_std: float = 3.0,
+        win_length: int = 2048,
+        hop_length: int = 512,
+    ):
+        """Perform noise reduction.
+
+        Parameters
+        ----------
+        audio_signal : AudioSignal
+            Audio signal that noise will be removed from.
+        nz_signal : AudioSignal, optional
+            Noise signal to compute noise statistics from.
+        denoise_amount : float, optional
+            Amount to denoise by, by default 1.0
+        n_std : float, optional
+            Number of standard deviations above which to consider
+            noise, by default 3.0
+        win_length : int, optional
+            Length of window for STFT, by default 2048
+        hop_length : int, optional
+            Hop length for STFT, by default 512
+
+        Returns
+        -------
+        AudioSignal
+            Denoised audio signal.
+        """
+        stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
+
+        audio_signal = audio_signal.clone()
+        audio_signal.stft_data = None
+        audio_signal.stft_params = stft_params
+
+        nz_signal = nz_signal.clone()
+        nz_signal.stft_params = stft_params
+
+        nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
+        nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
+        nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
+
+        nz_thresh = nz_freq_mean + nz_freq_std * n_std
+
+        stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
+        nb, nac, nf, nt = stft_db.shape
+        db_thresh = nz_thresh.expand(nb, nac, -1, nt)
+
+        stft_mask = (stft_db < db_thresh).float()
+        shape = stft_mask.shape
+
+        stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
+        pad_tuple = (
+            self.smoothing_filter.shape[-2] // 2,
+            self.smoothing_filter.shape[-1] // 2,
+        )
+        stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
+        stft_mask = stft_mask.reshape(*shape)
+        stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
+            audio_signal.device
+        )
+        stft_mask = 1 - stft_mask
+
+        audio_signal.stft_data *= stft_mask
+        audio_signal.istft()
+
+        return audio_signal
diff --git a/audiotools/post.py b/audiotools/post.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ced2d1e66a4ffda3269685bd45593b01038739f
--- /dev/null
+++ b/audiotools/post.py
@@ -0,0 +1,140 @@
+import tempfile
+import typing
+import zipfile
+from pathlib import Path
+
+import markdown2 as md
+import matplotlib.pyplot as plt
+import torch
+from IPython.display import HTML
+
+
+def audio_table(
+    audio_dict: dict,
+    first_column: str = None,
+    format_fn: typing.Callable = None,
+    **kwargs,
+):  # pragma: no cover
+    """Embeds an audio table into HTML, or as the output cell
+    in a notebook.
+
+    Parameters
+    ----------
+    audio_dict : dict
+        Dictionary of data to embed.
+    first_column : str, optional
+        The label for the first column of the table, by default None
+    format_fn : typing.Callable, optional
+        How to format the data, by default None
+
+    Returns
+    -------
+    str
+        Table as a string
+
+    Examples
+    --------
+
+    >>> audio_dict = {}
+    >>> for i in range(signal_batch.batch_size):
+    >>>     audio_dict[i] = {
+    >>>         "input": signal_batch[i],
+    >>>         "output": output_batch[i]
+    >>>     }
+    >>> audiotools.post.audio_zip(audio_dict)
+
+    """
+    from audiotools import AudioSignal
+
+    output = []
+    columns = None
+
+    def _default_format_fn(label, x, **kwargs):
+        if torch.is_tensor(x):
+            x = x.tolist()
+
+        if x is None:
+            return "."
+        elif isinstance(x, AudioSignal):
+            return x.embed(display=False, return_html=True, **kwargs)
+        else:
+            return str(x)
+
+    if format_fn is None:
+        format_fn = _default_format_fn
+
+    if first_column is None:
+        first_column = "."
+
+    for k, v in audio_dict.items():
+        if not isinstance(v, dict):
+            v = {"Audio": v}
+
+        v_keys = list(v.keys())
+        if columns is None:
+            columns = [first_column] + v_keys
+            output.append(" | ".join(columns))
+
+            layout = "|---" + len(v_keys) * "|:-:"
+            output.append(layout)
+
+        formatted_audio = []
+        for col in columns[1:]:
+            formatted_audio.append(format_fn(col, v[col], **kwargs))
+
+        row = f"| {k} | "
+        row += " | ".join(formatted_audio)
+        output.append(row)
+
+    output = "\n" + "\n".join(output)
+    return output
+
+
+def in_notebook():  # pragma: no cover
+    """Determines if code is running in a notebook.
+
+    Returns
+    -------
+    bool
+        Whether or not this is running in a notebook.
+    """
+    try:
+        from IPython import get_ipython
+
+        if "IPKernelApp" not in get_ipython().config:  # pragma: no cover
+            return False
+    except ImportError:
+        return False
+    except AttributeError:
+        return False
+    return True
+
+
+def disp(obj, **kwargs):  # pragma: no cover
+    """Displays an object, depending on if its in a notebook
+    or not.
+
+    Parameters
+    ----------
+    obj : typing.Any
+        Any object to display.
+
+    """
+    from audiotools import AudioSignal
+
+    IN_NOTEBOOK = in_notebook()
+
+    if isinstance(obj, AudioSignal):
+        audio_elem = obj.embed(display=False, return_html=True)
+        if IN_NOTEBOOK:
+            return HTML(audio_elem)
+        else:
+            print(audio_elem)
+    if isinstance(obj, dict):
+        table = audio_table(obj, **kwargs)
+        if IN_NOTEBOOK:
+            return HTML(md.markdown(table, extras=["tables"]))
+        else:
+            print(table)
+    if isinstance(obj, plt.Figure):
+        plt.show()
diff --git a/audiotools/preference.py b/audiotools/preference.py
new file mode 100644
index 0000000000000000000000000000000000000000..800a852e8119dd18ea65784cf95182de2470fbc4
--- /dev/null
+++ b/audiotools/preference.py
@@ -0,0 +1,600 @@
+##############################################################
+### Tools for creating preference tests (MUSHRA, ABX, etc) ###
+##############################################################
+import copy
+import csv
+import random
+import sys
+import traceback
+from collections import defaultdict
+from pathlib import Path
+from typing import List
+
+import gradio as gr
+
+from audiotools.core.util import find_audio
+
+################################################################
+### Logic for audio player, and adding audio / play buttons. ###
+################################################################
+
+WAVESURFER = """<div id="waveform"></div><div id="wave-timeline"></div>"""
+
+CUSTOM_CSS = """
+.gradio-container {
+    max-width: 840px !important;
+}
+region.wavesurfer-region:before {
+    content: attr(data-region-label);
+}
+
+block {
+    min-width: 0 !important;
+}
+
+#wave-timeline {
+    background-color: rgba(0, 0, 0, 0.8);
+}
+
+.head.svelte-1cl284s {
+    display: none;
+}
+"""
+
+load_wavesurfer_js = """
+function load_wavesurfer() {
+    function load_script(url) {
+        const script = document.createElement('script');
+        script.src = url;
+        document.body.appendChild(script);
+
+        return new Promise((res, rej) => {
+            script.onload = function() {
+                res();
+            }
+            script.onerror = function () {
+                rej();
+            }
+        });
+    }
+
+    function create_wavesurfer() {
+        var options = {
+            container: '#waveform',
+            waveColor: '#F2F2F2', // Set a darker wave color
+            progressColor: 'white', // Set a slightly lighter progress color
+            loaderColor: 'white', // Set a slightly lighter loader color
+            cursorColor: 'black', // Set a slightly lighter cursor color
+            backgroundColor: '#00AAFF', // Set a black background color
+            barWidth: 4,
+            barRadius: 3,
+            barHeight: 1, // the height of the wave
+            plugins: [
+                WaveSurfer.regions.create({
+                    regionsMinLength: 0.0,
+                    dragSelection: {
+                        slop: 5
+                    },
+                    color: 'hsla(200, 50%, 70%, 0.4)',
+                }),
+                 WaveSurfer.timeline.create({
+                    container: "#wave-timeline",
+                    primaryLabelInterval: 5.0,
+                    secondaryLabelInterval: 1.0,
+                    primaryFontColor: '#F2F2F2',
+                    secondaryFontColor: '#F2F2F2',
+                }),
+            ]
+        };
+        wavesurfer = WaveSurfer.create(options);
+        wavesurfer.on('region-created', region => {
+            wavesurfer.regions.clear();
+        });
+        wavesurfer.on('finish', function () {
+            var loop =  document.getElementById("loop-button").textContent.includes("ON");
+            if (loop) {
+                wavesurfer.play();
+            }
+            else {
+                var button_elements = document.getElementsByClassName('playpause')
+                var buttons = Array.from(button_elements);
+
+                for (let j = 0; j < buttons.length; j++) {
+                    buttons[j].classList.remove("primary");
+                    buttons[j].classList.add("secondary");
+                    buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+                }
+            }
+        });
+
+        wavesurfer.on('region-out', function () {
+            var loop =  document.getElementById("loop-button").textContent.includes("ON");
+            if (!loop) {
+                var button_elements = document.getElementsByClassName('playpause')
+                var buttons = Array.from(button_elements);
+
+                for (let j = 0; j < buttons.length; j++) {
+                    buttons[j].classList.remove("primary");
+                    buttons[j].classList.add("secondary");
+                    buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+                }
+                wavesurfer.pause();
+            }
+        });
+
+        console.log("Created WaveSurfer object.")
+    }
+
+    load_script('https://unpkg.com/wavesurfer.js@6.6.4')
+        .then(() => {
+            load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js")
+                .then(() => {
+                    load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js')
+                        .then(() => {
+                            console.log("Loaded regions");
+                            create_wavesurfer();
+                            document.getElementById("start-survey").click();
+                        })
+                })
+        });
+}
+"""
+
+play = lambda i: """
+function play() {
+    var audio_elements = document.getElementsByTagName('audio');
+    var button_elements = document.getElementsByClassName('playpause')
+
+    var audio_array = Array.from(audio_elements);
+    var buttons = Array.from(button_elements);
+
+    var src_link = audio_array[{i}].getAttribute("src");
+    console.log(src_link);
+
+    var loop = document.getElementById("loop-button").textContent.includes("ON");
+    var playing = buttons[{i}].textContent.includes("Stop");
+
+    for (let j = 0; j < buttons.length; j++) {
+        if (j != {i} || playing) {
+            buttons[j].classList.remove("primary");
+            buttons[j].classList.add("secondary");
+            buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+        }
+        else {
+            buttons[j].classList.remove("secondary");
+            buttons[j].classList.add("primary");
+            buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop")
+        }
+    }
+
+    if (playing) {
+        wavesurfer.pause();
+        wavesurfer.seekTo(0.0);
+    }
+    else {
+        wavesurfer.load(src_link);
+        wavesurfer.on('ready', function () {
+            var region = Object.values(wavesurfer.regions.list)[0];
+
+            if (region != null) {
+                region.loop = loop;
+                region.play();
+            } else {
+                wavesurfer.play();
+            }
+        });
+    }
+}
+""".replace(
+    "{i}", str(i)
+)
+
+clear_regions = """
+function clear_regions() {
+    wavesurfer.clearRegions();
+}
+"""
+
+reset_player = """
+function reset_player() {
+    wavesurfer.clearRegions();
+    wavesurfer.pause();
+    wavesurfer.seekTo(0.0);
+
+    var button_elements = document.getElementsByClassName('playpause')
+    var buttons = Array.from(button_elements);
+
+    for (let j = 0; j < buttons.length; j++) {
+        buttons[j].classList.remove("primary");
+        buttons[j].classList.add("secondary");
+        buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+    }
+}
+"""
+
+loop_region = """
+function loop_region() {
+    var element = document.getElementById("loop-button");
+    var loop = element.textContent.includes("OFF");
+    console.log(loop);
+
+    try {
+        var region = Object.values(wavesurfer.regions.list)[0];
+        region.loop = loop;
+    } catch {}
+
+    if (loop) {
+        element.classList.remove("secondary");
+        element.classList.add("primary");
+        element.textContent = "Looping ON";
+    } else {
+        element.classList.remove("primary");
+        element.classList.add("secondary");
+        element.textContent = "Looping OFF";
+    }
+}
+"""
+
+
+class Player:
+    def __init__(self, app):
+        self.app = app
+
+        self.app.load(_js=load_wavesurfer_js)
+        self.app.css = CUSTOM_CSS
+
+        self.wavs = []
+        self.position = 0
+
+    def create(self):
+        gr.HTML(WAVESURFER)
+        gr.Markdown(
+            "Click and drag on the waveform above to select a region for playback. "
+            "Once created, the region can be moved around and resized. "
+            "Clear the regions using the button below. Hit play on one of the buttons below to start!"
+        )
+
+        with gr.Row():
+            clear = gr.Button("Clear region")
+            loop = gr.Button("Looping OFF", elem_id="loop-button")
+
+            loop.click(None, _js=loop_region)
+            clear.click(None, _js=clear_regions)
+
+        gr.HTML("<hr>")
+
+    def add(self, name: str = "Play"):
+        i = self.position
+        self.wavs.append(
+            {
+                "audio": gr.Audio(visible=False),
+                "button": gr.Button(name, elem_classes=["playpause"]),
+                "position": i,
+            }
+        )
+        self.wavs[-1]["button"].click(None, _js=play(i))
+        self.position += 1
+        return self.wavs[-1]
+
+    def to_list(self):
+        return [x["audio"] for x in self.wavs]
+
+
+############################################################
+### Keeping track of users, and CSS for the progress bar ###
+############################################################
+
+load_tracker = lambda name: """
+function load_name() {
+    function setCookie(name, value, exp_days) {
+        var d = new Date();
+        d.setTime(d.getTime() + (exp_days*24*60*60*1000));
+        var expires = "expires=" + d.toGMTString();
+        document.cookie = name + "=" + value + ";" + expires + ";path=/";
+    }
+
+    function getCookie(name) {
+        var cname = name + "=";
+        var decodedCookie = decodeURIComponent(document.cookie);
+        var ca = decodedCookie.split(';');
+        for(var i = 0; i < ca.length; i++){
+            var c = ca[i];
+            while(c.charAt(0) == ' '){
+                c = c.substring(1);
+            }
+            if(c.indexOf(cname) == 0){
+                return c.substring(cname.length, c.length);
+            }
+        }
+        return "";
+    }
+
+    name = getCookie("{name}");
+    if (name == "") {
+        name = Math.random().toString(36).slice(2);
+        console.log(name);
+        setCookie("name", name, 30);
+    }
+    name = getCookie("{name}");
+    return name;
+}
+""".replace(
+    "{name}", name
+)
+
+# Progress bar
+
+progress_template = """
+<!DOCTYPE html>
+<html>
+  <head>
+    <title>Progress Bar</title>
+    <style>
+      .progress-bar {
+        background-color: #ddd;
+        border-radius: 4px;
+        height: 30px;
+        width: 100%;
+        position: relative;
+      }
+
+      .progress {
+        background-color: #00AAFF;
+        border-radius: 4px;
+        height: 100%;
+        width: {PROGRESS}%; /* Change this value to control the progress */
+      }
+
+      .progress-text {
+        position: absolute;
+        top: 50%;
+        left: 50%;
+        transform: translate(-50%, -50%);
+        font-size: 18px;
+        font-family: Arial, sans-serif;
+        font-weight: bold;
+        color: #333 !important;
+        text-shadow: 1px 1px #fff;
+      }
+    </style>
+  </head>
+  <body>
+    <div class="progress-bar">
+      <div class="progress"></div>
+      <div class="progress-text">{TEXT}</div>
+    </div>
+  </body>
+</html>
+"""
+
+
+def create_tracker(app, cookie_name="name"):
+    user = gr.Text(label="user", interactive=True, visible=False, elem_id="user")
+    app.load(_js=load_tracker(cookie_name), outputs=user)
+    return user
+
+
+#################################################################
+### CSS and HTML for labeling sliders for both ABX and MUSHRA ###
+#################################################################
+
+slider_abx = """
+<!DOCTYPE html>
+<html>
+  <head>
+    <meta charset="UTF-8">
+    <title>Labels Example</title>
+    <style>
+      body {
+        margin: 0;
+        padding: 0;
+      }
+
+      .labels-container {
+        display: flex;
+        justify-content: space-between;
+        align-items: center;
+        width: 100%;
+        height: 40px;
+        padding: 0px 12px 0px;
+      }
+
+      .label {
+        display: flex;
+        justify-content: center;
+        align-items: center;
+        width: 33%;
+        height: 100%;
+        font-weight: bold;
+        text-transform: uppercase;
+        padding: 10px;
+        font-family: Arial, sans-serif;
+        font-size: 16px;
+        font-weight: 700;
+        letter-spacing: 1px;
+        line-height: 1.5;
+      }
+
+      .label-a {
+        background-color: #00AAFF;
+        color: #333 !important;
+      }
+
+      .label-tie {
+        background-color: #f97316;
+        color: #333 !important;
+      }
+
+      .label-b {
+        background-color: #00AAFF;
+        color: #333 !important;
+      }
+    </style>
+  </head>
+  <body>
+    <div class="labels-container">
+      <div class="label label-a">Prefer A</div>
+      <div class="label label-tie">Toss-up</div>
+      <div class="label label-b">Prefer B</div>
+    </div>
+  </body>
+</html>
+"""
+
+slider_mushra = """
+<!DOCTYPE html>
+<html>
+  <head>
+    <meta charset="UTF-8">
+    <title>Labels Example</title>
+    <style>
+      body {
+        margin: 0;
+        padding: 0;
+      }
+
+      .labels-container {
+        display: flex;
+        justify-content: space-between;
+        align-items: center;
+        width: 100%;
+        height: 30px;
+        padding: 10px;
+      }
+
+      .label {
+        display: flex;
+        justify-content: center;
+        align-items: center;
+        width: 20%;
+        height: 100%;
+        font-weight: bold;
+        text-transform: uppercase;
+        padding: 10px;
+        font-family: Arial, sans-serif;
+        font-size: 13.5px;
+        font-weight: 700;
+        line-height: 1.5;
+      }
+
+      .label-bad {
+        background-color: #ff5555;
+        color: #333 !important;
+      }
+
+      .label-poor {
+        background-color: #ffa500;
+        color: #333 !important;
+      }
+
+      .label-fair {
+        background-color: #ffd700;
+        color: #333 !important;
+      }
+
+      .label-good {
+        background-color: #97d997;
+        color: #333 !important;
+      }
+
+      .label-excellent {
+        background-color: #04c822;
+        color: #333 !important;
+      }
+    </style>
+  </head>
+  <body>
+    <div class="labels-container">
+      <div class="label label-bad">bad</div>
+      <div class="label label-poor">poor</div>
+      <div class="label label-fair">fair</div>
+      <div class="label label-good">good</div>
+      <div class="label label-excellent">excellent</div>
+    </div>
+  </body>
+</html>
+"""
+
+#########################################################
+### Handling loading audio and tracking session state ###
+#########################################################
+
+
+class Samples:
+    def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None):
+        files = find_audio(folder)
+        samples = defaultdict(lambda: defaultdict())
+
+        for f in files:
+            condition = f.parent.stem
+            samples[f.name][condition] = f
+
+        self.samples = samples
+        self.names = list(samples.keys())
+        self.filtered = False
+        self.current = 0
+
+        if shuffle:
+            random.shuffle(self.names)
+
+        self.n_samples = len(self.names) if n_samples is None else n_samples
+
+    def get_updates(self, idx, order):
+        key = self.names[idx]
+        return [gr.update(value=str(self.samples[key][o])) for o in order]
+
+    def progress(self):
+        try:
+            pct = self.current / len(self) * 100
+        except:  # pragma: no cover
+            pct = 100
+        text = f"On {self.current} / {len(self)} samples"
+        pbar = (
+            copy.copy(progress_template)
+            .replace("{PROGRESS}", str(pct))
+            .replace("{TEXT}", str(text))
+        )
+        return gr.update(value=pbar)
+
+    def __len__(self):
+        return self.n_samples
+
+    def filter_completed(self, user, save_path):
+        if not self.filtered:
+            done = []
+            if Path(save_path).exists():
+                with open(save_path, "r") as f:
+                    reader = csv.DictReader(f)
+                    done = [r["sample"] for r in reader if r["user"] == user]
+            self.names = [k for k in self.names if k not in done]
+            self.names = self.names[: self.n_samples]
+            self.filtered = True  # Avoid filtering more than once per session.
+
+    def get_next_sample(self, reference, conditions):
+        random.shuffle(conditions)
+        if reference is not None:
+            self.order = [reference] + conditions
+        else:
+            self.order = conditions
+
+        try:
+            updates = self.get_updates(self.current, self.order)
+            self.current += 1
+            done = gr.update(interactive=True)
+            pbar = self.progress()
+        except:
+            traceback.print_exc()
+            updates = [gr.update() for _ in range(len(self.order))]
+            done = gr.update(value="No more samples!", interactive=False)
+            self.current = len(self)
+            pbar = self.progress()
+
+        return updates, done, pbar
+
+
+def save_result(result, save_path):
+    with open(save_path, mode="a", newline="") as file:
+        writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys())))
+        if file.tell() == 0:
+            writer.writeheader()
+        writer.writerow(result)
diff --git a/src/inference.py b/src/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..654090b58a9e85e740821e485428b5fb37766edb
--- /dev/null
+++ b/src/inference.py
@@ -0,0 +1,169 @@
+import os
+import random
+import pandas as pd
+import torch
+import librosa
+import numpy as np
+import soundfile as sf
+from tqdm import tqdm
+from .utils import scale_shift_re
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+@torch.no_grad()
+def inference(autoencoder, unet, gt, gt_mask,
+              tokenizer, text_encoder,
+              params, noise_scheduler,
+              text_raw, neg_text=None,
+              audio_frames=500,
+              guidance_scale=3, guidance_rescale=0.0,
+              ddim_steps=50, eta=1, random_seed=2024,
+              device='cuda',
+              ):
+    if neg_text is None:
+        neg_text = [""]
+    if tokenizer is not None:
+        text_batch = tokenizer(text_raw,
+                               max_length=params['text_encoder']['max_length'],
+                               padding="max_length", truncation=True, return_tensors="pt")
+        text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
+        text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
+
+        uncond_text_batch = tokenizer(neg_text,
+                                      max_length=params['text_encoder']['max_length'],
+                                      padding="max_length", truncation=True, return_tensors="pt")
+        uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
+        uncond_text = text_encoder(input_ids=uncond_text,
+                                   attention_mask=uncond_text_mask).last_hidden_state
+    else:
+        text, text_mask = None, None
+        guidance_scale = None
+
+    codec_dim = params['model']['out_chans']
+    unet.eval()
+
+    if random_seed is not None:
+        generator = torch.Generator(device=device).manual_seed(random_seed)
+    else:
+        generator = torch.Generator(device=device)
+        generator.seed()
+
+    noise_scheduler.set_timesteps(ddim_steps)
+
+    # init noise
+    noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
+    latents = noise
+
+    for t in noise_scheduler.timesteps:
+        latents = noise_scheduler.scale_model_input(latents, t)
+
+        if guidance_scale:
+
+            latents_combined = torch.cat([latents, latents], dim=0)
+            text_combined = torch.cat([text, uncond_text], dim=0)
+            text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
+            
+            if gt is not None:
+                gt_combined = torch.cat([gt, gt], dim=0)
+                gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
+            else:
+                gt_combined = None
+                gt_mask_combined = None
+            
+            output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined, 
+                                      cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
+            output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
+
+            output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
+            if guidance_rescale > 0.0:
+                output_pred = rescale_noise_cfg(output_pred, output_text,
+                                                guidance_rescale=guidance_rescale)
+        else:
+            output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
+                                         cls_token=None, gt=gt, mae_mask_infer=gt_mask)
+
+        latents = noise_scheduler.step(model_output=output_pred, timestep=t, 
+                                       sample=latents,
+                                       eta=eta, generator=generator).prev_sample
+
+    pred = scale_shift_re(latents, params['autoencoder']['scale'],
+                          params['autoencoder']['shift'])
+    if gt is not None:
+        pred[~gt_mask] = gt[~gt_mask]
+    pred_wav = autoencoder(embedding=pred)
+    return pred_wav
+
+
+@torch.no_grad()
+def eval_udit(autoencoder, unet,
+              tokenizer, text_encoder,
+              params, noise_scheduler,
+              val_df, subset,
+              audio_frames, mae=False,
+              guidance_scale=3, guidance_rescale=0.0,
+              ddim_steps=50, eta=1, random_seed=2023,
+              device='cuda',
+              epoch=0, save_path='logs/eval/', val_num=5):
+    val_df = pd.read_csv(val_df)
+    val_df = val_df[val_df['split'] == subset]
+    if mae:
+        val_df = val_df[val_df['audio_length'] != 0]
+
+    save_path = save_path + str(epoch) + '/'
+    os.makedirs(save_path, exist_ok=True)
+
+    for i in tqdm(range(len(val_df))):
+        row = val_df.iloc[i]
+        text = [row['caption']]
+        if mae:
+            audio_path = params['data']['val_dir'] + str(row['audio_path'])
+            gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
+            gt = gt / (np.max(np.abs(gt)) + 1e-9)
+            sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
+            num_samples = 10 * sr
+            if len(gt) < num_samples:
+                padding = num_samples - len(gt)
+                gt = np.pad(gt, (0, padding), 'constant')
+            else:
+                gt = gt[:num_samples]
+            gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
+            gt = autoencoder(audio=gt)
+            B, D, L = gt.shape
+            mask_len = int(L * 0.2)
+            gt_mask = torch.zeros(B, D, L).to(device)
+            for _ in range(2):
+                start = random.randint(0, L - mask_len)
+                gt_mask[:, :, start:start + mask_len] = 1
+            gt_mask = gt_mask.bool()
+        else:
+            gt = None
+            gt_mask = None
+
+        pred = inference(autoencoder, unet, gt, gt_mask,
+                         tokenizer, text_encoder, 
+                         params, noise_scheduler,
+                         text, neg_text=None,
+                         audio_frames=audio_frames,
+                         guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
+                         ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
+                         device=device)
+
+        pred = pred.cpu().numpy().squeeze(0).squeeze(0)
+
+        sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
+
+        if i + 1 >= val_num:
+            break
diff --git a/src/inference_controlnet.py b/src/inference_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeb5dd30228c9a597c25e523d0d94c564bbe910b
--- /dev/null
+++ b/src/inference_controlnet.py
@@ -0,0 +1,129 @@
+import os
+import random
+import pandas as pd
+import torch
+import librosa
+import numpy as np
+import soundfile as sf
+from tqdm import tqdm
+from .utils import scale_shift_re
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+@torch.no_grad()
+def inference(autoencoder, unet, controlnet,
+              gt, gt_mask, condition,
+              tokenizer, text_encoder,
+              params, noise_scheduler,
+              text_raw, neg_text=None,
+              audio_frames=500,
+              guidance_scale=3, guidance_rescale=0.0,
+              ddim_steps=50, eta=1, random_seed=2024,
+              conditioning_scale=1.0,
+              device='cuda',
+              ):
+    if neg_text is None:
+        neg_text = [""]
+    if tokenizer is not None:
+        text_batch = tokenizer(text_raw,
+                               max_length=params['text_encoder']['max_length'],
+                               padding="max_length", truncation=True, return_tensors="pt")
+        text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
+        text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
+
+        uncond_text_batch = tokenizer(neg_text,
+                                      max_length=params['text_encoder']['max_length'],
+                                      padding="max_length", truncation=True, return_tensors="pt")
+        uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
+        uncond_text = text_encoder(input_ids=uncond_text,
+                                   attention_mask=uncond_text_mask).last_hidden_state
+    else:
+        text, text_mask = None, None
+        guidance_scale = None
+
+    codec_dim = params['model']['out_chans']
+    unet.eval()
+    controlnet.eval()
+
+    if random_seed is not None:
+        generator = torch.Generator(device=device).manual_seed(random_seed)
+    else:
+        generator = torch.Generator(device=device)
+        generator.seed()
+
+    noise_scheduler.set_timesteps(ddim_steps)
+
+    # init noise
+    noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
+    latents = noise
+
+    for t in noise_scheduler.timesteps:
+        latents = noise_scheduler.scale_model_input(latents, t)
+
+        if guidance_scale:
+            latents_combined = torch.cat([latents, latents], dim=0)
+            text_combined = torch.cat([text, uncond_text], dim=0)
+            text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
+            condition_combined = torch.cat([condition, condition], dim=0)
+
+            if gt is not None:
+                gt_combined = torch.cat([gt, gt], dim=0)
+                gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
+            else:
+                gt_combined = None
+                gt_mask_combined = None
+
+            x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
+                        cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined, 
+                        forward_model=False)
+            controlnet_skips = controlnet(x, t, text_combined,
+                                          context_mask=text_mask_combined,
+                                          cls_token=None,
+                                          condition=condition_combined,
+                                          conditioning_scale=conditioning_scale)
+            output_combined = unet.model(x, t, text_combined,
+                                         context_mask=text_mask_combined,
+                                         cls_token=None, controlnet_skips=controlnet_skips)
+
+            output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
+
+            output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
+            if guidance_rescale > 0.0:
+                output_pred = rescale_noise_cfg(output_pred, output_text,
+                                                guidance_rescale=guidance_rescale)
+        else:
+            x, _ = unet(latents, t, text, context_mask=text_mask,
+                        cls_token=None, gt=gt, mae_mask_infer=gt_mask,
+                        forward_model=False)
+            controlnet_skips = controlnet(x, t, text,
+                                          context_mask=text_mask,
+                                          cls_token=None,
+                                          condition=condition,
+                                          conditioning_scale=conditioning_scale)
+            output_pred = unet.model(x, t, text,
+                                     context_mask=text_mask,
+                                     cls_token=None, controlnet_skips=controlnet_skips)
+
+        latents = noise_scheduler.step(model_output=output_pred, timestep=t,
+                                       sample=latents,
+                                       eta=eta, generator=generator).prev_sample
+
+    pred = scale_shift_re(latents, params['autoencoder']['scale'],
+                          params['autoencoder']['shift'])
+    if gt is not None:
+        pred[~gt_mask] = gt[~gt_mask]
+    pred_wav = autoencoder(embedding=pred)
+    return pred_wav
\ No newline at end of file
diff --git a/src/models/.ipynb_checkpoints/blocks-checkpoint.py b/src/models/.ipynb_checkpoints/blocks-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef730009cb7cf664c5d9021e551b275680d11f3
--- /dev/null
+++ b/src/models/.ipynb_checkpoints/blocks-checkpoint.py
@@ -0,0 +1,325 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from .utils.attention import Attention, JointAttention
+from .utils.modules import unpatchify, FeedForward
+from .utils.modules import film_modulate
+
+
+class AdaLN(nn.Module):
+    def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
+        super().__init__()
+        self.ada_mode = ada_mode
+        self.scale_shift_table = None
+        if ada_mode == 'ada':
+            # move nn.silu outside
+            self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
+        elif ada_mode == 'ada_single':
+            # adaln used in pixel-art alpha
+            self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+        elif ada_mode in ['ada_lora', 'ada_lora_bias']:
+            self.lora_a = nn.Linear(dim, r * 6, bias=False)
+            self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
+            self.scaling = alpha / r
+            if ada_mode == 'ada_lora_bias':
+                # take bias out for consistency
+                self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+        else:
+            raise NotImplementedError
+
+    def forward(self, time_token=None, time_ada=None):
+        if self.ada_mode == 'ada':
+            assert time_ada is None
+            B = time_token.shape[0]
+            time_ada = self.time_ada(time_token).reshape(B, 6, -1)
+        elif self.ada_mode == 'ada_single':
+            B = time_ada.shape[0]
+            time_ada = time_ada.reshape(B, 6, -1)
+            time_ada = self.scale_shift_table[None] + time_ada
+        elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
+            B = time_ada.shape[0]
+            time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
+            time_ada = time_ada + time_ada_lora
+            time_ada = time_ada.reshape(B, 6, -1)
+            if self.scale_shift_table is not None:
+                time_ada = self.scale_shift_table[None] + time_ada
+        else:
+            raise NotImplementedError
+        return time_ada
+
+
+class DiTBlock(nn.Module):
+    """
+    A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+    """
+
+    def __init__(self, dim, context_dim=None,
+                 num_heads=8, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer=nn.LayerNorm,
+                 time_fusion='none',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 skip=False, skip_norm=False,
+                 rope_mode='none',
+                 context_norm=False,
+                 use_checkpoint=False):
+
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(dim=dim,
+                              num_heads=num_heads,
+                              qkv_bias=qkv_bias, qk_scale=qk_scale,
+                              qk_norm=qk_norm,
+                              rope_mode=rope_mode)
+
+        if context_dim is not None:
+            self.use_context = True
+            self.cross_attn = Attention(dim=dim,
+                                        num_heads=num_heads,
+                                        context_dim=context_dim,
+                                        qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                        qk_norm=qk_norm,
+                                        rope_mode='none')
+            self.norm2 = norm_layer(dim)
+            if context_norm:
+                self.norm_context = norm_layer(context_dim)
+            else:
+                self.norm_context = nn.Identity()
+        else:
+            self.use_context = False
+
+        self.norm3 = norm_layer(dim)
+        self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
+                               activation_fn=act_layer, dropout=0)
+
+        self.use_adanorm = True if time_fusion != 'token' else False
+        if self.use_adanorm:
+            self.adaln = AdaLN(dim, ada_mode=time_fusion,
+                               r=ada_lora_rank, alpha=ada_lora_alpha)
+        if skip:
+            self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
+            self.skip_linear = nn.Linear(2 * dim, dim)
+        else:
+            self.skip_linear = None
+            
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x, time_token=None, time_ada=None,
+                skip=None, context=None,
+                x_mask=None, context_mask=None, extras=None):
+        if self.use_checkpoint:
+            return checkpoint(self._forward, x,
+                              time_token, time_ada, skip, context,
+                              x_mask, context_mask, extras,
+                              use_reentrant=False)
+        else:
+            return self._forward(x,
+                                 time_token, time_ada, skip, context,
+                                 x_mask, context_mask, extras)
+
+    def _forward(self, x, time_token=None, time_ada=None,
+                 skip=None, context=None,
+                 x_mask=None, context_mask=None, extras=None):
+        B, T, C = x.shape
+        if self.skip_linear is not None:
+            assert skip is not None
+            cat = torch.cat([x, skip], dim=-1)
+            cat = self.skip_norm(cat)
+            x = self.skip_linear(cat)
+
+        if self.use_adanorm:
+            time_ada = self.adaln(time_token, time_ada)
+            (shift_msa, scale_msa, gate_msa,
+             shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+        # self attention
+        if self.use_adanorm:
+            x_norm = film_modulate(self.norm1(x), shift=shift_msa,
+                                   scale=scale_msa)
+            x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
+                                               context_mask=x_mask,
+                                               extras=extras)
+        else:
+            x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
+                              extras=extras)
+
+        # cross attention
+        if self.use_context:
+            assert context is not None
+            x = x + self.cross_attn(x=self.norm2(x),
+                                    context=self.norm_context(context),
+                                    context_mask=context_mask, extras=extras)
+
+        # mlp
+        if self.use_adanorm:
+            x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
+            x = x + (1 - gate_mlp) * self.mlp(x_norm)
+        else:
+            x = x + self.mlp(self.norm3(x))
+
+        return x
+
+
+class JointDiTBlock(nn.Module):
+    """
+    A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+    """
+
+    def __init__(self, dim, context_dim=None,
+                 num_heads=8, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer=nn.LayerNorm,
+                 time_fusion='none',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 skip=(False, False),
+                 rope_mode=False,
+                 context_norm=False,
+                 use_checkpoint=False,):
+
+        super().__init__()
+        # no cross attention
+        assert context_dim is None
+        self.attn_norm_x = norm_layer(dim)
+        self.attn_norm_c = norm_layer(dim)
+        self.attn = JointAttention(dim=dim,
+                                   num_heads=num_heads,
+                                   qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                   qk_norm=qk_norm,
+                                   rope_mode=rope_mode)
+        self.ffn_norm_x = norm_layer(dim)
+        self.ffn_norm_c = norm_layer(dim)
+        self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
+                                 activation_fn=act_layer, dropout=0)
+        self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
+                                 activation_fn=act_layer, dropout=0)
+
+        # Zero-out the shift table
+        self.use_adanorm = True if time_fusion != 'token' else False
+        if self.use_adanorm:
+            self.adaln = AdaLN(dim, ada_mode=time_fusion,
+                               r=ada_lora_rank, alpha=ada_lora_alpha)
+
+        if skip is False:
+            skip_x, skip_c = False, False
+        else:
+            skip_x, skip_c = skip
+
+        self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
+        self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
+
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x, time_token=None, time_ada=None,
+                skip=None, context=None,
+                x_mask=None, context_mask=None, extras=None):
+        if self.use_checkpoint:
+            return checkpoint(self._forward, x,
+                              time_token, time_ada, skip,
+                              context, x_mask, context_mask, extras,
+                              use_reentrant=False)
+        else:
+            return self._forward(x,
+                                 time_token, time_ada, skip,
+                                 context, x_mask, context_mask, extras)
+
+    def _forward(self, x, time_token=None, time_ada=None,
+                 skip=None, context=None,
+                 x_mask=None, context_mask=None, extras=None):
+
+        assert context is None and context_mask is None
+
+        context, x = x[:, :extras, :], x[:, extras:, :]
+        context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
+
+        if skip is not None:
+            skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
+
+        B, T, C = x.shape
+        if self.skip_linear_x is not None:
+            x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
+
+        if self.skip_linear_c is not None:
+            context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
+
+        if self.use_adanorm:
+            time_ada = self.adaln(time_token, time_ada)
+            (shift_msa, scale_msa, gate_msa,
+             shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+        # self attention
+        x_norm = self.attn_norm_x(x)
+        c_norm = self.attn_norm_c(context)
+        if self.use_adanorm:
+            x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
+        x_out, c_out = self.attn(x_norm, context=c_norm,
+                                 x_mask=x_mask, context_mask=context_mask,
+                                 extras=extras)
+        if self.use_adanorm:
+            x = x + (1 - gate_msa) * x_out
+        else:
+            x = x + x_out
+        context = context + c_out
+
+        # mlp
+        if self.use_adanorm:
+            x_norm = film_modulate(self.ffn_norm_x(x),
+                                   shift=shift_mlp, scale=scale_mlp)
+            x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
+        else:
+            x = x + self.mlp_x(self.ffn_norm_x(x))
+
+        c_norm = self.ffn_norm_c(context)
+        context = context + self.mlp_c(c_norm)
+
+        return torch.cat((context, x), dim=1)
+
+
+class FinalBlock(nn.Module):
+    def __init__(self, embed_dim, patch_size, in_chans,
+                 img_size,
+                 input_type='2d',
+                 norm_layer=nn.LayerNorm,
+                 use_conv=True,
+                 use_adanorm=True):
+        super().__init__()
+        self.in_chans = in_chans
+        self.img_size = img_size
+        self.input_type = input_type
+
+        self.norm = norm_layer(embed_dim)
+        if use_adanorm:
+            self.use_adanorm = True
+        else:
+            self.use_adanorm = False
+
+        if input_type == '2d':
+            self.patch_dim = patch_size ** 2 * in_chans
+            self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+            if use_conv:
+                self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 
+                                             3, padding=1)
+            else:
+                self.final_layer = nn.Identity()
+
+        elif input_type == '1d':
+            self.patch_dim = patch_size * in_chans
+            self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+            if use_conv:
+                self.final_layer = nn.Conv1d(self.in_chans, self.in_chans, 
+                                             3, padding=1)
+            else:
+                self.final_layer = nn.Identity()
+
+    def forward(self, x, time_ada=None, extras=0):
+        B, T, C = x.shape
+        x = x[:, extras:, :]
+        # only handle generation target
+        if self.use_adanorm:
+            shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
+            x = film_modulate(self.norm(x), shift, scale)
+        else:
+            x = self.norm(x)
+        x = self.linear(x)
+        x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
+        x = self.final_layer(x)
+        return x
\ No newline at end of file
diff --git a/src/models/.ipynb_checkpoints/conditioners-checkpoint.py b/src/models/.ipynb_checkpoints/conditioners-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cade7febf61ef005f421c42cf17bb1bb2935a751
--- /dev/null
+++ b/src/models/.ipynb_checkpoints/conditioners-checkpoint.py
@@ -0,0 +1,183 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import repeat
+import math
+from .udit import UDiT
+from .utils.span_mask import compute_mask_indices
+
+
+class EmbeddingCFG(nn.Module):
+    """
+    Handles label dropout for classifier-free guidance.
+    """
+    # todo: support 2D input
+
+    def __init__(self, in_channels):
+        super().__init__()
+        self.cfg_embedding = nn.Parameter(
+            torch.randn(in_channels) / in_channels ** 0.5)
+
+    def token_drop(self, condition, condition_mask, cfg_prob):
+        """
+        Drops labels to enable classifier-free guidance.
+        """
+        b, t, device = condition.shape[0], condition.shape[1], condition.device
+        drop_ids = torch.rand(b, device=device) < cfg_prob
+        uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
+        condition = torch.where(drop_ids[:, None, None], uncond, condition)
+        if condition_mask is not None:
+            condition_mask[drop_ids] = False
+            condition_mask[drop_ids, 0] = True
+
+        return condition, condition_mask
+
+    def forward(self, condition, condition_mask, cfg_prob=0.0):
+        if condition_mask is not None:
+            condition_mask = condition_mask.clone()
+        if cfg_prob > 0:
+            condition, condition_mask = self.token_drop(condition,
+                                                        condition_mask,
+                                                        cfg_prob)
+        return condition, condition_mask
+
+
+class DiscreteCFG(nn.Module):
+    def __init__(self, replace_id=2):
+        super(DiscreteCFG, self).__init__()
+        self.replace_id = replace_id
+
+    def forward(self, context, context_mask, cfg_prob):
+        context = context.clone()
+        if context_mask is not None:
+            context_mask = context_mask.clone()
+        if cfg_prob > 0:
+            cfg_mask = torch.rand(len(context)) < cfg_prob
+            if torch.any(cfg_mask):
+                context[cfg_mask] = 0
+                context[cfg_mask, 0] = self.replace_id
+                if context_mask is not None:
+                    context_mask[cfg_mask] = False
+                    context_mask[cfg_mask, 0] = True
+        return context, context_mask
+
+
+class CFGModel(nn.Module):
+    def __init__(self, context_dim, backbone):
+        super().__init__()
+        self.model = backbone
+        self.context_cfg = EmbeddingCFG(context_dim)
+
+    def forward(self, x, timesteps,
+                context, x_mask=None, context_mask=None,
+                cfg_prob=0.0):
+        context = self.context_cfg(context, cfg_prob)
+        x = self.model(x=x, timesteps=timesteps,
+                       context=context,
+                       x_mask=x_mask, context_mask=context_mask)
+        return x
+
+
+class ConcatModel(nn.Module):
+    def __init__(self, backbone, in_dim, stride=[]):
+        super().__init__()
+        self.model = backbone
+
+        self.downsample_layers = nn.ModuleList()
+        for i, s in enumerate(stride):
+            downsample_layer = nn.Conv1d(
+                in_dim,
+                in_dim * 2,
+                kernel_size=2 * s,
+                stride=s,
+                padding=math.ceil(s / 2),
+            )
+            self.downsample_layers.append(downsample_layer)
+            in_dim = in_dim * 2
+
+        self.context_cfg = EmbeddingCFG(in_dim)
+
+    def forward(self, x, timesteps,
+                context, x_mask=None,
+                cfg=False, cfg_prob=0.0):
+
+        # todo: support 2D input
+        # x: B, C, L
+        # context: B, C, L
+
+        for downsample_layer in self.downsample_layers:
+            context = downsample_layer(context)
+
+        context = context.transpose(1, 2)
+        context = self.context_cfg(caption=context,
+                                   cfg=cfg, cfg_prob=cfg_prob)
+        context = context.transpose(1, 2)
+
+        assert context.shape[-1] == x.shape[-1]
+        x = torch.cat([context, x], dim=1)
+        x = self.model(x=x, timesteps=timesteps,
+                       context=None, x_mask=x_mask, context_mask=None)
+        return x
+
+
+class MaskDiT(nn.Module):
+    def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
+        super().__init__()
+        self.model = UDiT(**kwargs)
+        self.mae = mae
+        if self.mae:
+            out_channel = kwargs.pop('out_chans', None)
+            self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
+            self.mae_prob = mae_prob
+            self.mask_ratio = mask_ratio
+            self.mask_span = mask_span
+
+    def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+        B, D, L = gt.shape
+        if mae_mask_infer is None:
+            # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+            mask_ratios = mask_ratios.cpu().numpy()
+            mask = compute_mask_indices(shape=[B, L],
+                                        padding_mask=None,
+                                        mask_prob=mask_ratios,
+                                        mask_length=self.mask_span,
+                                        mask_type="static",
+                                        mask_other=0.0,
+                                        min_masks=1,
+                                        no_overlap=False,
+                                        min_space=0,)
+            mask = mask.unsqueeze(1).expand_as(gt)
+        else:
+            mask = mae_mask_infer
+            mask = mask.expand_as(gt)
+        gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
+        return gt, mask.type_as(gt)
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None, cls_token=None,
+                gt=None, mae_mask_infer=None,
+                forward_model=True):
+        # todo: handle controlnet inside
+        mae_mask = torch.ones_like(x)
+        if self.mae:
+            if gt is not None:
+                B, D, L = gt.shape
+                mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
+                gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
+                # apply mae only to the selected batches
+                if mae_mask_infer is None:
+                    # determine mae batch
+                    mae_batch = torch.rand(B) < self.mae_prob
+                    gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
+                    mae_mask[~mae_batch] = 1.0
+            else:
+                B, D, L = x.shape
+                gt = self.mask_embed.view(1, D, 1).expand_as(x)
+            x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
+
+        if forward_model:
+            x = self.model(x=x, timesteps=timesteps, context=context,
+                           x_mask=x_mask, context_mask=context_mask,
+                           cls_token=cls_token)
+            # print(mae_mask[:, 0, :].sum(dim=-1))
+        return x, mae_mask
diff --git a/src/models/.ipynb_checkpoints/controlnet-checkpoint.py b/src/models/.ipynb_checkpoints/controlnet-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..1750621847ed116a6fbab55a50e67963699d6a5a
--- /dev/null
+++ b/src/models/.ipynb_checkpoints/controlnet-checkpoint.py
@@ -0,0 +1,318 @@
+import torch
+import torch.nn as nn
+
+from .utils.modules import PatchEmbed, TimestepEmbedder
+from .utils.modules import PE_wrapper, RMSNorm
+from .blocks import DiTBlock, JointDiTBlock
+from .utils.span_mask import compute_mask_indices
+
+
+class DiTControlNetEmbed(nn.Module):
+    def __init__(self, in_chans, out_chans, blocks,
+                 cond_mask=False, cond_mask_prob=None,
+                 cond_mask_ratio=None, cond_mask_span=None):
+        super().__init__()
+        self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1)
+
+        self.cond_mask = cond_mask
+        if self.cond_mask:
+            self.mask_embed = nn.Parameter(torch.zeros((blocks[0])))
+            self.mask_prob = cond_mask_prob
+            self.mask_ratio = cond_mask_ratio
+            self.mask_span = cond_mask_span
+            blocks[0] = blocks[0] + 1
+
+        conv_blocks = []
+        for i in range(len(blocks) - 1):
+            channel_in = blocks[i]
+            channel_out = blocks[i + 1]
+            block = nn.Sequential(
+                nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1),
+                nn.SiLU(),
+                nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2),
+                nn.SiLU(),)
+            conv_blocks.append(block)
+            self.blocks = nn.ModuleList(conv_blocks)
+
+        self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1)
+        nn.init.zeros_(self.conv_out.weight)
+        nn.init.zeros_(self.conv_out.bias)
+
+    def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+        B, D, L = gt.shape
+        if mae_mask_infer is None:
+            # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+            mask_ratios = mask_ratios.cpu().numpy()
+            mask = compute_mask_indices(shape=[B, L],
+                                        padding_mask=None,
+                                        mask_prob=mask_ratios,
+                                        mask_length=self.mask_span,
+                                        mask_type="static",
+                                        mask_other=0.0,
+                                        min_masks=1,
+                                        no_overlap=False,
+                                        min_space=0,)
+            # only apply mask to some batches
+            mask_batch = torch.rand(B) < self.mask_prob
+            mask[~mask_batch] = False
+            mask = mask.unsqueeze(1).expand_as(gt)
+        else:
+            mask = mae_mask_infer
+            mask = mask.expand_as(gt)
+        gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt)
+        return gt, mask.type_as(gt)
+
+    def forward(self, conditioning, cond_mask_infer=None):
+        embedding = self.conv_in(conditioning)
+
+        if self.cond_mask:
+            B, D, L = embedding.shape
+            if not self.training and cond_mask_infer is None:
+                cond_mask_infer = torch.zeros_like(embedding).bool()
+            mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device)
+            embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer)
+            embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1)
+
+        for block in self.blocks:
+            embedding = block(embedding)
+
+        embedding = self.conv_out(embedding)
+
+        # B, L, C
+        embedding = embedding.transpose(1, 2).contiguous()
+
+        return embedding
+
+
+class DiTControlNet(nn.Module):
+    def __init__(self,
+                 img_size=(224, 224), patch_size=16, in_chans=3,
+                 input_type='2d', out_chans=None,
+                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer='layernorm',
+                 context_norm=False,
+                 use_checkpoint=False,
+                 # time fusion ada or token
+                 time_fusion='token',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 cls_dim=None,
+                 # max length is only used for concat
+                 context_dim=768, context_fusion='concat',
+                 context_max_length=128, context_pe_method='sinu',
+                 pe_method='abs', rope_mode='none',
+                 use_conv=True,
+                 skip=True, skip_norm=True,
+                 # controlnet configs
+                 cond_in=None, cond_blocks=None, 
+                 cond_mask=False, cond_mask_prob=None,
+                 cond_mask_ratio=None, cond_mask_span=None,
+                 **kwargs):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim
+        # input
+        self.in_chans = in_chans
+        self.input_type = input_type
+        if self.input_type == '2d':
+            num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+        elif self.input_type == '1d':
+            num_patches = img_size // patch_size
+        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
+                                      embed_dim=embed_dim, input_type=input_type)
+        out_chans = in_chans if out_chans is None else out_chans
+        self.out_chans = out_chans
+
+        # position embedding
+        self.rope = rope_mode
+        self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
+                               length=num_patches)
+
+        print(f'x position embedding: {pe_method}')
+        print(f'rope mode: {self.rope}')
+
+        # time embed
+        self.time_embed = TimestepEmbedder(embed_dim)
+        self.time_fusion = time_fusion
+        self.use_adanorm = False
+
+        # cls embed
+        if cls_dim is not None:
+            self.cls_embed = nn.Sequential(
+                nn.Linear(cls_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+        else:
+            self.cls_embed = None
+
+        # time fusion
+        if time_fusion == 'token':
+            # put token at the beginning of sequence
+            self.extras = 2 if self.cls_embed else 1
+            self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
+        elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
+            self.use_adanorm = True
+            # aviod  repetitive silu for each adaln block
+            self.time_act = nn.SiLU()
+            self.extras = 0
+            if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
+                # shared adaln
+                self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+            else:
+                self.time_ada = None
+        else:
+            raise NotImplementedError
+        print(f'time fusion mode: {self.time_fusion}')
+
+        # context
+        # use a simple projection
+        self.use_context = False
+        self.context_cross = False
+        self.context_max_length = context_max_length
+        self.context_fusion = 'none'
+        if context_dim is not None:
+            self.use_context = True
+            self.context_embed = nn.Sequential(
+                nn.Linear(context_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+            self.context_fusion = context_fusion
+            if context_fusion == 'concat' or context_fusion == 'joint':
+                self.extras += context_max_length
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                # no cross attention layers
+                context_dim = None
+            elif context_fusion == 'cross':
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                self.context_cross = True
+                context_dim = embed_dim
+            else:
+                raise NotImplementedError
+        print(f'context fusion mode: {context_fusion}')
+        print(f'context position embedding: {context_pe_method}')
+
+        if self.context_fusion == 'joint':
+            Block = JointDiTBlock
+        else:
+            Block = DiTBlock
+
+        # norm layers
+        if norm_layer == 'layernorm':
+            norm_layer = nn.LayerNorm
+        elif norm_layer == 'rmsnorm':
+            norm_layer = RMSNorm
+        else:
+            raise NotImplementedError
+
+        self.in_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=False, skip_norm=False,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim,
+                                                 blocks=cond_blocks,
+                                                 cond_mask=cond_mask, 
+                                                 cond_mask_prob=cond_mask_prob,
+                                                 cond_mask_ratio=cond_mask_ratio,
+                                                 cond_mask_span=cond_mask_span)
+
+        controlnet_zero_blocks = []
+        for i in range(depth // 2):
+            block = nn.Linear(embed_dim, embed_dim)
+            nn.init.zeros_(block.weight)
+            nn.init.zeros_(block.bias)
+            controlnet_zero_blocks.append(block)
+        self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks)
+
+        print('ControlNet ready \n')
+
+    def set_trainable(self):
+        for param in self.parameters():
+            param.requires_grad = False
+
+        # only train input_proj, blocks, and output_proj
+        for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']:
+            module = getattr(self, module_name, None)
+            if module is not None:
+                for param in module.parameters():
+                    param.requires_grad = True
+                module.train()
+            else:
+                print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n')
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None,
+                cls_token=None,
+                condition=None, cond_mask_infer=None,
+                conditioning_scale=1.0):
+        # make it compatible with int time step during inference
+        if timesteps.dim() == 0:
+            timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
+
+        x = self.patch_embed(x)
+        # add condition to x
+        condition = self.controlnet_pre(condition)
+        x = x + condition
+        x = self.x_pe(x)
+
+        B, L, D = x.shape
+
+        if self.use_context:
+            context_token = self.context_embed(context)
+            context_token = self.context_pe(context_token)
+            if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+                x, x_mask = self._concat_x_context(x=x, context=context_token,
+                                                   x_mask=x_mask,
+                                                   context_mask=context_mask)
+                context_token, context_mask = None, None
+        else:
+            context_token, context_mask = None, None
+
+        time_token = self.time_embed(timesteps)
+        if self.cls_embed:
+            cls_token = self.cls_embed(cls_token)
+        time_ada = None
+        if self.use_adanorm:
+            if self.cls_embed:
+                time_token = time_token + cls_token
+            time_token = self.time_act(time_token)
+            if self.time_ada is not None:
+                time_ada = self.time_ada(time_token)
+        else:
+            time_token = time_token.unsqueeze(dim=1)
+            if self.cls_embed:
+                cls_token = cls_token.unsqueeze(dim=1)
+                time_token = torch.cat([time_token, cls_token], dim=1)
+            time_token = self.time_pe(time_token)
+            x = torch.cat((time_token, x), dim=1)
+            if x_mask is not None:
+                x_mask = torch.cat(
+                    [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
+                     x_mask], dim=1)
+            time_token = None
+
+        skips = []
+        for blk in self.in_blocks:
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=None, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+            skips.append(x)
+
+        controlnet_skips = []
+        for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks):
+            controlnet_skips.append(controlnet_block(skip) * conditioning_scale)
+
+        return controlnet_skips
\ No newline at end of file
diff --git a/src/models/.ipynb_checkpoints/udit-checkpoint.py b/src/models/.ipynb_checkpoints/udit-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e126efd370efabbfcc4f4359194f9c95c6e9d154
--- /dev/null
+++ b/src/models/.ipynb_checkpoints/udit-checkpoint.py
@@ -0,0 +1,365 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import math
+from .utils.modules import PatchEmbed, TimestepEmbedder
+from .utils.modules import PE_wrapper, RMSNorm
+from .blocks import DiTBlock, JointDiTBlock, FinalBlock
+
+
+class UDiT(nn.Module):
+    def __init__(self,
+                 img_size=224, patch_size=16, in_chans=3,
+                 input_type='2d', out_chans=None,
+                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer='layernorm',
+                 context_norm=False,
+                 use_checkpoint=False,
+                 # time fusion ada or token
+                 time_fusion='token',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 cls_dim=None,
+                 # max length is only used for concat
+                 context_dim=768, context_fusion='concat',
+                 context_max_length=128, context_pe_method='sinu',
+                 pe_method='abs', rope_mode='none',
+                 use_conv=True,
+                 skip=True, skip_norm=True):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+
+        # input
+        self.in_chans = in_chans
+        self.input_type = input_type
+        if self.input_type == '2d':
+            num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+        elif self.input_type == '1d':
+            num_patches = img_size // patch_size
+        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
+                                      embed_dim=embed_dim, input_type=input_type)
+        out_chans = in_chans if out_chans is None else out_chans
+        self.out_chans = out_chans
+
+        # position embedding
+        self.rope = rope_mode
+        self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
+                               length=num_patches)
+
+        print(f'x position embedding: {pe_method}')
+        print(f'rope mode: {self.rope}')
+
+        # time embed
+        self.time_embed = TimestepEmbedder(embed_dim)
+        self.time_fusion = time_fusion
+        self.use_adanorm = False
+
+        # cls embed
+        if cls_dim is not None:
+            self.cls_embed = nn.Sequential(
+                nn.Linear(cls_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+        else:
+            self.cls_embed = None
+
+        # time fusion
+        if time_fusion == 'token':
+            # put token at the beginning of sequence
+            self.extras = 2 if self.cls_embed else 1
+            self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
+        elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
+            self.use_adanorm = True
+            # aviod  repetitive silu for each adaln block
+            self.time_act = nn.SiLU()
+            self.extras = 0
+            self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
+            if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
+                # shared adaln
+                self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+            else:
+                self.time_ada = None
+        else:
+            raise NotImplementedError
+        print(f'time fusion mode: {self.time_fusion}')
+
+        # context
+        # use a simple projection
+        self.use_context = False
+        self.context_cross = False
+        self.context_max_length = context_max_length
+        self.context_fusion = 'none'
+        if context_dim is not None:
+            self.use_context = True
+            self.context_embed = nn.Sequential(
+                nn.Linear(context_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+            self.context_fusion = context_fusion
+            if context_fusion == 'concat' or context_fusion == 'joint':
+                self.extras += context_max_length
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                # no cross attention layers
+                context_dim = None
+            elif context_fusion == 'cross':
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                self.context_cross = True
+                context_dim = embed_dim
+            else:
+                raise NotImplementedError
+        print(f'context fusion mode: {context_fusion}')
+        print(f'context position embedding: {context_pe_method}')
+
+        if self.context_fusion == 'joint':
+            Block = JointDiTBlock
+            self.use_skip = skip[0]
+        else:
+            Block = DiTBlock
+            self.use_skip = skip
+
+        # norm layers
+        if norm_layer == 'layernorm':
+            norm_layer = nn.LayerNorm
+        elif norm_layer == 'rmsnorm':
+            norm_layer = RMSNorm
+        else:
+            raise NotImplementedError
+
+        print(f'use long skip connection: {skip}')
+        self.in_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=False, skip_norm=False,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        self.mid_block = Block(
+            dim=embed_dim, context_dim=context_dim, num_heads=num_heads, 
+            mlp_ratio=mlp_ratio,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+            act_layer=act_layer, norm_layer=norm_layer,
+            time_fusion=time_fusion,
+            ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+            skip=False, skip_norm=False,
+            rope_mode=self.rope,
+            context_norm=context_norm,
+            use_checkpoint=use_checkpoint)
+
+        self.out_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads, 
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=skip, skip_norm=skip_norm,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        # FinalLayer block
+        self.use_conv = use_conv
+        self.final_block = FinalBlock(embed_dim=embed_dim,
+                                      patch_size=patch_size,
+                                      img_size=img_size,
+                                      in_chans=out_chans,
+                                      input_type=input_type,
+                                      norm_layer=norm_layer,
+                                      use_conv=use_conv,
+                                      use_adanorm=self.use_adanorm)
+        self.initialize_weights()
+
+    def _init_ada(self):
+        if self.time_fusion == 'ada':
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+            for block in self.in_blocks:
+                nn.init.constant_(block.adaln.time_ada.weight, 0)
+                nn.init.constant_(block.adaln.time_ada.bias, 0)
+            nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
+            nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
+            for block in self.out_blocks:
+                nn.init.constant_(block.adaln.time_ada.weight, 0)
+                nn.init.constant_(block.adaln.time_ada.bias, 0)
+        elif self.time_fusion == 'ada_single':
+            nn.init.constant_(self.time_ada.weight, 0)
+            nn.init.constant_(self.time_ada.bias, 0)
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+        elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
+            nn.init.constant_(self.time_ada.weight, 0)
+            nn.init.constant_(self.time_ada.bias, 0)
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+            for block in self.in_blocks:
+                nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+                                         a=math.sqrt(5))
+                nn.init.constant_(block.adaln.lora_b.weight, 0)
+            nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
+                                     a=math.sqrt(5))
+            nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
+            for block in self.out_blocks:
+                nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+                                         a=math.sqrt(5))
+                nn.init.constant_(block.adaln.lora_b.weight, 0)
+
+    def initialize_weights(self):
+        # Basic init for all layers
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+        # init patch Conv like Linear
+        w = self.patch_embed.proj.weight.data
+        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+        nn.init.constant_(self.patch_embed.proj.bias, 0)
+
+        # Zero-out AdaLN
+        if self.use_adanorm:
+            self._init_ada()
+
+        # Zero-out Cross Attention
+        if self.context_cross:
+            for block in self.in_blocks:
+                nn.init.constant_(block.cross_attn.proj.weight, 0)
+                nn.init.constant_(block.cross_attn.proj.bias, 0)
+            nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
+            nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
+            for block in self.out_blocks:
+                nn.init.constant_(block.cross_attn.proj.weight, 0)
+                nn.init.constant_(block.cross_attn.proj.bias, 0)
+
+        # Zero-out cls embedding
+        if self.cls_embed:
+            if self.use_adanorm:
+                nn.init.constant_(self.cls_embed[-1].weight, 0)
+                nn.init.constant_(self.cls_embed[-1].bias, 0)
+
+        # Zero-out Output
+        # might not zero-out this when using v-prediction
+        # it could be good when using noise-prediction
+        # nn.init.constant_(self.final_block.linear.weight, 0)
+        # nn.init.constant_(self.final_block.linear.bias, 0)
+        # if self.use_conv:
+        #     nn.init.constant_(self.final_block.final_layer.weight.data, 0)
+        #     nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+        # init out Conv
+        if self.use_conv:
+            nn.init.xavier_uniform_(self.final_block.final_layer.weight)
+            nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+    def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
+        assert context.shape[-2] == self.context_max_length
+        # Check if either x_mask or context_mask is provided
+        B = x.shape[0]
+        # Create default masks if they are not provided
+        if x_mask is None:
+            x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+        if context_mask is None:
+            context_mask = torch.ones(B, context.shape[-2],
+                                      device=context.device).bool()
+        # Concatenate the masks along the second dimension (dim=1)
+        x_mask = torch.cat([context_mask, x_mask], dim=1)
+        # Concatenate context and x along the second dimension (dim=1)
+        x = torch.cat((context, x), dim=1)
+        return x, x_mask
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None,
+                cls_token=None, controlnet_skips=None,
+               ):
+        # make it compatible with int time step during inference
+        if timesteps.dim() == 0:
+            timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
+
+        x = self.patch_embed(x)
+        x = self.x_pe(x)
+
+        B, L, D = x.shape
+
+        if self.use_context:
+            context_token = self.context_embed(context)
+            context_token = self.context_pe(context_token)
+            if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+                x, x_mask = self._concat_x_context(x=x, context=context_token,
+                                                   x_mask=x_mask,
+                                                   context_mask=context_mask)
+                context_token, context_mask = None, None
+        else:
+            context_token, context_mask = None, None
+
+        time_token = self.time_embed(timesteps)
+        if self.cls_embed:
+            cls_token = self.cls_embed(cls_token)
+        time_ada = None
+        time_ada_final = None
+        if self.use_adanorm:
+            if self.cls_embed:
+                time_token = time_token + cls_token
+            time_token = self.time_act(time_token)
+            time_ada_final = self.time_ada_final(time_token)
+            if self.time_ada is not None:
+                time_ada = self.time_ada(time_token)
+        else:
+            time_token = time_token.unsqueeze(dim=1)
+            if self.cls_embed:
+                cls_token = cls_token.unsqueeze(dim=1)
+                time_token = torch.cat([time_token, cls_token], dim=1)
+            time_token = self.time_pe(time_token)
+            x = torch.cat((time_token, x), dim=1)
+            if x_mask is not None:
+                x_mask = torch.cat(
+                    [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
+                     x_mask], dim=1)
+            time_token = None
+
+        skips = []
+        for blk in self.in_blocks:
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=None, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+            if self.use_skip:
+                skips.append(x)
+
+        x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
+                           skip=None, context=context_token,
+                           x_mask=x_mask, context_mask=context_mask,
+                           extras=self.extras)
+        for blk in self.out_blocks:
+            if self.use_skip:
+                skip = skips.pop()
+                if controlnet_skips:
+                    # add to skip like u-net controlnet
+                    skip = skip + controlnet_skips.pop()
+            else:
+                skip = None
+                if controlnet_skips:
+                    # directly add to x
+                    x = x + controlnet_skips.pop()
+
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=skip, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+
+        x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+        return x
\ No newline at end of file
diff --git a/src/models/__pycache__/attention.cpython-311.pyc b/src/models/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f0d5a6c50f65f9163dc99a4b7d27343a5167677
Binary files /dev/null and b/src/models/__pycache__/attention.cpython-311.pyc differ
diff --git a/src/models/__pycache__/blocks.cpython-310.pyc b/src/models/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7461fb52ed06192ceb625e9a60ee1e5fb68d6b9
Binary files /dev/null and b/src/models/__pycache__/blocks.cpython-310.pyc differ
diff --git a/src/models/__pycache__/blocks.cpython-311.pyc b/src/models/__pycache__/blocks.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e2e896fbaf9999c48a80ab82f91eaf281fd6e31
Binary files /dev/null and b/src/models/__pycache__/blocks.cpython-311.pyc differ
diff --git a/src/models/__pycache__/conditioners.cpython-310.pyc b/src/models/__pycache__/conditioners.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cacc8c6625287466c020bb10ab7c4ecd352e2d72
Binary files /dev/null and b/src/models/__pycache__/conditioners.cpython-310.pyc differ
diff --git a/src/models/__pycache__/conditioners.cpython-311.pyc b/src/models/__pycache__/conditioners.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5327c1cb1adb18a3c8e198a6aad33788d244680c
Binary files /dev/null and b/src/models/__pycache__/conditioners.cpython-311.pyc differ
diff --git a/src/models/__pycache__/controlnet.cpython-311.pyc b/src/models/__pycache__/controlnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fd2690660df68d7bcd3a749771c43615891d11d
Binary files /dev/null and b/src/models/__pycache__/controlnet.cpython-311.pyc differ
diff --git a/src/models/__pycache__/modules.cpython-311.pyc b/src/models/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a075bf74a7e0860dfbb3bf07774efc46eda5face
Binary files /dev/null and b/src/models/__pycache__/modules.cpython-311.pyc differ
diff --git a/src/models/__pycache__/rotary.cpython-311.pyc b/src/models/__pycache__/rotary.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1ba7bb807f377ad675aae2086a2f5147407c850
Binary files /dev/null and b/src/models/__pycache__/rotary.cpython-311.pyc differ
diff --git a/src/models/__pycache__/timm.cpython-311.pyc b/src/models/__pycache__/timm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..381ee0b6c15f37a71522fc8d52501e1158d5e461
Binary files /dev/null and b/src/models/__pycache__/timm.cpython-311.pyc differ
diff --git a/src/models/__pycache__/udit.cpython-310.pyc b/src/models/__pycache__/udit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d83aeee69c058f58072944bf4188c549393916c0
Binary files /dev/null and b/src/models/__pycache__/udit.cpython-310.pyc differ
diff --git a/src/models/__pycache__/udit.cpython-311.pyc b/src/models/__pycache__/udit.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c2807d366b0eb4030274b26d5195f3e4e0d7604
Binary files /dev/null and b/src/models/__pycache__/udit.cpython-311.pyc differ
diff --git a/src/models/blocks.py b/src/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef730009cb7cf664c5d9021e551b275680d11f3
--- /dev/null
+++ b/src/models/blocks.py
@@ -0,0 +1,325 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from .utils.attention import Attention, JointAttention
+from .utils.modules import unpatchify, FeedForward
+from .utils.modules import film_modulate
+
+
+class AdaLN(nn.Module):
+    def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
+        super().__init__()
+        self.ada_mode = ada_mode
+        self.scale_shift_table = None
+        if ada_mode == 'ada':
+            # move nn.silu outside
+            self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
+        elif ada_mode == 'ada_single':
+            # adaln used in pixel-art alpha
+            self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+        elif ada_mode in ['ada_lora', 'ada_lora_bias']:
+            self.lora_a = nn.Linear(dim, r * 6, bias=False)
+            self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
+            self.scaling = alpha / r
+            if ada_mode == 'ada_lora_bias':
+                # take bias out for consistency
+                self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+        else:
+            raise NotImplementedError
+
+    def forward(self, time_token=None, time_ada=None):
+        if self.ada_mode == 'ada':
+            assert time_ada is None
+            B = time_token.shape[0]
+            time_ada = self.time_ada(time_token).reshape(B, 6, -1)
+        elif self.ada_mode == 'ada_single':
+            B = time_ada.shape[0]
+            time_ada = time_ada.reshape(B, 6, -1)
+            time_ada = self.scale_shift_table[None] + time_ada
+        elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
+            B = time_ada.shape[0]
+            time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
+            time_ada = time_ada + time_ada_lora
+            time_ada = time_ada.reshape(B, 6, -1)
+            if self.scale_shift_table is not None:
+                time_ada = self.scale_shift_table[None] + time_ada
+        else:
+            raise NotImplementedError
+        return time_ada
+
+
+class DiTBlock(nn.Module):
+    """
+    A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+    """
+
+    def __init__(self, dim, context_dim=None,
+                 num_heads=8, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer=nn.LayerNorm,
+                 time_fusion='none',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 skip=False, skip_norm=False,
+                 rope_mode='none',
+                 context_norm=False,
+                 use_checkpoint=False):
+
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(dim=dim,
+                              num_heads=num_heads,
+                              qkv_bias=qkv_bias, qk_scale=qk_scale,
+                              qk_norm=qk_norm,
+                              rope_mode=rope_mode)
+
+        if context_dim is not None:
+            self.use_context = True
+            self.cross_attn = Attention(dim=dim,
+                                        num_heads=num_heads,
+                                        context_dim=context_dim,
+                                        qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                        qk_norm=qk_norm,
+                                        rope_mode='none')
+            self.norm2 = norm_layer(dim)
+            if context_norm:
+                self.norm_context = norm_layer(context_dim)
+            else:
+                self.norm_context = nn.Identity()
+        else:
+            self.use_context = False
+
+        self.norm3 = norm_layer(dim)
+        self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
+                               activation_fn=act_layer, dropout=0)
+
+        self.use_adanorm = True if time_fusion != 'token' else False
+        if self.use_adanorm:
+            self.adaln = AdaLN(dim, ada_mode=time_fusion,
+                               r=ada_lora_rank, alpha=ada_lora_alpha)
+        if skip:
+            self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
+            self.skip_linear = nn.Linear(2 * dim, dim)
+        else:
+            self.skip_linear = None
+            
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x, time_token=None, time_ada=None,
+                skip=None, context=None,
+                x_mask=None, context_mask=None, extras=None):
+        if self.use_checkpoint:
+            return checkpoint(self._forward, x,
+                              time_token, time_ada, skip, context,
+                              x_mask, context_mask, extras,
+                              use_reentrant=False)
+        else:
+            return self._forward(x,
+                                 time_token, time_ada, skip, context,
+                                 x_mask, context_mask, extras)
+
+    def _forward(self, x, time_token=None, time_ada=None,
+                 skip=None, context=None,
+                 x_mask=None, context_mask=None, extras=None):
+        B, T, C = x.shape
+        if self.skip_linear is not None:
+            assert skip is not None
+            cat = torch.cat([x, skip], dim=-1)
+            cat = self.skip_norm(cat)
+            x = self.skip_linear(cat)
+
+        if self.use_adanorm:
+            time_ada = self.adaln(time_token, time_ada)
+            (shift_msa, scale_msa, gate_msa,
+             shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+        # self attention
+        if self.use_adanorm:
+            x_norm = film_modulate(self.norm1(x), shift=shift_msa,
+                                   scale=scale_msa)
+            x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
+                                               context_mask=x_mask,
+                                               extras=extras)
+        else:
+            x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
+                              extras=extras)
+
+        # cross attention
+        if self.use_context:
+            assert context is not None
+            x = x + self.cross_attn(x=self.norm2(x),
+                                    context=self.norm_context(context),
+                                    context_mask=context_mask, extras=extras)
+
+        # mlp
+        if self.use_adanorm:
+            x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
+            x = x + (1 - gate_mlp) * self.mlp(x_norm)
+        else:
+            x = x + self.mlp(self.norm3(x))
+
+        return x
+
+
+class JointDiTBlock(nn.Module):
+    """
+    A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+    """
+
+    def __init__(self, dim, context_dim=None,
+                 num_heads=8, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer=nn.LayerNorm,
+                 time_fusion='none',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 skip=(False, False),
+                 rope_mode=False,
+                 context_norm=False,
+                 use_checkpoint=False,):
+
+        super().__init__()
+        # no cross attention
+        assert context_dim is None
+        self.attn_norm_x = norm_layer(dim)
+        self.attn_norm_c = norm_layer(dim)
+        self.attn = JointAttention(dim=dim,
+                                   num_heads=num_heads,
+                                   qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                   qk_norm=qk_norm,
+                                   rope_mode=rope_mode)
+        self.ffn_norm_x = norm_layer(dim)
+        self.ffn_norm_c = norm_layer(dim)
+        self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
+                                 activation_fn=act_layer, dropout=0)
+        self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
+                                 activation_fn=act_layer, dropout=0)
+
+        # Zero-out the shift table
+        self.use_adanorm = True if time_fusion != 'token' else False
+        if self.use_adanorm:
+            self.adaln = AdaLN(dim, ada_mode=time_fusion,
+                               r=ada_lora_rank, alpha=ada_lora_alpha)
+
+        if skip is False:
+            skip_x, skip_c = False, False
+        else:
+            skip_x, skip_c = skip
+
+        self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
+        self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
+
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x, time_token=None, time_ada=None,
+                skip=None, context=None,
+                x_mask=None, context_mask=None, extras=None):
+        if self.use_checkpoint:
+            return checkpoint(self._forward, x,
+                              time_token, time_ada, skip,
+                              context, x_mask, context_mask, extras,
+                              use_reentrant=False)
+        else:
+            return self._forward(x,
+                                 time_token, time_ada, skip,
+                                 context, x_mask, context_mask, extras)
+
+    def _forward(self, x, time_token=None, time_ada=None,
+                 skip=None, context=None,
+                 x_mask=None, context_mask=None, extras=None):
+
+        assert context is None and context_mask is None
+
+        context, x = x[:, :extras, :], x[:, extras:, :]
+        context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
+
+        if skip is not None:
+            skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
+
+        B, T, C = x.shape
+        if self.skip_linear_x is not None:
+            x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
+
+        if self.skip_linear_c is not None:
+            context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
+
+        if self.use_adanorm:
+            time_ada = self.adaln(time_token, time_ada)
+            (shift_msa, scale_msa, gate_msa,
+             shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+        # self attention
+        x_norm = self.attn_norm_x(x)
+        c_norm = self.attn_norm_c(context)
+        if self.use_adanorm:
+            x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
+        x_out, c_out = self.attn(x_norm, context=c_norm,
+                                 x_mask=x_mask, context_mask=context_mask,
+                                 extras=extras)
+        if self.use_adanorm:
+            x = x + (1 - gate_msa) * x_out
+        else:
+            x = x + x_out
+        context = context + c_out
+
+        # mlp
+        if self.use_adanorm:
+            x_norm = film_modulate(self.ffn_norm_x(x),
+                                   shift=shift_mlp, scale=scale_mlp)
+            x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
+        else:
+            x = x + self.mlp_x(self.ffn_norm_x(x))
+
+        c_norm = self.ffn_norm_c(context)
+        context = context + self.mlp_c(c_norm)
+
+        return torch.cat((context, x), dim=1)
+
+
+class FinalBlock(nn.Module):
+    def __init__(self, embed_dim, patch_size, in_chans,
+                 img_size,
+                 input_type='2d',
+                 norm_layer=nn.LayerNorm,
+                 use_conv=True,
+                 use_adanorm=True):
+        super().__init__()
+        self.in_chans = in_chans
+        self.img_size = img_size
+        self.input_type = input_type
+
+        self.norm = norm_layer(embed_dim)
+        if use_adanorm:
+            self.use_adanorm = True
+        else:
+            self.use_adanorm = False
+
+        if input_type == '2d':
+            self.patch_dim = patch_size ** 2 * in_chans
+            self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+            if use_conv:
+                self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 
+                                             3, padding=1)
+            else:
+                self.final_layer = nn.Identity()
+
+        elif input_type == '1d':
+            self.patch_dim = patch_size * in_chans
+            self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+            if use_conv:
+                self.final_layer = nn.Conv1d(self.in_chans, self.in_chans, 
+                                             3, padding=1)
+            else:
+                self.final_layer = nn.Identity()
+
+    def forward(self, x, time_ada=None, extras=0):
+        B, T, C = x.shape
+        x = x[:, extras:, :]
+        # only handle generation target
+        if self.use_adanorm:
+            shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
+            x = film_modulate(self.norm(x), shift, scale)
+        else:
+            x = self.norm(x)
+        x = self.linear(x)
+        x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
+        x = self.final_layer(x)
+        return x
\ No newline at end of file
diff --git a/src/models/conditioners.py b/src/models/conditioners.py
new file mode 100644
index 0000000000000000000000000000000000000000..cade7febf61ef005f421c42cf17bb1bb2935a751
--- /dev/null
+++ b/src/models/conditioners.py
@@ -0,0 +1,183 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import repeat
+import math
+from .udit import UDiT
+from .utils.span_mask import compute_mask_indices
+
+
+class EmbeddingCFG(nn.Module):
+    """
+    Handles label dropout for classifier-free guidance.
+    """
+    # todo: support 2D input
+
+    def __init__(self, in_channels):
+        super().__init__()
+        self.cfg_embedding = nn.Parameter(
+            torch.randn(in_channels) / in_channels ** 0.5)
+
+    def token_drop(self, condition, condition_mask, cfg_prob):
+        """
+        Drops labels to enable classifier-free guidance.
+        """
+        b, t, device = condition.shape[0], condition.shape[1], condition.device
+        drop_ids = torch.rand(b, device=device) < cfg_prob
+        uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
+        condition = torch.where(drop_ids[:, None, None], uncond, condition)
+        if condition_mask is not None:
+            condition_mask[drop_ids] = False
+            condition_mask[drop_ids, 0] = True
+
+        return condition, condition_mask
+
+    def forward(self, condition, condition_mask, cfg_prob=0.0):
+        if condition_mask is not None:
+            condition_mask = condition_mask.clone()
+        if cfg_prob > 0:
+            condition, condition_mask = self.token_drop(condition,
+                                                        condition_mask,
+                                                        cfg_prob)
+        return condition, condition_mask
+
+
+class DiscreteCFG(nn.Module):
+    def __init__(self, replace_id=2):
+        super(DiscreteCFG, self).__init__()
+        self.replace_id = replace_id
+
+    def forward(self, context, context_mask, cfg_prob):
+        context = context.clone()
+        if context_mask is not None:
+            context_mask = context_mask.clone()
+        if cfg_prob > 0:
+            cfg_mask = torch.rand(len(context)) < cfg_prob
+            if torch.any(cfg_mask):
+                context[cfg_mask] = 0
+                context[cfg_mask, 0] = self.replace_id
+                if context_mask is not None:
+                    context_mask[cfg_mask] = False
+                    context_mask[cfg_mask, 0] = True
+        return context, context_mask
+
+
+class CFGModel(nn.Module):
+    def __init__(self, context_dim, backbone):
+        super().__init__()
+        self.model = backbone
+        self.context_cfg = EmbeddingCFG(context_dim)
+
+    def forward(self, x, timesteps,
+                context, x_mask=None, context_mask=None,
+                cfg_prob=0.0):
+        context = self.context_cfg(context, cfg_prob)
+        x = self.model(x=x, timesteps=timesteps,
+                       context=context,
+                       x_mask=x_mask, context_mask=context_mask)
+        return x
+
+
+class ConcatModel(nn.Module):
+    def __init__(self, backbone, in_dim, stride=[]):
+        super().__init__()
+        self.model = backbone
+
+        self.downsample_layers = nn.ModuleList()
+        for i, s in enumerate(stride):
+            downsample_layer = nn.Conv1d(
+                in_dim,
+                in_dim * 2,
+                kernel_size=2 * s,
+                stride=s,
+                padding=math.ceil(s / 2),
+            )
+            self.downsample_layers.append(downsample_layer)
+            in_dim = in_dim * 2
+
+        self.context_cfg = EmbeddingCFG(in_dim)
+
+    def forward(self, x, timesteps,
+                context, x_mask=None,
+                cfg=False, cfg_prob=0.0):
+
+        # todo: support 2D input
+        # x: B, C, L
+        # context: B, C, L
+
+        for downsample_layer in self.downsample_layers:
+            context = downsample_layer(context)
+
+        context = context.transpose(1, 2)
+        context = self.context_cfg(caption=context,
+                                   cfg=cfg, cfg_prob=cfg_prob)
+        context = context.transpose(1, 2)
+
+        assert context.shape[-1] == x.shape[-1]
+        x = torch.cat([context, x], dim=1)
+        x = self.model(x=x, timesteps=timesteps,
+                       context=None, x_mask=x_mask, context_mask=None)
+        return x
+
+
+class MaskDiT(nn.Module):
+    def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
+        super().__init__()
+        self.model = UDiT(**kwargs)
+        self.mae = mae
+        if self.mae:
+            out_channel = kwargs.pop('out_chans', None)
+            self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
+            self.mae_prob = mae_prob
+            self.mask_ratio = mask_ratio
+            self.mask_span = mask_span
+
+    def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+        B, D, L = gt.shape
+        if mae_mask_infer is None:
+            # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+            mask_ratios = mask_ratios.cpu().numpy()
+            mask = compute_mask_indices(shape=[B, L],
+                                        padding_mask=None,
+                                        mask_prob=mask_ratios,
+                                        mask_length=self.mask_span,
+                                        mask_type="static",
+                                        mask_other=0.0,
+                                        min_masks=1,
+                                        no_overlap=False,
+                                        min_space=0,)
+            mask = mask.unsqueeze(1).expand_as(gt)
+        else:
+            mask = mae_mask_infer
+            mask = mask.expand_as(gt)
+        gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
+        return gt, mask.type_as(gt)
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None, cls_token=None,
+                gt=None, mae_mask_infer=None,
+                forward_model=True):
+        # todo: handle controlnet inside
+        mae_mask = torch.ones_like(x)
+        if self.mae:
+            if gt is not None:
+                B, D, L = gt.shape
+                mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
+                gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
+                # apply mae only to the selected batches
+                if mae_mask_infer is None:
+                    # determine mae batch
+                    mae_batch = torch.rand(B) < self.mae_prob
+                    gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
+                    mae_mask[~mae_batch] = 1.0
+            else:
+                B, D, L = x.shape
+                gt = self.mask_embed.view(1, D, 1).expand_as(x)
+            x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
+
+        if forward_model:
+            x = self.model(x=x, timesteps=timesteps, context=context,
+                           x_mask=x_mask, context_mask=context_mask,
+                           cls_token=cls_token)
+            # print(mae_mask[:, 0, :].sum(dim=-1))
+        return x, mae_mask
diff --git a/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a60c5319ac77a8cdebd2835527256b547101700
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1 @@
+from .condition_wrapper import Conditioner
\ No newline at end of file
diff --git a/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07aebb898f6d8048099a22f4f9b4d8f1a3117fd
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py
@@ -0,0 +1,80 @@
+import typing as tp
+
+from einops import rearrange
+from librosa import filters
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+
+class ChromaExtractor(nn.Module):
+    """Chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate for the chroma extraction.
+        n_chroma (int): Number of chroma bins for the chroma extraction.
+        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+        nfft (int, optional): Number of FFT.
+        winlen (int, optional): Window length.
+        winhop (int, optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+    """
+    
+    def __init__(self, 
+                 sample_rate: int, 
+                 n_chroma: int = 12, radix2_exp: int = 12,
+                 nfft: tp.Optional[int] = None,
+                 winlen: tp.Optional[int] = None,
+                 winhop: tp.Optional[int] = None, argmax: bool = True,
+                 norm: float = torch.inf):
+        super().__init__()
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                                       n_chroma=self.n_chroma)), persistent=False)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=False,
+                                                      pad=0, normalized=True)
+
+    def forward(self, wav: torch.Tensor) -> torch.Tensor:
+        T = wav.shape[-1]
+        # in case we are getting a wav that was dropped out (nullified)
+        # from the conditioner, make sure wav length is no less that nfft
+        if T < self.nfft:
+            pad = self.nfft - T
+            r = 0 if pad % 2 == 0 else 1
+            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+        wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ),
+                          int(self.nfft // 2 - self.winhop // 2 )), mode="reflect")
+
+        spec = self.spec(wav).squeeze(1)
+        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+        if self.argmax:
+            idx = norm_chroma.argmax(-1, keepdim=True)
+            norm_chroma[:] = 0
+            norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+        return norm_chroma
+
+
+if __name__ == "__main__":
+    chroma = ChromaExtractor(sample_rate=16000,
+                             n_chroma=4,
+                             radix2_exp=None,
+                             winlen=16000,
+                             nfft=16000,
+                             winhop=4000)
+    audio = torch.rand(1, 16000)
+    c = chroma(audio)
\ No newline at end of file
diff --git a/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e44795abc16c3b9a9a30c8f579767879c25fe7
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py
@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+from .chroma import ChromaExtractor
+from .energy import EnergyExtractor
+from .voice import VoiceConversionExtractor
+from .mbenergy import MultibandEnergyExtractor
+
+
+class Conditioner(nn.Module):
+    def __init__(self,
+                 condition_type,
+                 **kwargs
+                ):
+        super().__init__()
+        if condition_type == 'energy':
+            self.conditioner = EnergyExtractor(**kwargs)
+        elif condition_type == 'chroma':
+            self.conditioner = ChromaExtractor(**kwargs)
+        elif condition_type == 'vc':
+            self.conditioner = VoiceConversionExtractor(**kwargs)
+        elif condition_type == 'mb_energy':
+            self.conditioner = MultibandEnergyExtractor(**kwargs)
+        else:
+            raise NotImplementedError
+
+    def forward(self, waveform, latent_shape):
+        # B T C
+        condition = self.conditioner(waveform)
+        # B C T
+        condition = condition.permute(0, 2, 1).contiguous()
+
+        if len(latent_shape) == 4:
+            # 2d spectrogram B C T F
+            assert (condition.shape[-1] % latent_shape[-2]) == 0
+            X = latent_shape[-1] * condition.shape[-1] // latent_shape[-2]
+            # copy on F direction
+            condition = condition.unsqueeze(-1).expand(-1, -1, -1, X)
+        elif len(latent_shape) == 3:
+            condition = condition
+        else:
+            raise NotImplementedError
+        return condition
+
+
+if __name__ == '__main__':
+    conditioner = Conditioner(condition_type='energy',
+                              hop_size=160, window_size=1024, padding='reflect',
+                              min_db=-80, norm=True)
+    audio = torch.rand(4, 16000)  # Example audio signal
+    energy = conditioner(audio, (4, 8, 100, 64))
\ No newline at end of file
diff --git a/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png b/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png
new file mode 100644
index 0000000000000000000000000000000000000000..a092693dc8e7cacb09b8eef632d1e7ececbf51e3
Binary files /dev/null and b/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png differ
diff --git a/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86481723850af419fa701254863789108596e88
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py
@@ -0,0 +1,85 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+
+class EnergyExtractor(nn.Module):
+    def __init__(self, hop_size: int = 512, window_size: int = 1024,
+                 padding: str = 'reflect', min_db: float = -60,
+                 norm: bool = True, quantize_levels: int = None):
+        super().__init__()
+        self.hop_size = hop_size
+        self.window_size = window_size
+        self.padding = padding
+        self.min_db = min_db
+        self.norm = norm
+        self.quantize_levels = quantize_levels
+
+    def forward(self, audio: torch.Tensor) -> torch.Tensor:
+        # Compute number of frames
+        n_frames = int(audio.size(-1) // self.hop_size)
+
+        # Pad the audio signal
+        pad_amount = (self.window_size - self.hop_size) // 2
+        audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding)
+
+        # Square the padded audio signal
+        audio_squared = audio_padded ** 2
+
+        # Compute the mean energy for each frame using unfold and mean
+        audio_squared = audio_squared[:, None, None, :]
+        energy = F.unfold(audio_squared, (1, self.window_size), stride=self.hop_size)[:, :, :n_frames]
+        energy = energy.mean(dim=1)
+
+        # Compute the square root of the mean energy to get the RMS energy
+        # energy = torch.sqrt(energy)
+
+        # Normalize the energy using the min_db value
+        gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device))
+        gain_db = 10 * torch.log10(gain)
+
+        if self.norm:
+            # Find the min and max of gain_db
+            # min_gain_db = torch.min(gain_db)
+            min_gain_db = self.min_db
+            max_gain_db = torch.max(gain_db, dim=-1, keepdim=True)[0]
+
+            # Avoid numerical error by adding a small epsilon to the denominator
+            epsilon = 1e-8
+            gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon)
+
+        if self.quantize_levels is not None:
+            # Quantize the result to the given number of levels
+            gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1)
+
+        return gain_db.unsqueeze(-1)
+
+
+if __name__ == "__main__":
+    energy_extractor = EnergyExtractor(hop_size=512, window_size=1024, padding='reflect', 
+                                       min_db=-60, norm=True)
+    audio = torch.rand(1, 16000)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+    import librosa
+    import matplotlib.pyplot as plt
+    # a1, _ = librosa.load('eg1.wav', sr=16000)
+    # a2, _ = librosa.load('eg2.wav', sr=16000)
+    # audio = torch.tensor([a1[:5*16000], a2[:5*16000]])
+    a1, _ = librosa.load('eg2.wav', sr=24000)
+    audio = torch.tensor(a1[:5*16000]).unsqueeze(0)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+
+    # Plot the energy for each audio sample
+    plt.figure(figsize=(12, 6))
+
+    for i in range(energy.shape[0]):
+        plt.plot(energy[i, :, 0].cpu().numpy(), label=f'Audio {i+1}')
+
+    plt.xlabel('Frame')
+    plt.ylabel('Energy (dB)')
+    plt.title('Energy over Time')
+    plt.legend()
+    plt.savefig('debug.png')
\ No newline at end of file
diff --git a/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..39dedf081f61f15a5e49a922a1863b560888aa6c
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import julius
+import soundfile as sf
+
+
+class MultibandEnergyExtractor(nn.Module):
+    def __init__(self, hop_size: int = 512, window_size: int = 1024,
+                 padding: str = 'reflect', min_db: float = -60,
+                 norm: bool = True, quantize_levels: int = None,
+                 n_bands: int = 8, control_bands: int = 4,
+                 sample_rate: int = 24000,):
+        super().__init__()
+        self.hop_size = hop_size
+        self.window_size = window_size
+        self.padding = padding
+        self.min_db = min_db
+        self.norm = norm
+        self.quantize_levels = quantize_levels
+        self.n_bands = n_bands
+        self.control_bands = control_bands
+        self.sample_rate = sample_rate
+
+    def forward(self, audio: torch.Tensor) -> torch.Tensor:
+        # Split the audio into frequency bands
+        audio = julius.split_bands(audio, n_bands=self.n_bands,
+                                   sample_rate=self.sample_rate)[:self.control_bands].transpose(0, 1)
+        B, C, _ = audio.shape
+        for i in range(C):
+            sf.write(f'output_{i}.wav', audio[0][i], self.sample_rate)
+
+        # Compute number of frames
+        n_frames = int(audio.size(-1) // self.hop_size)
+
+        # Pad the audio signal
+        pad_amount = (self.window_size - self.hop_size) // 2
+        audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding)
+
+        # Square the padded audio signal
+        audio_squared = audio_padded ** 2
+
+        # Compute the mean energy for each frame using unfold and mean
+        energy = audio_squared.unfold(dimension=-1, size=self.window_size, step=self.hop_size)
+        energy = energy[:, :, :n_frames]
+        print(energy.shape)
+        energy = energy.mean(dim=-1)
+        print(energy.shape)
+
+        # Compute the square root of the mean energy to get the RMS energy
+        # energy = torch.sqrt(energy)
+
+        # Normalize the energy using the min_db value
+        gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device))
+        gain_db = 10 * torch.log10(gain)
+
+        if self.norm:
+            # Find the min and max of gain_db
+            # min_gain_db = torch.min(gain_db)
+            min_gain_db = self.min_db
+            max_gain_db = torch.amax(gain_db, dim=(-1, -2), keepdim=True)
+
+            # Avoid numerical error by adding a small epsilon to the denominator
+            epsilon = 1e-8
+            gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon)
+
+        if self.quantize_levels is not None:
+            # Quantize the result to the given number of levels
+            gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1)
+
+        return gain_db.transpose(-1, -2)
+
+
+if __name__ == "__main__":
+    energy_extractor = MultibandEnergyExtractor(hop_size=320, window_size=1280,
+                                                padding='reflect',
+                                                min_db=-60, norm=True)
+    audio = torch.rand(4, 24000)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+    import librosa
+    import matplotlib.pyplot as plt
+    a1, _ = librosa.load('eg2.wav', sr=24000)
+    audio = torch.tensor(a1[:5*16000]).unsqueeze(0)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+
+    # Plot the energy for each audio sample
+    plt.figure(figsize=(12, 6))
+
+    for i in range(energy.shape[-1]):
+        plt.plot(energy[0, :, i].cpu().numpy(), label=f'Band {i+1}')
+
+    plt.xlabel('Frame')
+    plt.ylabel('Energy (dB)')
+    plt.title('Energy over Time')
+    plt.legend()
+    plt.savefig('debug.png')
diff --git a/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..9739d69ee93d4b5b81c1e6bd83c94e48c1e5b2b3
--- /dev/null
+++ b/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py
@@ -0,0 +1,46 @@
+from transformers import HubertModel
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import torchaudio
+import librosa
+
+
+class HubertModelWithFinalProj(HubertModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        # The final projection layer is only used for backward compatibility.
+        # Following https://github.com/auspicious3000/contentvec/issues/6
+        # Remove this layer is necessary to achieve the desired outcome.
+        self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
+
+
+class VoiceConversionExtractor(nn.Module):
+    # training on the fly might be slow
+    def __init__(self, config, sr):
+        super().__init__()
+        self.encoder = HubertModelWithFinalProj.from_pretrained(config)
+        self.encoder.eval()
+        self.sr = sr
+        self.target_sr = 16000
+        if self.sr != self.target_sr:
+            self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr,
+                                                            new_freq=self.target_sr)
+
+    def forward(self, audio):
+        if self.sr != self.target_sr:
+            audio = self.resampler(audio)
+        audio = F.pad(audio, ((400 - 320) // 2, (400 - 320) // 2))
+        logits = self.encoder(audio)['last_hidden_state']
+        return logits
+
+
+if __name__ == '__main__':
+    model = VoiceConversionExtractor('lengyue233/content-vec-best', 24000)
+    audio, sr = librosa.load('test.wav', sr=24000)
+    audio = audio[:round(100*320*1.5)]
+    audio = torch.tensor([audio])
+    with torch.no_grad():
+        content = model(audio)
+    print(content.shape)
\ No newline at end of file
diff --git a/src/models/conditions/__init__.py b/src/models/conditions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a60c5319ac77a8cdebd2835527256b547101700
--- /dev/null
+++ b/src/models/conditions/__init__.py
@@ -0,0 +1 @@
+from .condition_wrapper import Conditioner
\ No newline at end of file
diff --git a/src/models/conditions/__pycache__/__init__.cpython-311.pyc b/src/models/conditions/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3af6332d60ff9aad61b16821c885967ebe1dd22e
Binary files /dev/null and b/src/models/conditions/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/chroma.cpython-311.pyc b/src/models/conditions/__pycache__/chroma.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c61a05935c9a648b6200517444dc4baa5be6712
Binary files /dev/null and b/src/models/conditions/__pycache__/chroma.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc b/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e6f9d47688777ed749d1e6dad04e2f65269e48a
Binary files /dev/null and b/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/energy.cpython-311.pyc b/src/models/conditions/__pycache__/energy.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..684b9a9193b8b9b28f1c518cc0e389ad1e08aace
Binary files /dev/null and b/src/models/conditions/__pycache__/energy.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc b/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..116f7fe8ddc26a388ce59c8d337537a36f0a4e0c
Binary files /dev/null and b/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/sound_event.cpython-311.pyc b/src/models/conditions/__pycache__/sound_event.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96294a7643d5f951b6c1e3826f075aa89dd54cf0
Binary files /dev/null and b/src/models/conditions/__pycache__/sound_event.cpython-311.pyc differ
diff --git a/src/models/conditions/__pycache__/voice.cpython-311.pyc b/src/models/conditions/__pycache__/voice.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51751a890b358f952f8549a118af45671a0411ce
Binary files /dev/null and b/src/models/conditions/__pycache__/voice.cpython-311.pyc differ
diff --git a/src/models/conditions/chroma.py b/src/models/conditions/chroma.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07aebb898f6d8048099a22f4f9b4d8f1a3117fd
--- /dev/null
+++ b/src/models/conditions/chroma.py
@@ -0,0 +1,80 @@
+import typing as tp
+
+from einops import rearrange
+from librosa import filters
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+
+class ChromaExtractor(nn.Module):
+    """Chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate for the chroma extraction.
+        n_chroma (int): Number of chroma bins for the chroma extraction.
+        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+        nfft (int, optional): Number of FFT.
+        winlen (int, optional): Window length.
+        winhop (int, optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+    """
+    
+    def __init__(self, 
+                 sample_rate: int, 
+                 n_chroma: int = 12, radix2_exp: int = 12,
+                 nfft: tp.Optional[int] = None,
+                 winlen: tp.Optional[int] = None,
+                 winhop: tp.Optional[int] = None, argmax: bool = True,
+                 norm: float = torch.inf):
+        super().__init__()
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                                       n_chroma=self.n_chroma)), persistent=False)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=False,
+                                                      pad=0, normalized=True)
+
+    def forward(self, wav: torch.Tensor) -> torch.Tensor:
+        T = wav.shape[-1]
+        # in case we are getting a wav that was dropped out (nullified)
+        # from the conditioner, make sure wav length is no less that nfft
+        if T < self.nfft:
+            pad = self.nfft - T
+            r = 0 if pad % 2 == 0 else 1
+            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+        wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ),
+                          int(self.nfft // 2 - self.winhop // 2 )), mode="reflect")
+
+        spec = self.spec(wav).squeeze(1)
+        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+        if self.argmax:
+            idx = norm_chroma.argmax(-1, keepdim=True)
+            norm_chroma[:] = 0
+            norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+        return norm_chroma
+
+
+if __name__ == "__main__":
+    chroma = ChromaExtractor(sample_rate=16000,
+                             n_chroma=4,
+                             radix2_exp=None,
+                             winlen=16000,
+                             nfft=16000,
+                             winhop=4000)
+    audio = torch.rand(1, 16000)
+    c = chroma(audio)
\ No newline at end of file
diff --git a/src/models/conditions/condition_wrapper.py b/src/models/conditions/condition_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e44795abc16c3b9a9a30c8f579767879c25fe7
--- /dev/null
+++ b/src/models/conditions/condition_wrapper.py
@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+from .chroma import ChromaExtractor
+from .energy import EnergyExtractor
+from .voice import VoiceConversionExtractor
+from .mbenergy import MultibandEnergyExtractor
+
+
+class Conditioner(nn.Module):
+    def __init__(self,
+                 condition_type,
+                 **kwargs
+                ):
+        super().__init__()
+        if condition_type == 'energy':
+            self.conditioner = EnergyExtractor(**kwargs)
+        elif condition_type == 'chroma':
+            self.conditioner = ChromaExtractor(**kwargs)
+        elif condition_type == 'vc':
+            self.conditioner = VoiceConversionExtractor(**kwargs)
+        elif condition_type == 'mb_energy':
+            self.conditioner = MultibandEnergyExtractor(**kwargs)
+        else:
+            raise NotImplementedError
+
+    def forward(self, waveform, latent_shape):
+        # B T C
+        condition = self.conditioner(waveform)
+        # B C T
+        condition = condition.permute(0, 2, 1).contiguous()
+
+        if len(latent_shape) == 4:
+            # 2d spectrogram B C T F
+            assert (condition.shape[-1] % latent_shape[-2]) == 0
+            X = latent_shape[-1] * condition.shape[-1] // latent_shape[-2]
+            # copy on F direction
+            condition = condition.unsqueeze(-1).expand(-1, -1, -1, X)
+        elif len(latent_shape) == 3:
+            condition = condition
+        else:
+            raise NotImplementedError
+        return condition
+
+
+if __name__ == '__main__':
+    conditioner = Conditioner(condition_type='energy',
+                              hop_size=160, window_size=1024, padding='reflect',
+                              min_db=-80, norm=True)
+    audio = torch.rand(4, 16000)  # Example audio signal
+    energy = conditioner(audio, (4, 8, 100, 64))
\ No newline at end of file
diff --git a/src/models/conditions/debug.png b/src/models/conditions/debug.png
new file mode 100644
index 0000000000000000000000000000000000000000..999c9593896a2890de740ca5e3f878828c18d3df
Binary files /dev/null and b/src/models/conditions/debug.png differ
diff --git a/src/models/conditions/eg1.wav b/src/models/conditions/eg1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6eb190a33a2d35c169811fa764df280fa2e906fb
Binary files /dev/null and b/src/models/conditions/eg1.wav differ
diff --git a/src/models/conditions/eg2.wav b/src/models/conditions/eg2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..281bda1f4c65d48de53f5a21aaa8e7f185cd914b
Binary files /dev/null and b/src/models/conditions/eg2.wav differ
diff --git a/src/models/conditions/energy.py b/src/models/conditions/energy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86481723850af419fa701254863789108596e88
--- /dev/null
+++ b/src/models/conditions/energy.py
@@ -0,0 +1,85 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+
+class EnergyExtractor(nn.Module):
+    def __init__(self, hop_size: int = 512, window_size: int = 1024,
+                 padding: str = 'reflect', min_db: float = -60,
+                 norm: bool = True, quantize_levels: int = None):
+        super().__init__()
+        self.hop_size = hop_size
+        self.window_size = window_size
+        self.padding = padding
+        self.min_db = min_db
+        self.norm = norm
+        self.quantize_levels = quantize_levels
+
+    def forward(self, audio: torch.Tensor) -> torch.Tensor:
+        # Compute number of frames
+        n_frames = int(audio.size(-1) // self.hop_size)
+
+        # Pad the audio signal
+        pad_amount = (self.window_size - self.hop_size) // 2
+        audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding)
+
+        # Square the padded audio signal
+        audio_squared = audio_padded ** 2
+
+        # Compute the mean energy for each frame using unfold and mean
+        audio_squared = audio_squared[:, None, None, :]
+        energy = F.unfold(audio_squared, (1, self.window_size), stride=self.hop_size)[:, :, :n_frames]
+        energy = energy.mean(dim=1)
+
+        # Compute the square root of the mean energy to get the RMS energy
+        # energy = torch.sqrt(energy)
+
+        # Normalize the energy using the min_db value
+        gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device))
+        gain_db = 10 * torch.log10(gain)
+
+        if self.norm:
+            # Find the min and max of gain_db
+            # min_gain_db = torch.min(gain_db)
+            min_gain_db = self.min_db
+            max_gain_db = torch.max(gain_db, dim=-1, keepdim=True)[0]
+
+            # Avoid numerical error by adding a small epsilon to the denominator
+            epsilon = 1e-8
+            gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon)
+
+        if self.quantize_levels is not None:
+            # Quantize the result to the given number of levels
+            gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1)
+
+        return gain_db.unsqueeze(-1)
+
+
+if __name__ == "__main__":
+    energy_extractor = EnergyExtractor(hop_size=512, window_size=1024, padding='reflect', 
+                                       min_db=-60, norm=True)
+    audio = torch.rand(1, 16000)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+    import librosa
+    import matplotlib.pyplot as plt
+    # a1, _ = librosa.load('eg1.wav', sr=16000)
+    # a2, _ = librosa.load('eg2.wav', sr=16000)
+    # audio = torch.tensor([a1[:5*16000], a2[:5*16000]])
+    a1, _ = librosa.load('eg2.wav', sr=24000)
+    audio = torch.tensor(a1[:5*16000]).unsqueeze(0)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+
+    # Plot the energy for each audio sample
+    plt.figure(figsize=(12, 6))
+
+    for i in range(energy.shape[0]):
+        plt.plot(energy[i, :, 0].cpu().numpy(), label=f'Audio {i+1}')
+
+    plt.xlabel('Frame')
+    plt.ylabel('Energy (dB)')
+    plt.title('Energy over Time')
+    plt.legend()
+    plt.savefig('debug.png')
\ No newline at end of file
diff --git a/src/models/conditions/mbenergy.py b/src/models/conditions/mbenergy.py
new file mode 100644
index 0000000000000000000000000000000000000000..39dedf081f61f15a5e49a922a1863b560888aa6c
--- /dev/null
+++ b/src/models/conditions/mbenergy.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import julius
+import soundfile as sf
+
+
+class MultibandEnergyExtractor(nn.Module):
+    def __init__(self, hop_size: int = 512, window_size: int = 1024,
+                 padding: str = 'reflect', min_db: float = -60,
+                 norm: bool = True, quantize_levels: int = None,
+                 n_bands: int = 8, control_bands: int = 4,
+                 sample_rate: int = 24000,):
+        super().__init__()
+        self.hop_size = hop_size
+        self.window_size = window_size
+        self.padding = padding
+        self.min_db = min_db
+        self.norm = norm
+        self.quantize_levels = quantize_levels
+        self.n_bands = n_bands
+        self.control_bands = control_bands
+        self.sample_rate = sample_rate
+
+    def forward(self, audio: torch.Tensor) -> torch.Tensor:
+        # Split the audio into frequency bands
+        audio = julius.split_bands(audio, n_bands=self.n_bands,
+                                   sample_rate=self.sample_rate)[:self.control_bands].transpose(0, 1)
+        B, C, _ = audio.shape
+        for i in range(C):
+            sf.write(f'output_{i}.wav', audio[0][i], self.sample_rate)
+
+        # Compute number of frames
+        n_frames = int(audio.size(-1) // self.hop_size)
+
+        # Pad the audio signal
+        pad_amount = (self.window_size - self.hop_size) // 2
+        audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding)
+
+        # Square the padded audio signal
+        audio_squared = audio_padded ** 2
+
+        # Compute the mean energy for each frame using unfold and mean
+        energy = audio_squared.unfold(dimension=-1, size=self.window_size, step=self.hop_size)
+        energy = energy[:, :, :n_frames]
+        print(energy.shape)
+        energy = energy.mean(dim=-1)
+        print(energy.shape)
+
+        # Compute the square root of the mean energy to get the RMS energy
+        # energy = torch.sqrt(energy)
+
+        # Normalize the energy using the min_db value
+        gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device))
+        gain_db = 10 * torch.log10(gain)
+
+        if self.norm:
+            # Find the min and max of gain_db
+            # min_gain_db = torch.min(gain_db)
+            min_gain_db = self.min_db
+            max_gain_db = torch.amax(gain_db, dim=(-1, -2), keepdim=True)
+
+            # Avoid numerical error by adding a small epsilon to the denominator
+            epsilon = 1e-8
+            gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon)
+
+        if self.quantize_levels is not None:
+            # Quantize the result to the given number of levels
+            gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1)
+
+        return gain_db.transpose(-1, -2)
+
+
+if __name__ == "__main__":
+    energy_extractor = MultibandEnergyExtractor(hop_size=320, window_size=1280,
+                                                padding='reflect',
+                                                min_db=-60, norm=True)
+    audio = torch.rand(4, 24000)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+    import librosa
+    import matplotlib.pyplot as plt
+    a1, _ = librosa.load('eg2.wav', sr=24000)
+    audio = torch.tensor(a1[:5*16000]).unsqueeze(0)
+    energy = energy_extractor(audio)
+    print(energy.shape)
+
+    # Plot the energy for each audio sample
+    plt.figure(figsize=(12, 6))
+
+    for i in range(energy.shape[-1]):
+        plt.plot(energy[0, :, i].cpu().numpy(), label=f'Band {i+1}')
+
+    plt.xlabel('Frame')
+    plt.ylabel('Energy (dB)')
+    plt.title('Energy over Time')
+    plt.legend()
+    plt.savefig('debug.png')
diff --git a/src/models/conditions/output_0.wav b/src/models/conditions/output_0.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2f10de23902599f046a7e2d045d804a4979188bf
Binary files /dev/null and b/src/models/conditions/output_0.wav differ
diff --git a/src/models/conditions/output_1.wav b/src/models/conditions/output_1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ce38cd82d88373c27b5bf02510c208c2cfa29bac
Binary files /dev/null and b/src/models/conditions/output_1.wav differ
diff --git a/src/models/conditions/output_2.wav b/src/models/conditions/output_2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a8df7bd2f6b5c7478d57914bdc382138e4f7101d
Binary files /dev/null and b/src/models/conditions/output_2.wav differ
diff --git a/src/models/conditions/output_3.wav b/src/models/conditions/output_3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5b56740b310fc6cecbcf342034b54e8cde7cc3f3
Binary files /dev/null and b/src/models/conditions/output_3.wav differ
diff --git a/src/models/conditions/voice.py b/src/models/conditions/voice.py
new file mode 100644
index 0000000000000000000000000000000000000000..9739d69ee93d4b5b81c1e6bd83c94e48c1e5b2b3
--- /dev/null
+++ b/src/models/conditions/voice.py
@@ -0,0 +1,46 @@
+from transformers import HubertModel
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import torchaudio
+import librosa
+
+
+class HubertModelWithFinalProj(HubertModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        # The final projection layer is only used for backward compatibility.
+        # Following https://github.com/auspicious3000/contentvec/issues/6
+        # Remove this layer is necessary to achieve the desired outcome.
+        self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
+
+
+class VoiceConversionExtractor(nn.Module):
+    # training on the fly might be slow
+    def __init__(self, config, sr):
+        super().__init__()
+        self.encoder = HubertModelWithFinalProj.from_pretrained(config)
+        self.encoder.eval()
+        self.sr = sr
+        self.target_sr = 16000
+        if self.sr != self.target_sr:
+            self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr,
+                                                            new_freq=self.target_sr)
+
+    def forward(self, audio):
+        if self.sr != self.target_sr:
+            audio = self.resampler(audio)
+        audio = F.pad(audio, ((400 - 320) // 2, (400 - 320) // 2))
+        logits = self.encoder(audio)['last_hidden_state']
+        return logits
+
+
+if __name__ == '__main__':
+    model = VoiceConversionExtractor('lengyue233/content-vec-best', 24000)
+    audio, sr = librosa.load('test.wav', sr=24000)
+    audio = audio[:round(100*320*1.5)]
+    audio = torch.tensor([audio])
+    with torch.no_grad():
+        content = model(audio)
+    print(content.shape)
\ No newline at end of file
diff --git a/src/models/controlnet.py b/src/models/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1750621847ed116a6fbab55a50e67963699d6a5a
--- /dev/null
+++ b/src/models/controlnet.py
@@ -0,0 +1,318 @@
+import torch
+import torch.nn as nn
+
+from .utils.modules import PatchEmbed, TimestepEmbedder
+from .utils.modules import PE_wrapper, RMSNorm
+from .blocks import DiTBlock, JointDiTBlock
+from .utils.span_mask import compute_mask_indices
+
+
+class DiTControlNetEmbed(nn.Module):
+    def __init__(self, in_chans, out_chans, blocks,
+                 cond_mask=False, cond_mask_prob=None,
+                 cond_mask_ratio=None, cond_mask_span=None):
+        super().__init__()
+        self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1)
+
+        self.cond_mask = cond_mask
+        if self.cond_mask:
+            self.mask_embed = nn.Parameter(torch.zeros((blocks[0])))
+            self.mask_prob = cond_mask_prob
+            self.mask_ratio = cond_mask_ratio
+            self.mask_span = cond_mask_span
+            blocks[0] = blocks[0] + 1
+
+        conv_blocks = []
+        for i in range(len(blocks) - 1):
+            channel_in = blocks[i]
+            channel_out = blocks[i + 1]
+            block = nn.Sequential(
+                nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1),
+                nn.SiLU(),
+                nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2),
+                nn.SiLU(),)
+            conv_blocks.append(block)
+            self.blocks = nn.ModuleList(conv_blocks)
+
+        self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1)
+        nn.init.zeros_(self.conv_out.weight)
+        nn.init.zeros_(self.conv_out.bias)
+
+    def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+        B, D, L = gt.shape
+        if mae_mask_infer is None:
+            # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+            mask_ratios = mask_ratios.cpu().numpy()
+            mask = compute_mask_indices(shape=[B, L],
+                                        padding_mask=None,
+                                        mask_prob=mask_ratios,
+                                        mask_length=self.mask_span,
+                                        mask_type="static",
+                                        mask_other=0.0,
+                                        min_masks=1,
+                                        no_overlap=False,
+                                        min_space=0,)
+            # only apply mask to some batches
+            mask_batch = torch.rand(B) < self.mask_prob
+            mask[~mask_batch] = False
+            mask = mask.unsqueeze(1).expand_as(gt)
+        else:
+            mask = mae_mask_infer
+            mask = mask.expand_as(gt)
+        gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt)
+        return gt, mask.type_as(gt)
+
+    def forward(self, conditioning, cond_mask_infer=None):
+        embedding = self.conv_in(conditioning)
+
+        if self.cond_mask:
+            B, D, L = embedding.shape
+            if not self.training and cond_mask_infer is None:
+                cond_mask_infer = torch.zeros_like(embedding).bool()
+            mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device)
+            embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer)
+            embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1)
+
+        for block in self.blocks:
+            embedding = block(embedding)
+
+        embedding = self.conv_out(embedding)
+
+        # B, L, C
+        embedding = embedding.transpose(1, 2).contiguous()
+
+        return embedding
+
+
+class DiTControlNet(nn.Module):
+    def __init__(self,
+                 img_size=(224, 224), patch_size=16, in_chans=3,
+                 input_type='2d', out_chans=None,
+                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer='layernorm',
+                 context_norm=False,
+                 use_checkpoint=False,
+                 # time fusion ada or token
+                 time_fusion='token',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 cls_dim=None,
+                 # max length is only used for concat
+                 context_dim=768, context_fusion='concat',
+                 context_max_length=128, context_pe_method='sinu',
+                 pe_method='abs', rope_mode='none',
+                 use_conv=True,
+                 skip=True, skip_norm=True,
+                 # controlnet configs
+                 cond_in=None, cond_blocks=None, 
+                 cond_mask=False, cond_mask_prob=None,
+                 cond_mask_ratio=None, cond_mask_span=None,
+                 **kwargs):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim
+        # input
+        self.in_chans = in_chans
+        self.input_type = input_type
+        if self.input_type == '2d':
+            num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+        elif self.input_type == '1d':
+            num_patches = img_size // patch_size
+        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
+                                      embed_dim=embed_dim, input_type=input_type)
+        out_chans = in_chans if out_chans is None else out_chans
+        self.out_chans = out_chans
+
+        # position embedding
+        self.rope = rope_mode
+        self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
+                               length=num_patches)
+
+        print(f'x position embedding: {pe_method}')
+        print(f'rope mode: {self.rope}')
+
+        # time embed
+        self.time_embed = TimestepEmbedder(embed_dim)
+        self.time_fusion = time_fusion
+        self.use_adanorm = False
+
+        # cls embed
+        if cls_dim is not None:
+            self.cls_embed = nn.Sequential(
+                nn.Linear(cls_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+        else:
+            self.cls_embed = None
+
+        # time fusion
+        if time_fusion == 'token':
+            # put token at the beginning of sequence
+            self.extras = 2 if self.cls_embed else 1
+            self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
+        elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
+            self.use_adanorm = True
+            # aviod  repetitive silu for each adaln block
+            self.time_act = nn.SiLU()
+            self.extras = 0
+            if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
+                # shared adaln
+                self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+            else:
+                self.time_ada = None
+        else:
+            raise NotImplementedError
+        print(f'time fusion mode: {self.time_fusion}')
+
+        # context
+        # use a simple projection
+        self.use_context = False
+        self.context_cross = False
+        self.context_max_length = context_max_length
+        self.context_fusion = 'none'
+        if context_dim is not None:
+            self.use_context = True
+            self.context_embed = nn.Sequential(
+                nn.Linear(context_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+            self.context_fusion = context_fusion
+            if context_fusion == 'concat' or context_fusion == 'joint':
+                self.extras += context_max_length
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                # no cross attention layers
+                context_dim = None
+            elif context_fusion == 'cross':
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                self.context_cross = True
+                context_dim = embed_dim
+            else:
+                raise NotImplementedError
+        print(f'context fusion mode: {context_fusion}')
+        print(f'context position embedding: {context_pe_method}')
+
+        if self.context_fusion == 'joint':
+            Block = JointDiTBlock
+        else:
+            Block = DiTBlock
+
+        # norm layers
+        if norm_layer == 'layernorm':
+            norm_layer = nn.LayerNorm
+        elif norm_layer == 'rmsnorm':
+            norm_layer = RMSNorm
+        else:
+            raise NotImplementedError
+
+        self.in_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=False, skip_norm=False,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim,
+                                                 blocks=cond_blocks,
+                                                 cond_mask=cond_mask, 
+                                                 cond_mask_prob=cond_mask_prob,
+                                                 cond_mask_ratio=cond_mask_ratio,
+                                                 cond_mask_span=cond_mask_span)
+
+        controlnet_zero_blocks = []
+        for i in range(depth // 2):
+            block = nn.Linear(embed_dim, embed_dim)
+            nn.init.zeros_(block.weight)
+            nn.init.zeros_(block.bias)
+            controlnet_zero_blocks.append(block)
+        self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks)
+
+        print('ControlNet ready \n')
+
+    def set_trainable(self):
+        for param in self.parameters():
+            param.requires_grad = False
+
+        # only train input_proj, blocks, and output_proj
+        for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']:
+            module = getattr(self, module_name, None)
+            if module is not None:
+                for param in module.parameters():
+                    param.requires_grad = True
+                module.train()
+            else:
+                print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n')
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None,
+                cls_token=None,
+                condition=None, cond_mask_infer=None,
+                conditioning_scale=1.0):
+        # make it compatible with int time step during inference
+        if timesteps.dim() == 0:
+            timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
+
+        x = self.patch_embed(x)
+        # add condition to x
+        condition = self.controlnet_pre(condition)
+        x = x + condition
+        x = self.x_pe(x)
+
+        B, L, D = x.shape
+
+        if self.use_context:
+            context_token = self.context_embed(context)
+            context_token = self.context_pe(context_token)
+            if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+                x, x_mask = self._concat_x_context(x=x, context=context_token,
+                                                   x_mask=x_mask,
+                                                   context_mask=context_mask)
+                context_token, context_mask = None, None
+        else:
+            context_token, context_mask = None, None
+
+        time_token = self.time_embed(timesteps)
+        if self.cls_embed:
+            cls_token = self.cls_embed(cls_token)
+        time_ada = None
+        if self.use_adanorm:
+            if self.cls_embed:
+                time_token = time_token + cls_token
+            time_token = self.time_act(time_token)
+            if self.time_ada is not None:
+                time_ada = self.time_ada(time_token)
+        else:
+            time_token = time_token.unsqueeze(dim=1)
+            if self.cls_embed:
+                cls_token = cls_token.unsqueeze(dim=1)
+                time_token = torch.cat([time_token, cls_token], dim=1)
+            time_token = self.time_pe(time_token)
+            x = torch.cat((time_token, x), dim=1)
+            if x_mask is not None:
+                x_mask = torch.cat(
+                    [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
+                     x_mask], dim=1)
+            time_token = None
+
+        skips = []
+        for blk in self.in_blocks:
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=None, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+            skips.append(x)
+
+        controlnet_skips = []
+        for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks):
+            controlnet_skips.append(controlnet_block(skip) * conditioning_scale)
+
+        return controlnet_skips
\ No newline at end of file
diff --git a/src/models/udit.py b/src/models/udit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e126efd370efabbfcc4f4359194f9c95c6e9d154
--- /dev/null
+++ b/src/models/udit.py
@@ -0,0 +1,365 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import math
+from .utils.modules import PatchEmbed, TimestepEmbedder
+from .utils.modules import PE_wrapper, RMSNorm
+from .blocks import DiTBlock, JointDiTBlock, FinalBlock
+
+
+class UDiT(nn.Module):
+    def __init__(self,
+                 img_size=224, patch_size=16, in_chans=3,
+                 input_type='2d', out_chans=None,
+                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 act_layer='gelu', norm_layer='layernorm',
+                 context_norm=False,
+                 use_checkpoint=False,
+                 # time fusion ada or token
+                 time_fusion='token',
+                 ada_lora_rank=None, ada_lora_alpha=None,
+                 cls_dim=None,
+                 # max length is only used for concat
+                 context_dim=768, context_fusion='concat',
+                 context_max_length=128, context_pe_method='sinu',
+                 pe_method='abs', rope_mode='none',
+                 use_conv=True,
+                 skip=True, skip_norm=True):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+
+        # input
+        self.in_chans = in_chans
+        self.input_type = input_type
+        if self.input_type == '2d':
+            num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+        elif self.input_type == '1d':
+            num_patches = img_size // patch_size
+        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
+                                      embed_dim=embed_dim, input_type=input_type)
+        out_chans = in_chans if out_chans is None else out_chans
+        self.out_chans = out_chans
+
+        # position embedding
+        self.rope = rope_mode
+        self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
+                               length=num_patches)
+
+        print(f'x position embedding: {pe_method}')
+        print(f'rope mode: {self.rope}')
+
+        # time embed
+        self.time_embed = TimestepEmbedder(embed_dim)
+        self.time_fusion = time_fusion
+        self.use_adanorm = False
+
+        # cls embed
+        if cls_dim is not None:
+            self.cls_embed = nn.Sequential(
+                nn.Linear(cls_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+        else:
+            self.cls_embed = None
+
+        # time fusion
+        if time_fusion == 'token':
+            # put token at the beginning of sequence
+            self.extras = 2 if self.cls_embed else 1
+            self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
+        elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
+            self.use_adanorm = True
+            # aviod  repetitive silu for each adaln block
+            self.time_act = nn.SiLU()
+            self.extras = 0
+            self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
+            if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
+                # shared adaln
+                self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+            else:
+                self.time_ada = None
+        else:
+            raise NotImplementedError
+        print(f'time fusion mode: {self.time_fusion}')
+
+        # context
+        # use a simple projection
+        self.use_context = False
+        self.context_cross = False
+        self.context_max_length = context_max_length
+        self.context_fusion = 'none'
+        if context_dim is not None:
+            self.use_context = True
+            self.context_embed = nn.Sequential(
+                nn.Linear(context_dim, embed_dim, bias=True),
+                nn.SiLU(),
+                nn.Linear(embed_dim, embed_dim, bias=True),)
+            self.context_fusion = context_fusion
+            if context_fusion == 'concat' or context_fusion == 'joint':
+                self.extras += context_max_length
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                # no cross attention layers
+                context_dim = None
+            elif context_fusion == 'cross':
+                self.context_pe = PE_wrapper(dim=embed_dim,
+                                             method=context_pe_method,
+                                             length=context_max_length)
+                self.context_cross = True
+                context_dim = embed_dim
+            else:
+                raise NotImplementedError
+        print(f'context fusion mode: {context_fusion}')
+        print(f'context position embedding: {context_pe_method}')
+
+        if self.context_fusion == 'joint':
+            Block = JointDiTBlock
+            self.use_skip = skip[0]
+        else:
+            Block = DiTBlock
+            self.use_skip = skip
+
+        # norm layers
+        if norm_layer == 'layernorm':
+            norm_layer = nn.LayerNorm
+        elif norm_layer == 'rmsnorm':
+            norm_layer = RMSNorm
+        else:
+            raise NotImplementedError
+
+        print(f'use long skip connection: {skip}')
+        self.in_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=False, skip_norm=False,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        self.mid_block = Block(
+            dim=embed_dim, context_dim=context_dim, num_heads=num_heads, 
+            mlp_ratio=mlp_ratio,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+            act_layer=act_layer, norm_layer=norm_layer,
+            time_fusion=time_fusion,
+            ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+            skip=False, skip_norm=False,
+            rope_mode=self.rope,
+            context_norm=context_norm,
+            use_checkpoint=use_checkpoint)
+
+        self.out_blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, context_dim=context_dim, num_heads=num_heads, 
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+                act_layer=act_layer, norm_layer=norm_layer,
+                time_fusion=time_fusion,
+                ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+                skip=skip, skip_norm=skip_norm,
+                rope_mode=self.rope,
+                context_norm=context_norm,
+                use_checkpoint=use_checkpoint)
+            for _ in range(depth // 2)])
+
+        # FinalLayer block
+        self.use_conv = use_conv
+        self.final_block = FinalBlock(embed_dim=embed_dim,
+                                      patch_size=patch_size,
+                                      img_size=img_size,
+                                      in_chans=out_chans,
+                                      input_type=input_type,
+                                      norm_layer=norm_layer,
+                                      use_conv=use_conv,
+                                      use_adanorm=self.use_adanorm)
+        self.initialize_weights()
+
+    def _init_ada(self):
+        if self.time_fusion == 'ada':
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+            for block in self.in_blocks:
+                nn.init.constant_(block.adaln.time_ada.weight, 0)
+                nn.init.constant_(block.adaln.time_ada.bias, 0)
+            nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
+            nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
+            for block in self.out_blocks:
+                nn.init.constant_(block.adaln.time_ada.weight, 0)
+                nn.init.constant_(block.adaln.time_ada.bias, 0)
+        elif self.time_fusion == 'ada_single':
+            nn.init.constant_(self.time_ada.weight, 0)
+            nn.init.constant_(self.time_ada.bias, 0)
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+        elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
+            nn.init.constant_(self.time_ada.weight, 0)
+            nn.init.constant_(self.time_ada.bias, 0)
+            nn.init.constant_(self.time_ada_final.weight, 0)
+            nn.init.constant_(self.time_ada_final.bias, 0)
+            for block in self.in_blocks:
+                nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+                                         a=math.sqrt(5))
+                nn.init.constant_(block.adaln.lora_b.weight, 0)
+            nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
+                                     a=math.sqrt(5))
+            nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
+            for block in self.out_blocks:
+                nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+                                         a=math.sqrt(5))
+                nn.init.constant_(block.adaln.lora_b.weight, 0)
+
+    def initialize_weights(self):
+        # Basic init for all layers
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+        # init patch Conv like Linear
+        w = self.patch_embed.proj.weight.data
+        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+        nn.init.constant_(self.patch_embed.proj.bias, 0)
+
+        # Zero-out AdaLN
+        if self.use_adanorm:
+            self._init_ada()
+
+        # Zero-out Cross Attention
+        if self.context_cross:
+            for block in self.in_blocks:
+                nn.init.constant_(block.cross_attn.proj.weight, 0)
+                nn.init.constant_(block.cross_attn.proj.bias, 0)
+            nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
+            nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
+            for block in self.out_blocks:
+                nn.init.constant_(block.cross_attn.proj.weight, 0)
+                nn.init.constant_(block.cross_attn.proj.bias, 0)
+
+        # Zero-out cls embedding
+        if self.cls_embed:
+            if self.use_adanorm:
+                nn.init.constant_(self.cls_embed[-1].weight, 0)
+                nn.init.constant_(self.cls_embed[-1].bias, 0)
+
+        # Zero-out Output
+        # might not zero-out this when using v-prediction
+        # it could be good when using noise-prediction
+        # nn.init.constant_(self.final_block.linear.weight, 0)
+        # nn.init.constant_(self.final_block.linear.bias, 0)
+        # if self.use_conv:
+        #     nn.init.constant_(self.final_block.final_layer.weight.data, 0)
+        #     nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+        # init out Conv
+        if self.use_conv:
+            nn.init.xavier_uniform_(self.final_block.final_layer.weight)
+            nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+    def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
+        assert context.shape[-2] == self.context_max_length
+        # Check if either x_mask or context_mask is provided
+        B = x.shape[0]
+        # Create default masks if they are not provided
+        if x_mask is None:
+            x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+        if context_mask is None:
+            context_mask = torch.ones(B, context.shape[-2],
+                                      device=context.device).bool()
+        # Concatenate the masks along the second dimension (dim=1)
+        x_mask = torch.cat([context_mask, x_mask], dim=1)
+        # Concatenate context and x along the second dimension (dim=1)
+        x = torch.cat((context, x), dim=1)
+        return x, x_mask
+
+    def forward(self, x, timesteps, context,
+                x_mask=None, context_mask=None,
+                cls_token=None, controlnet_skips=None,
+               ):
+        # make it compatible with int time step during inference
+        if timesteps.dim() == 0:
+            timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
+
+        x = self.patch_embed(x)
+        x = self.x_pe(x)
+
+        B, L, D = x.shape
+
+        if self.use_context:
+            context_token = self.context_embed(context)
+            context_token = self.context_pe(context_token)
+            if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+                x, x_mask = self._concat_x_context(x=x, context=context_token,
+                                                   x_mask=x_mask,
+                                                   context_mask=context_mask)
+                context_token, context_mask = None, None
+        else:
+            context_token, context_mask = None, None
+
+        time_token = self.time_embed(timesteps)
+        if self.cls_embed:
+            cls_token = self.cls_embed(cls_token)
+        time_ada = None
+        time_ada_final = None
+        if self.use_adanorm:
+            if self.cls_embed:
+                time_token = time_token + cls_token
+            time_token = self.time_act(time_token)
+            time_ada_final = self.time_ada_final(time_token)
+            if self.time_ada is not None:
+                time_ada = self.time_ada(time_token)
+        else:
+            time_token = time_token.unsqueeze(dim=1)
+            if self.cls_embed:
+                cls_token = cls_token.unsqueeze(dim=1)
+                time_token = torch.cat([time_token, cls_token], dim=1)
+            time_token = self.time_pe(time_token)
+            x = torch.cat((time_token, x), dim=1)
+            if x_mask is not None:
+                x_mask = torch.cat(
+                    [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
+                     x_mask], dim=1)
+            time_token = None
+
+        skips = []
+        for blk in self.in_blocks:
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=None, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+            if self.use_skip:
+                skips.append(x)
+
+        x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
+                           skip=None, context=context_token,
+                           x_mask=x_mask, context_mask=context_mask,
+                           extras=self.extras)
+        for blk in self.out_blocks:
+            if self.use_skip:
+                skip = skips.pop()
+                if controlnet_skips:
+                    # add to skip like u-net controlnet
+                    skip = skip + controlnet_skips.pop()
+            else:
+                skip = None
+                if controlnet_skips:
+                    # directly add to x
+                    x = x + controlnet_skips.pop()
+
+            x = blk(x=x, time_token=time_token, time_ada=time_ada,
+                    skip=skip, context=context_token,
+                    x_mask=x_mask, context_mask=context_mask,
+                    extras=self.extras)
+
+        x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+        return x
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f77c9148fc54916e1bedc2f36f77f6a2164986a
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py
@@ -0,0 +1,290 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+from .modules import RMSNorm
+
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+    ATTENTION_MODE = 'flash'
+else:
+    ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+    b, ndim = sim.shape[0], mask.ndim
+    if ndim == 3:
+        mask = rearrange(mask, "b n m -> b 1 n m")
+    if ndim == 2:
+        mask = repeat(mask, "n m -> b 1 n m", b=b)
+    max_neg_value = -torch.finfo(sim.dtype).max
+    sim = sim.masked_fill(~mask, max_neg_value)
+    return sim
+
+
+def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
+    def default(val, d):
+        return val if val is not None else (d() if isfunction(d) else d)
+    b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
+    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+    return attn_mask
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, context_dim=None, num_heads=8,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 attn_drop=0., proj_drop=0., rope_mode='none'):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        if context_dim is None:
+            self.cross_attn = False
+        else:
+            self.cross_attn = True
+
+        context_dim = dim if context_dim is None else context_dim
+
+        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+
+        if qk_norm is None:
+            self.norm_q = nn.Identity()
+            self.norm_k = nn.Identity()
+        elif qk_norm == 'layernorm':
+            self.norm_q = nn.LayerNorm(head_dim)
+            self.norm_k = nn.LayerNorm(head_dim)
+        elif qk_norm == 'rmsnorm':
+            self.norm_q = RMSNorm(head_dim)
+            self.norm_k = RMSNorm(head_dim)
+        else:
+            raise NotImplementedError
+
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        if self.cross_attn:
+            assert rope_mode == 'none'
+        self.rope_mode = rope_mode
+        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+            self.rotary = RotaryEmbedding(dim=head_dim)
+        elif self.rope_mode == 'dual':
+            self.rotary_x = RotaryEmbedding(dim=head_dim)
+            self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+    def _rotary(self, q, k, extras):
+        if self.rope_mode == 'shared':
+            q, k = self.rotary(q=q, k=k)
+        elif self.rope_mode == 'x_only':
+            q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'dual':
+            q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'none':
+            pass
+        else:
+            raise NotImplementedError
+        return q, k
+
+    def _attn(self, q, k, v, mask_binary):
+        if ATTENTION_MODE == 'flash':
+            x = F.scaled_dot_product_attention(q, k, v,
+                                               dropout_p=self.attn_drop_p,
+                                               attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        else:
+            raise NotImplementedError
+        return x
+
+    def forward(self, x, context=None, context_mask=None, extras=0):
+        B, L, C = x.shape
+        if context is None:
+            context = x
+
+        q = self.to_q(x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        if context_mask is not None:
+            mask_binary = create_mask(x.shape, context.shape,
+                                      x.device, None, context_mask)
+        else:
+            mask_binary = None
+
+        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
+        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
+        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
+
+        q = self.norm_q(q)
+        k = self.norm_k(k)
+
+        q, k = self._rotary(q, k, extras)
+
+        x = self._attn(q, k, v, mask_binary)
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class JointAttention(nn.Module):
+    def __init__(self, dim, num_heads=8,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 attn_drop=0., proj_drop=0.,
+                 rope_mode='none'):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
+        self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
+
+        self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
+        self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
+
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+
+        self.proj_x = nn.Linear(dim, dim)
+        self.proj_drop_x = nn.Dropout(proj_drop)
+
+        self.proj_c = nn.Linear(dim, dim)
+        self.proj_drop_c = nn.Dropout(proj_drop)
+
+        self.rope_mode = rope_mode
+        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+            self.rotary = RotaryEmbedding(dim=head_dim)
+        elif self.rope_mode == 'dual':
+            self.rotary_x = RotaryEmbedding(dim=head_dim)
+            self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+    def _make_qkv_layers(self, dim, qkv_bias):
+        return (nn.Linear(dim, dim, bias=qkv_bias),
+                nn.Linear(dim, dim, bias=qkv_bias),
+                nn.Linear(dim, dim, bias=qkv_bias))
+
+    def _make_norm_layers(self, qk_norm, head_dim):
+        if qk_norm is None:
+            norm_q = nn.Identity()
+            norm_k = nn.Identity()
+        elif qk_norm == 'layernorm':
+            norm_q = nn.LayerNorm(head_dim)
+            norm_k = nn.LayerNorm(head_dim)
+        elif qk_norm == 'rmsnorm':
+            norm_q = RMSNorm(head_dim)
+            norm_k = RMSNorm(head_dim)
+        else:
+            raise NotImplementedError
+        return norm_q, norm_k
+
+    def _rotary(self, q, k, extras):
+        if self.rope_mode == 'shared':
+            q, k = self.rotary(q=q, k=k)
+        elif self.rope_mode == 'x_only':
+            q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'dual':
+            q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'none':
+            pass
+        else:
+            raise NotImplementedError
+        return q, k
+
+    def _attn(self, q, k, v, mask_binary):
+        if ATTENTION_MODE == 'flash':
+            x = F.scaled_dot_product_attention(q, k, v,
+                                               dropout_p=self.attn_drop_p,
+                                               attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        else:
+            raise NotImplementedError
+        return x
+
+    def _cat_mask(self, x, context, x_mask=None, context_mask=None):
+        B = x.shape[0]
+        if x_mask is None:
+            x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+        if context_mask is None:
+            context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
+        mask = torch.cat([context_mask, x_mask], dim=1)
+        return mask
+
+    def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
+        B, Lx, C = x.shape
+        _, Lc, _ = context.shape
+        if x_mask is not None or context_mask is not None:
+            mask = self._cat_mask(x, context,
+                                  x_mask=x_mask,
+                                  context_mask=context_mask)
+            shape = [B, Lx+Lc, C]
+            mask_binary = create_mask(q_shape=shape, k_shape=shape,
+                                      device=x.device,
+                                      q_mask=None, k_mask=mask)
+        else:
+            mask_binary = None
+
+        qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
+        qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
+
+        qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+                                                    H=self.num_heads), [qx, kx, vx])
+        qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+                                                    H=self.num_heads), [qc, kc, vc])
+
+        qx, kx = self.norm_qx(qx), self.norm_kx(kx)
+        qc, kc = self.norm_qc(qc), self.norm_kc(kc)
+
+        q, k, v = (torch.cat([qc, qx], dim=2),
+                   torch.cat([kc, kx], dim=2),
+                   torch.cat([vc, vx], dim=2))
+
+        q, k = self._rotary(q, k, extras)
+
+        x = self._attn(q, k, v, mask_binary)
+
+        context, x = x[:, :Lc, :], x[:, Lc:, :]
+
+        x = self.proj_x(x)
+        x = self.proj_drop_x(x)
+
+        context = self.proj_c(context)
+        context = self.proj_drop_c(context)
+
+        return x, context
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..c825b988439d9b91e1e1d30c1cf842880252c0bf
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py
@@ -0,0 +1,374 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.cuda.amp import autocast
+import math
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .timm import trunc_normal_
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def film_modulate(x, shift, scale):
+    return x * (1 + scale) + shift
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+    """
+    Create sinusoidal timestep embeddings.
+
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    half = dim // 2
+    freqs = torch.exp(
+        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+    ).to(device=timesteps.device)
+    args = timesteps[:, None].float() * freqs[None]
+    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+    if dim % 2:
+        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    return embedding
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+
+    def __init__(self, hidden_size, frequency_embedding_size=256, 
+                 out_size=None):
+        super().__init__()
+        if out_size is None:
+            out_size = hidden_size
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, out_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    def forward(self, t):
+        t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
+            self.mlp[0].weight.dtype)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+def patchify(imgs, patch_size, input_type='2d'):
+    if input_type == '2d':
+        x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
+    elif input_type == '1d':
+        x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
+    return x
+
+
+def unpatchify(x, channels=3, input_type='2d', img_size=None):
+    if input_type == '2d':
+        patch_size = int((x.shape[2] // channels) ** 0.5)
+        # h = w = int(x.shape[1] ** .5)
+        h, w = img_size[0] // patch_size, img_size[1] // patch_size
+        assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
+        x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
+                             p1=patch_size, p2=patch_size)
+    elif input_type == '1d':
+        patch_size = int((x.shape[2] // channels))
+        h = x.shape[1]
+        assert patch_size * channels == x.shape[2]
+        x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
+    return x
+
+
+class PatchEmbed(nn.Module):
+    """
+     Image to Patch Embedding
+    """
+
+    def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
+        super().__init__()
+        self.patch_size = patch_size
+        self.input_type = input_type
+        if input_type == '2d':
+            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+        elif input_type == '1d':
+            self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+
+    def forward(self, x):
+        if self.input_type == '2d':
+            B, C, H, W = x.shape
+            assert H % self.patch_size == 0 and W % self.patch_size == 0
+        elif self.input_type == '1d':
+            B, C, H = x.shape
+            assert H % self.patch_size == 0
+
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class PositionalConvEmbedding(nn.Module):
+    """
+    Relative positional embedding used in HuBERT
+    """
+
+    def __init__(self, dim=768, kernel_size=128, groups=16):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=kernel_size // 2,
+            groups=groups,
+            bias=True
+        )
+        self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+
+    def forward(self, x):
+        # B C T
+        x = self.conv(x)
+        x = F.gelu(x[:, :, :-1])
+        return x
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+    def __init__(self, dim, length):
+        super(SinusoidalPositionalEncoding, self).__init__()
+        self.length = length
+        self.dim = dim
+        self.register_buffer('pe', self._generate_positional_encoding(length, dim))
+
+    def _generate_positional_encoding(self, length, dim):
+        pe = torch.zeros(length, dim)
+        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+
+        pe = pe.unsqueeze(0)
+        return pe
+
+    def forward(self, x):
+        x = x + self.pe[:, :x.size(1)]
+        return x
+
+
+class PE_wrapper(nn.Module):
+    def __init__(self, dim=768, method='abs', length=None, **kwargs):
+        super().__init__()
+        self.method = method
+        if method == 'abs':
+            # init absolute pe like UViT
+            self.length = length
+            self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
+            trunc_normal_(self.abs_pe, std=.02)
+        elif method == 'conv':
+            self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
+        elif method == 'sinu':
+            self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
+        elif method == 'none':
+            # skip pe
+            self.id = nn.Identity()
+        else:
+            raise NotImplementedError
+
+    def forward(self, x):
+        if self.method == 'abs':
+            _, L, _ = x.shape
+            assert L <= self.length
+            x = x + self.abs_pe[:, :L, :]
+        elif self.method == 'conv':
+            x = x + self.conv_pe(x)
+        elif self.method == 'sinu':
+            x = self.sinu_pe(x)
+        elif self.method == 'none':
+            x = self.id(x)
+        else:
+            raise NotImplementedError
+        return x
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        """
+        Initialize the RMSNorm normalization layer.
+
+        Args:
+            dim (int): The dimension of the input tensor.
+            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+        Attributes:
+            eps (float): A small value added to the denominator for numerical stability.
+            weight (nn.Parameter): Learnable scaling parameter.
+
+        """
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        """
+        Apply the RMSNorm normalization to the input tensor.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The normalized tensor.
+
+        """
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        """
+        Forward pass through the RMSNorm layer.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The output tensor after applying RMSNorm.
+
+        """
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+class GELU(nn.Module):
+
+    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", 
+                 bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+        self.approximate = approximate
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate, approximate=self.approximate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32),
+                      approximate=self.approximate).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states = self.proj(hidden_states)
+        hidden_states = self.gelu(hidden_states)
+        return hidden_states
+
+
+class GEGLU(nn.Module):
+    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states = self.proj(hidden_states)
+        hidden_states, gate = hidden_states.chunk(2, dim=-1)
+        return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        return x * torch.sigmoid(1.702 * x)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def snake_beta(x, alpha, beta):
+    return x + beta * torch.sin(x * alpha).pow(2)
+
+
+class Snake(nn.Module):
+    def __init__(self, dim_in, dim_out, bias,
+                 alpha_trainable=True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+        self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+    def forward(self, x):
+        x = self.proj(x)
+        x = snake_beta(x, self.alpha, self.beta)
+        return x
+
+
+class GESnake(nn.Module):
+    def __init__(self, dim_in, dim_out, bias,
+                 alpha_trainable=True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+        self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+    def forward(self, x):
+        x = self.proj(x)
+        x, gate = x.chunk(2, dim=-1)
+        return x * snake_beta(gate, self.alpha, self.beta)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim,
+        dim_out=None,
+        mult=4,
+        dropout=0.0,
+        activation_fn="geglu",
+        final_dropout=False,
+        inner_dim=None,
+        bias=True,
+    ):
+        super().__init__()
+        if inner_dim is None:
+            inner_dim = int(dim * mult)
+        dim_out = dim_out if dim_out is not None else dim
+
+        if activation_fn == "gelu":
+            act_fn = GELU(dim, inner_dim, bias=bias)
+        elif activation_fn == "gelu-approximate":
+            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+        elif activation_fn == "geglu":
+            act_fn = GEGLU(dim, inner_dim, bias=bias)
+        elif activation_fn == "geglu-approximate":
+            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+        elif activation_fn == "snake":
+            act_fn = Snake(dim, inner_dim, bias=bias)
+        elif activation_fn == "gesnake":
+            act_fn = GESnake(dim, inner_dim, bias=bias)
+        else:
+            raise NotImplementedError
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(act_fn)
+        # project dropout
+        self.net.append(nn.Dropout(dropout))
+        # project out
+        self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+        if final_dropout:
+            self.net.append(nn.Dropout(dropout))
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        for module in self.net:
+            hidden_states = module(hidden_states)
+        return hidden_states
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..636fbf6558b0d469f6802b10c180bbbb6fc431cc
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py
@@ -0,0 +1,91 @@
+import torch
+
+"this rope is faster than llama rope with jit script"
+
+
+def rotate_half(x):
+    x1, x2 = x.chunk(2, dim=-1)
+    return torch.cat((-x2, x1), dim=-1)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def apply_rotary_pos_emb(x, cos, sin):
+    # NOTE: This could probably be moved to Triton
+    # Handle a possible sequence length mismatch in between q and k
+    cos = cos[:, :, : x.shape[-2], :]
+    sin = sin[:, :, : x.shape[-2], :]
+    return (x * cos) + (rotate_half(x) * sin)
+
+
+class RotaryEmbedding(torch.nn.Module):
+    """
+    The rotary position embeddings from RoFormer_ (Su et. al).
+    A crucial insight from the method is that the query and keys are
+    transformed by rotation matrices which depend on the relative positions.
+
+    Other implementations are available in the Rotary Transformer repo_ and in
+    GPT-NeoX_, GPT-NeoX was an inspiration
+
+    .. _RoFormer: https://arxiv.org/abs/2104.09864
+    .. _repo: https://github.com/ZhuiyiTechnology/roformer
+    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+    """
+
+    def __init__(self, dim: int):
+        super().__init__()
+        # Generate and save the inverse frequency buffer (non trainable)
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+        self._seq_len_cached = None
+        self._cos_cached = None
+        self._sin_cached = None
+
+    def _update_cos_sin_tables(self, x, seq_dimension=-2):
+        # expect input: B, H, L, D
+        seq_len = x.shape[seq_dimension]
+
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        # also make sure dtype wont change
+        if (
+            seq_len != self._seq_len_cached
+            or self._cos_cached.device != x.device
+            or self._cos_cached.dtype != x.dtype
+        ):
+            self._seq_len_cached = seq_len
+            t = torch.arange(
+                x.shape[seq_dimension], device=x.device, dtype=torch.float32
+            )
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+            self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+            self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+        return self._cos_cached, self._sin_cached
+
+    def forward(self, q, k):
+        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+            q.float(), seq_dimension=-2
+        )
+        if k is not None:
+            return (
+                apply_rotary_pos_emb(q.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(q),
+                apply_rotary_pos_emb(k.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(k),
+            )
+        else:
+            return (
+                apply_rotary_pos_emb(q.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(q),
+                None
+            )
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d003a6c08c1675967f992e3d052b293a202d446
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+from typing import Optional, Tuple
+
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+    """
+    
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+    
+    # Convert mask_prob to a NumPy array
+    mask_prob = np.array(mask_prob)
+    
+    # Calculate all_num_mask for each element in the batch
+    all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
+    
+    # Apply the max operation with min_masks for each element
+    all_num_mask = np.maximum(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask[i]
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - keep_length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+    # min_len = min([len(m) for m in mask_idcs])
+    for i, mask_idc in enumerate(mask_idcs):
+        # if len(mask_idc) > min_len:
+            # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        mask[i, mask_idc] = True
+
+    return torch.tensor(mask)
+
+
+if __name__ == '__main__':
+    mask = compute_mask_indices(
+        shape=[4, 500],
+        padding_mask=None,
+        mask_prob=[0.65, 0.5, 0.65, 0.65],
+        mask_length=10,
+        mask_type="static",
+        mask_other=0.0,
+        min_masks=1,
+        no_overlap=False,
+        min_space=0,
+    )
+    print(mask)
+    print(mask.sum(dim=1))
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py
@@ -0,0 +1,114 @@
+# code from timm 0.3.2
+import torch
+import torch.nn as nn
+import math
+import warnings
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+                      "The distribution of values may be incorrect.",
+                      stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        l = norm_cdf((a - mean) / std)
+        u = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [l, u], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    # type: (Tensor, float, float, float, float) -> Tensor
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+    Examples:
+        >>> w = torch.empty(3, 5)
+        >>> nn.init.trunc_normal_(w)
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, 
+                 act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
\ No newline at end of file
diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/utils/__pycache__/__init__.cpython-310.pyc b/src/models/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc31b0a19bf83eef2a703df69b4272a12bfbe577
Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/__init__.cpython-311.pyc b/src/models/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c349d3285b7eb700500142d0d54cccaad6d0a80
Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/attention.cpython-310.pyc b/src/models/utils/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ebbb2fc665b2c2ff2115bc98ec054430333e6ee
Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/attention.cpython-311.pyc b/src/models/utils/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51bf17ebc671077c9fd467e530c92da966b3095a
Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/modules.cpython-310.pyc b/src/models/utils/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d507df3df9cbf9fa29fbabb4591df36aedf6bdd4
Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/modules.cpython-311.pyc b/src/models/utils/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e5628f9e5e0a0c83402e4c6a350bb50fdd802c9
Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/rotary.cpython-310.pyc b/src/models/utils/__pycache__/rotary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..716bfd4abc834b3e912c4f4574ddc3d3597183a5
Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/rotary.cpython-311.pyc b/src/models/utils/__pycache__/rotary.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96fbebedb56e837be29b9c7b0fd2a3c053571679
Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/span_mask.cpython-310.pyc b/src/models/utils/__pycache__/span_mask.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dc66ee9fc445b41e939f29ba32cfdeb6169bcc2
Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/span_mask.cpython-311.pyc b/src/models/utils/__pycache__/span_mask.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1570b3fd02c86a35de5716cfa17d9c4384595bc7
Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/timm.cpython-310.pyc b/src/models/utils/__pycache__/timm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fa38c2a3330015ecde142ac1e187afd6afd3aa5
Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/timm.cpython-311.pyc b/src/models/utils/__pycache__/timm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f30fbb13efb3ca0c053bbf2b0d9c48c71b82465
Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-311.pyc differ
diff --git a/src/models/utils/attention.py b/src/models/utils/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f77c9148fc54916e1bedc2f36f77f6a2164986a
--- /dev/null
+++ b/src/models/utils/attention.py
@@ -0,0 +1,290 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+from .modules import RMSNorm
+
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+    ATTENTION_MODE = 'flash'
+else:
+    ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+    b, ndim = sim.shape[0], mask.ndim
+    if ndim == 3:
+        mask = rearrange(mask, "b n m -> b 1 n m")
+    if ndim == 2:
+        mask = repeat(mask, "n m -> b 1 n m", b=b)
+    max_neg_value = -torch.finfo(sim.dtype).max
+    sim = sim.masked_fill(~mask, max_neg_value)
+    return sim
+
+
+def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
+    def default(val, d):
+        return val if val is not None else (d() if isfunction(d) else d)
+    b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
+    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+    return attn_mask
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, context_dim=None, num_heads=8,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 attn_drop=0., proj_drop=0., rope_mode='none'):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        if context_dim is None:
+            self.cross_attn = False
+        else:
+            self.cross_attn = True
+
+        context_dim = dim if context_dim is None else context_dim
+
+        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+
+        if qk_norm is None:
+            self.norm_q = nn.Identity()
+            self.norm_k = nn.Identity()
+        elif qk_norm == 'layernorm':
+            self.norm_q = nn.LayerNorm(head_dim)
+            self.norm_k = nn.LayerNorm(head_dim)
+        elif qk_norm == 'rmsnorm':
+            self.norm_q = RMSNorm(head_dim)
+            self.norm_k = RMSNorm(head_dim)
+        else:
+            raise NotImplementedError
+
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        if self.cross_attn:
+            assert rope_mode == 'none'
+        self.rope_mode = rope_mode
+        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+            self.rotary = RotaryEmbedding(dim=head_dim)
+        elif self.rope_mode == 'dual':
+            self.rotary_x = RotaryEmbedding(dim=head_dim)
+            self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+    def _rotary(self, q, k, extras):
+        if self.rope_mode == 'shared':
+            q, k = self.rotary(q=q, k=k)
+        elif self.rope_mode == 'x_only':
+            q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'dual':
+            q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'none':
+            pass
+        else:
+            raise NotImplementedError
+        return q, k
+
+    def _attn(self, q, k, v, mask_binary):
+        if ATTENTION_MODE == 'flash':
+            x = F.scaled_dot_product_attention(q, k, v,
+                                               dropout_p=self.attn_drop_p,
+                                               attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        else:
+            raise NotImplementedError
+        return x
+
+    def forward(self, x, context=None, context_mask=None, extras=0):
+        B, L, C = x.shape
+        if context is None:
+            context = x
+
+        q = self.to_q(x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        if context_mask is not None:
+            mask_binary = create_mask(x.shape, context.shape,
+                                      x.device, None, context_mask)
+        else:
+            mask_binary = None
+
+        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
+        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
+        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
+
+        q = self.norm_q(q)
+        k = self.norm_k(k)
+
+        q, k = self._rotary(q, k, extras)
+
+        x = self._attn(q, k, v, mask_binary)
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class JointAttention(nn.Module):
+    def __init__(self, dim, num_heads=8,
+                 qkv_bias=False, qk_scale=None, qk_norm=None,
+                 attn_drop=0., proj_drop=0.,
+                 rope_mode='none'):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
+        self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
+
+        self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
+        self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
+
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+
+        self.proj_x = nn.Linear(dim, dim)
+        self.proj_drop_x = nn.Dropout(proj_drop)
+
+        self.proj_c = nn.Linear(dim, dim)
+        self.proj_drop_c = nn.Dropout(proj_drop)
+
+        self.rope_mode = rope_mode
+        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+            self.rotary = RotaryEmbedding(dim=head_dim)
+        elif self.rope_mode == 'dual':
+            self.rotary_x = RotaryEmbedding(dim=head_dim)
+            self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+    def _make_qkv_layers(self, dim, qkv_bias):
+        return (nn.Linear(dim, dim, bias=qkv_bias),
+                nn.Linear(dim, dim, bias=qkv_bias),
+                nn.Linear(dim, dim, bias=qkv_bias))
+
+    def _make_norm_layers(self, qk_norm, head_dim):
+        if qk_norm is None:
+            norm_q = nn.Identity()
+            norm_k = nn.Identity()
+        elif qk_norm == 'layernorm':
+            norm_q = nn.LayerNorm(head_dim)
+            norm_k = nn.LayerNorm(head_dim)
+        elif qk_norm == 'rmsnorm':
+            norm_q = RMSNorm(head_dim)
+            norm_k = RMSNorm(head_dim)
+        else:
+            raise NotImplementedError
+        return norm_q, norm_k
+
+    def _rotary(self, q, k, extras):
+        if self.rope_mode == 'shared':
+            q, k = self.rotary(q=q, k=k)
+        elif self.rope_mode == 'x_only':
+            q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'dual':
+            q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+            q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+            q = torch.cat((q_c, q_x), dim=2)
+            k = torch.cat((k_c, k_x), dim=2)
+        elif self.rope_mode == 'none':
+            pass
+        else:
+            raise NotImplementedError
+        return q, k
+
+    def _attn(self, q, k, v, mask_binary):
+        if ATTENTION_MODE == 'flash':
+            x = F.scaled_dot_product_attention(q, k, v,
+                                               dropout_p=self.attn_drop_p,
+                                               attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        else:
+            raise NotImplementedError
+        return x
+
+    def _cat_mask(self, x, context, x_mask=None, context_mask=None):
+        B = x.shape[0]
+        if x_mask is None:
+            x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+        if context_mask is None:
+            context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
+        mask = torch.cat([context_mask, x_mask], dim=1)
+        return mask
+
+    def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
+        B, Lx, C = x.shape
+        _, Lc, _ = context.shape
+        if x_mask is not None or context_mask is not None:
+            mask = self._cat_mask(x, context,
+                                  x_mask=x_mask,
+                                  context_mask=context_mask)
+            shape = [B, Lx+Lc, C]
+            mask_binary = create_mask(q_shape=shape, k_shape=shape,
+                                      device=x.device,
+                                      q_mask=None, k_mask=mask)
+        else:
+            mask_binary = None
+
+        qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
+        qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
+
+        qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+                                                    H=self.num_heads), [qx, kx, vx])
+        qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+                                                    H=self.num_heads), [qc, kc, vc])
+
+        qx, kx = self.norm_qx(qx), self.norm_kx(kx)
+        qc, kc = self.norm_qc(qc), self.norm_kc(kc)
+
+        q, k, v = (torch.cat([qc, qx], dim=2),
+                   torch.cat([kc, kx], dim=2),
+                   torch.cat([vc, vx], dim=2))
+
+        q, k = self._rotary(q, k, extras)
+
+        x = self._attn(q, k, v, mask_binary)
+
+        context, x = x[:, :Lc, :], x[:, Lc:, :]
+
+        x = self.proj_x(x)
+        x = self.proj_drop_x(x)
+
+        context = self.proj_c(context)
+        context = self.proj_drop_c(context)
+
+        return x, context
\ No newline at end of file
diff --git a/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ba4f700842a91611ad1eda0f872df04162d1e59
--- /dev/null
+++ b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+    ATTENTION_MODE = 'flash'
+else:
+    ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+    b, ndim = sim.shape[0], mask.ndim
+    if ndim == 3:
+        mask = rearrange(mask, "b n m -> b 1 n m")
+    if ndim == 2:
+        mask = repeat(mask, "n m -> b 1 n m", b=b)
+    max_neg_value = -torch.finfo(sim.dtype).max
+    sim = sim.masked_fill(~mask, max_neg_value)
+    return sim
+
+
+def create_mask(q, k, q_mask=None, k_mask=None):
+    def default(val, d):
+        return val if val is not None else (d() if isfunction(d) else d)
+
+    b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
+    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+    return attn_mask
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
+                 attn_drop=0., proj_drop=0., use_rope=False):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        context_dim = dim if context_dim is None else context_dim
+
+        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        self.use_rope = use_rope
+        if self.use_rope:
+            self.rotary = RotaryEmbedding(dim=head_dim)
+
+    def forward(self, x, context=None, context_mask=None):
+        B, L, C = x.shape
+        q = self.to_q(x)
+        if context is None:
+            context = x
+        else:
+            assert self.use_rope is False
+
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        if context_mask is not None:
+            mask_binary = create_mask(x, context, None, context_mask)
+        else:
+            mask_binary = None
+
+        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
+        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
+        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
+
+        if self.use_rope:
+            q, k = self.rotary(q=q, k=k)
+
+        if ATTENTION_MODE == 'flash':
+            x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
+                                                                 dropout_p=self.attn_drop_p,
+                                                                 attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+        else:
+            raise NotImplementedError
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
\ No newline at end of file
diff --git a/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..1126d1a589ca106a6399896742f511df43b0d0ae
--- /dev/null
+++ b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py
@@ -0,0 +1,74 @@
+import torch
+from typing import Tuple
+from rotary import RotaryEmbedding
+import time
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor,
+                          x: torch.Tensor,):
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def compute_rope(q, freqs_cis):
+    return q * freqs_cis
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+    xq1, xq2 = xq.chunk(2, dim=-1)
+    xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
+
+    xk1, xk2 = xk.chunk(2, dim=-1)
+    xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
+
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
+    xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+if __name__ == '__main__':
+    # Move data to CUDA
+    freq_cis = precompute_freqs_cis(4, 5).cuda()
+    x = torch.rand(1, 5, 1, 4).cuda()
+    y = torch.rand(1, 5, 1, 4).cuda()
+
+    # First method
+    start_time = time.time()
+    for _ in range(20000):
+        x1, y1 = apply_rotary_emb(x, y, freq_cis)
+    end_time = time.time()
+    print(f"Method 1 time cost: {end_time - start_time} seconds")
+
+    # Prepare data for the second method
+    x = x.permute(0, 2, 1, 3)
+    y = y.permute(0, 2, 1, 3)
+    rope = RotaryEmbedding(4).cuda()
+
+    # Second method
+    start_time = time.time()
+    for _ in range(20000):
+        x2, y2 = rope(x, y)
+    end_time = time.time()
+    print(f"Method 2 time cost: {end_time - start_time} seconds")
+
+    # Print the results
+    print(x1)
+    print(x2)
\ No newline at end of file
diff --git a/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de78be925584642b52de19239fd67bdcf6173d95
Binary files /dev/null and b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc differ
diff --git a/src/models/utils/bk/attention.py b/src/models/utils/bk/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ba4f700842a91611ad1eda0f872df04162d1e59
--- /dev/null
+++ b/src/models/utils/bk/attention.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+    ATTENTION_MODE = 'flash'
+else:
+    ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+    b, ndim = sim.shape[0], mask.ndim
+    if ndim == 3:
+        mask = rearrange(mask, "b n m -> b 1 n m")
+    if ndim == 2:
+        mask = repeat(mask, "n m -> b 1 n m", b=b)
+    max_neg_value = -torch.finfo(sim.dtype).max
+    sim = sim.masked_fill(~mask, max_neg_value)
+    return sim
+
+
+def create_mask(q, k, q_mask=None, k_mask=None):
+    def default(val, d):
+        return val if val is not None else (d() if isfunction(d) else d)
+
+    b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
+    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+    return attn_mask
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
+                 attn_drop=0., proj_drop=0., use_rope=False):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        context_dim = dim if context_dim is None else context_dim
+
+        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+        self.attn_drop_p = attn_drop
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        self.use_rope = use_rope
+        if self.use_rope:
+            self.rotary = RotaryEmbedding(dim=head_dim)
+
+    def forward(self, x, context=None, context_mask=None):
+        B, L, C = x.shape
+        q = self.to_q(x)
+        if context is None:
+            context = x
+        else:
+            assert self.use_rope is False
+
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        if context_mask is not None:
+            mask_binary = create_mask(x, context, None, context_mask)
+        else:
+            mask_binary = None
+
+        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
+        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
+        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
+
+        if self.use_rope:
+            q, k = self.rotary(q=q, k=k)
+
+        if ATTENTION_MODE == 'flash':
+            x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
+                                                                 dropout_p=self.attn_drop_p,
+                                                                 attn_mask=mask_binary)
+            x = einops.rearrange(x, 'B H L D -> B L (H D)')
+        elif ATTENTION_MODE == 'math':
+            attn = (q @ k.transpose(-2, -1)) * self.scale
+            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+            attn = attn.softmax(dim=-1)
+            attn = self.attn_drop(attn)
+            x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+        else:
+            raise NotImplementedError
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
\ No newline at end of file
diff --git a/src/models/utils/bk/llama_rotary.py b/src/models/utils/bk/llama_rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..1126d1a589ca106a6399896742f511df43b0d0ae
--- /dev/null
+++ b/src/models/utils/bk/llama_rotary.py
@@ -0,0 +1,74 @@
+import torch
+from typing import Tuple
+from rotary import RotaryEmbedding
+import time
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor,
+                          x: torch.Tensor,):
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def compute_rope(q, freqs_cis):
+    return q * freqs_cis
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+    xq1, xq2 = xq.chunk(2, dim=-1)
+    xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
+
+    xk1, xk2 = xk.chunk(2, dim=-1)
+    xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
+
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
+    xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+if __name__ == '__main__':
+    # Move data to CUDA
+    freq_cis = precompute_freqs_cis(4, 5).cuda()
+    x = torch.rand(1, 5, 1, 4).cuda()
+    y = torch.rand(1, 5, 1, 4).cuda()
+
+    # First method
+    start_time = time.time()
+    for _ in range(20000):
+        x1, y1 = apply_rotary_emb(x, y, freq_cis)
+    end_time = time.time()
+    print(f"Method 1 time cost: {end_time - start_time} seconds")
+
+    # Prepare data for the second method
+    x = x.permute(0, 2, 1, 3)
+    y = y.permute(0, 2, 1, 3)
+    rope = RotaryEmbedding(4).cuda()
+
+    # Second method
+    start_time = time.time()
+    for _ in range(20000):
+        x2, y2 = rope(x, y)
+    end_time = time.time()
+    print(f"Method 2 time cost: {end_time - start_time} seconds")
+
+    # Print the results
+    print(x1)
+    print(x2)
\ No newline at end of file
diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c825b988439d9b91e1e1d30c1cf842880252c0bf
--- /dev/null
+++ b/src/models/utils/modules.py
@@ -0,0 +1,374 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.cuda.amp import autocast
+import math
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .timm import trunc_normal_
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def film_modulate(x, shift, scale):
+    return x * (1 + scale) + shift
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+    """
+    Create sinusoidal timestep embeddings.
+
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    half = dim // 2
+    freqs = torch.exp(
+        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+    ).to(device=timesteps.device)
+    args = timesteps[:, None].float() * freqs[None]
+    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+    if dim % 2:
+        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    return embedding
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+
+    def __init__(self, hidden_size, frequency_embedding_size=256, 
+                 out_size=None):
+        super().__init__()
+        if out_size is None:
+            out_size = hidden_size
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, out_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    def forward(self, t):
+        t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
+            self.mlp[0].weight.dtype)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+def patchify(imgs, patch_size, input_type='2d'):
+    if input_type == '2d':
+        x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
+    elif input_type == '1d':
+        x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
+    return x
+
+
+def unpatchify(x, channels=3, input_type='2d', img_size=None):
+    if input_type == '2d':
+        patch_size = int((x.shape[2] // channels) ** 0.5)
+        # h = w = int(x.shape[1] ** .5)
+        h, w = img_size[0] // patch_size, img_size[1] // patch_size
+        assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
+        x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
+                             p1=patch_size, p2=patch_size)
+    elif input_type == '1d':
+        patch_size = int((x.shape[2] // channels))
+        h = x.shape[1]
+        assert patch_size * channels == x.shape[2]
+        x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
+    return x
+
+
+class PatchEmbed(nn.Module):
+    """
+     Image to Patch Embedding
+    """
+
+    def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
+        super().__init__()
+        self.patch_size = patch_size
+        self.input_type = input_type
+        if input_type == '2d':
+            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+        elif input_type == '1d':
+            self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+
+    def forward(self, x):
+        if self.input_type == '2d':
+            B, C, H, W = x.shape
+            assert H % self.patch_size == 0 and W % self.patch_size == 0
+        elif self.input_type == '1d':
+            B, C, H = x.shape
+            assert H % self.patch_size == 0
+
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class PositionalConvEmbedding(nn.Module):
+    """
+    Relative positional embedding used in HuBERT
+    """
+
+    def __init__(self, dim=768, kernel_size=128, groups=16):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=kernel_size // 2,
+            groups=groups,
+            bias=True
+        )
+        self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+
+    def forward(self, x):
+        # B C T
+        x = self.conv(x)
+        x = F.gelu(x[:, :, :-1])
+        return x
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+    def __init__(self, dim, length):
+        super(SinusoidalPositionalEncoding, self).__init__()
+        self.length = length
+        self.dim = dim
+        self.register_buffer('pe', self._generate_positional_encoding(length, dim))
+
+    def _generate_positional_encoding(self, length, dim):
+        pe = torch.zeros(length, dim)
+        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+
+        pe = pe.unsqueeze(0)
+        return pe
+
+    def forward(self, x):
+        x = x + self.pe[:, :x.size(1)]
+        return x
+
+
+class PE_wrapper(nn.Module):
+    def __init__(self, dim=768, method='abs', length=None, **kwargs):
+        super().__init__()
+        self.method = method
+        if method == 'abs':
+            # init absolute pe like UViT
+            self.length = length
+            self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
+            trunc_normal_(self.abs_pe, std=.02)
+        elif method == 'conv':
+            self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
+        elif method == 'sinu':
+            self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
+        elif method == 'none':
+            # skip pe
+            self.id = nn.Identity()
+        else:
+            raise NotImplementedError
+
+    def forward(self, x):
+        if self.method == 'abs':
+            _, L, _ = x.shape
+            assert L <= self.length
+            x = x + self.abs_pe[:, :L, :]
+        elif self.method == 'conv':
+            x = x + self.conv_pe(x)
+        elif self.method == 'sinu':
+            x = self.sinu_pe(x)
+        elif self.method == 'none':
+            x = self.id(x)
+        else:
+            raise NotImplementedError
+        return x
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        """
+        Initialize the RMSNorm normalization layer.
+
+        Args:
+            dim (int): The dimension of the input tensor.
+            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+        Attributes:
+            eps (float): A small value added to the denominator for numerical stability.
+            weight (nn.Parameter): Learnable scaling parameter.
+
+        """
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        """
+        Apply the RMSNorm normalization to the input tensor.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The normalized tensor.
+
+        """
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        """
+        Forward pass through the RMSNorm layer.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The output tensor after applying RMSNorm.
+
+        """
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+class GELU(nn.Module):
+
+    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", 
+                 bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+        self.approximate = approximate
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate, approximate=self.approximate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32),
+                      approximate=self.approximate).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states = self.proj(hidden_states)
+        hidden_states = self.gelu(hidden_states)
+        return hidden_states
+
+
+class GEGLU(nn.Module):
+    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states = self.proj(hidden_states)
+        hidden_states, gate = hidden_states.chunk(2, dim=-1)
+        return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        return x * torch.sigmoid(1.702 * x)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def snake_beta(x, alpha, beta):
+    return x + beta * torch.sin(x * alpha).pow(2)
+
+
+class Snake(nn.Module):
+    def __init__(self, dim_in, dim_out, bias,
+                 alpha_trainable=True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+        self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+    def forward(self, x):
+        x = self.proj(x)
+        x = snake_beta(x, self.alpha, self.beta)
+        return x
+
+
+class GESnake(nn.Module):
+    def __init__(self, dim_in, dim_out, bias,
+                 alpha_trainable=True):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+        self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+    def forward(self, x):
+        x = self.proj(x)
+        x, gate = x.chunk(2, dim=-1)
+        return x * snake_beta(gate, self.alpha, self.beta)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim,
+        dim_out=None,
+        mult=4,
+        dropout=0.0,
+        activation_fn="geglu",
+        final_dropout=False,
+        inner_dim=None,
+        bias=True,
+    ):
+        super().__init__()
+        if inner_dim is None:
+            inner_dim = int(dim * mult)
+        dim_out = dim_out if dim_out is not None else dim
+
+        if activation_fn == "gelu":
+            act_fn = GELU(dim, inner_dim, bias=bias)
+        elif activation_fn == "gelu-approximate":
+            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+        elif activation_fn == "geglu":
+            act_fn = GEGLU(dim, inner_dim, bias=bias)
+        elif activation_fn == "geglu-approximate":
+            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+        elif activation_fn == "snake":
+            act_fn = Snake(dim, inner_dim, bias=bias)
+        elif activation_fn == "gesnake":
+            act_fn = GESnake(dim, inner_dim, bias=bias)
+        else:
+            raise NotImplementedError
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(act_fn)
+        # project dropout
+        self.net.append(nn.Dropout(dropout))
+        # project out
+        self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+        if final_dropout:
+            self.net.append(nn.Dropout(dropout))
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        for module in self.net:
+            hidden_states = module(hidden_states)
+        return hidden_states
\ No newline at end of file
diff --git a/src/models/utils/rotary.py b/src/models/utils/rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..636fbf6558b0d469f6802b10c180bbbb6fc431cc
--- /dev/null
+++ b/src/models/utils/rotary.py
@@ -0,0 +1,91 @@
+import torch
+
+"this rope is faster than llama rope with jit script"
+
+
+def rotate_half(x):
+    x1, x2 = x.chunk(2, dim=-1)
+    return torch.cat((-x2, x1), dim=-1)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def apply_rotary_pos_emb(x, cos, sin):
+    # NOTE: This could probably be moved to Triton
+    # Handle a possible sequence length mismatch in between q and k
+    cos = cos[:, :, : x.shape[-2], :]
+    sin = sin[:, :, : x.shape[-2], :]
+    return (x * cos) + (rotate_half(x) * sin)
+
+
+class RotaryEmbedding(torch.nn.Module):
+    """
+    The rotary position embeddings from RoFormer_ (Su et. al).
+    A crucial insight from the method is that the query and keys are
+    transformed by rotation matrices which depend on the relative positions.
+
+    Other implementations are available in the Rotary Transformer repo_ and in
+    GPT-NeoX_, GPT-NeoX was an inspiration
+
+    .. _RoFormer: https://arxiv.org/abs/2104.09864
+    .. _repo: https://github.com/ZhuiyiTechnology/roformer
+    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+    """
+
+    def __init__(self, dim: int):
+        super().__init__()
+        # Generate and save the inverse frequency buffer (non trainable)
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+        self._seq_len_cached = None
+        self._cos_cached = None
+        self._sin_cached = None
+
+    def _update_cos_sin_tables(self, x, seq_dimension=-2):
+        # expect input: B, H, L, D
+        seq_len = x.shape[seq_dimension]
+
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        # also make sure dtype wont change
+        if (
+            seq_len != self._seq_len_cached
+            or self._cos_cached.device != x.device
+            or self._cos_cached.dtype != x.dtype
+        ):
+            self._seq_len_cached = seq_len
+            t = torch.arange(
+                x.shape[seq_dimension], device=x.device, dtype=torch.float32
+            )
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+            self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+            self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+        return self._cos_cached, self._sin_cached
+
+    def forward(self, q, k):
+        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+            q.float(), seq_dimension=-2
+        )
+        if k is not None:
+            return (
+                apply_rotary_pos_emb(q.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(q),
+                apply_rotary_pos_emb(k.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(k),
+            )
+        else:
+            return (
+                apply_rotary_pos_emb(q.float(),
+                                     self._cos_cached,
+                                     self._sin_cached).type_as(q),
+                None
+            )
\ No newline at end of file
diff --git a/src/models/utils/span_mask.py b/src/models/utils/span_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d003a6c08c1675967f992e3d052b293a202d446
--- /dev/null
+++ b/src/models/utils/span_mask.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+from typing import Optional, Tuple
+
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+    """
+    
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+    
+    # Convert mask_prob to a NumPy array
+    mask_prob = np.array(mask_prob)
+    
+    # Calculate all_num_mask for each element in the batch
+    all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
+    
+    # Apply the max operation with min_masks for each element
+    all_num_mask = np.maximum(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask[i]
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - keep_length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+    # min_len = min([len(m) for m in mask_idcs])
+    for i, mask_idc in enumerate(mask_idcs):
+        # if len(mask_idc) > min_len:
+            # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        mask[i, mask_idc] = True
+
+    return torch.tensor(mask)
+
+
+if __name__ == '__main__':
+    mask = compute_mask_indices(
+        shape=[4, 500],
+        padding_mask=None,
+        mask_prob=[0.65, 0.5, 0.65, 0.65],
+        mask_length=10,
+        mask_type="static",
+        mask_other=0.0,
+        min_masks=1,
+        no_overlap=False,
+        min_space=0,
+    )
+    print(mask)
+    print(mask.sum(dim=1))
\ No newline at end of file
diff --git a/src/models/utils/timm.py b/src/models/utils/timm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4
--- /dev/null
+++ b/src/models/utils/timm.py
@@ -0,0 +1,114 @@
+# code from timm 0.3.2
+import torch
+import torch.nn as nn
+import math
+import warnings
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+                      "The distribution of values may be incorrect.",
+                      stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        l = norm_cdf((a - mean) / std)
+        u = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [l, u], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    # type: (Tensor, float, float, float, float) -> Tensor
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+    Examples:
+        >>> w = torch.empty(3, 5)
+        >>> nn.init.trunc_normal_(w)
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, 
+                 act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
\ No newline at end of file
diff --git a/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py b/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf799cb010332a6adbb2c74213df24c2602a6e8
--- /dev/null
+++ b/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+from .dac import DAC
+from .stable_vae import load_vae
+
+
+class Autoencoder(nn.Module):
+    def __init__(self, ckpt_path, model_type='dac', quantization_first=False):
+        super(Autoencoder, self).__init__()
+        self.model_type = model_type
+        if self.model_type == 'dac':
+            model = DAC.load(ckpt_path)
+        elif self.model_type == 'stable_vae':
+            model = load_vae(ckpt_path)
+        else:
+            raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+        self.ae = model.eval()
+        self.quantization_first = quantization_first
+        print(f'Autoencoder quantization first mode: {quantization_first}')
+
+    @torch.no_grad()
+    def forward(self, audio=None, embedding=None):
+        if self.model_type == 'dac':
+            return self.process_dac(audio, embedding)
+        elif self.model_type == 'encodec':
+            return self.process_encodec(audio, embedding)
+        elif self.model_type == 'stable_vae':
+            return self.process_stable_vae(audio, embedding)
+        else:
+            raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+
+    def process_dac(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                z, *_ = self.ae.quantizer(z, None)
+            return z
+        elif embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                z, *_ = self.ae.quantizer(z, None)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
+
+    def process_encodec(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                code = self.ae.quantizer.encode(z)
+                z = self.ae.quantizer.decode(code)
+            return z
+        elif embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                code = self.ae.quantizer.encode(z)
+                z = self.ae.quantizer.decode(code)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
+
+    def process_stable_vae(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                z = self.ae.bottleneck.encode(z)
+            return z
+        if embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                z = self.ae.bottleneck.encode(z)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
diff --git a/src/modules/.ipynb_checkpoints/clap_wrapper-checkpoint.py b/src/modules/.ipynb_checkpoints/clap_wrapper-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc b/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..148ad771c41a0d927fe7068eeb6f755caba628eb
Binary files /dev/null and b/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc differ
diff --git a/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc b/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb87e64ba7232c1d1a59a276db00eb8aea7f9fea
Binary files /dev/null and b/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc differ
diff --git a/src/modules/autoencoder_wrapper.py b/src/modules/autoencoder_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf799cb010332a6adbb2c74213df24c2602a6e8
--- /dev/null
+++ b/src/modules/autoencoder_wrapper.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+from .dac import DAC
+from .stable_vae import load_vae
+
+
+class Autoencoder(nn.Module):
+    def __init__(self, ckpt_path, model_type='dac', quantization_first=False):
+        super(Autoencoder, self).__init__()
+        self.model_type = model_type
+        if self.model_type == 'dac':
+            model = DAC.load(ckpt_path)
+        elif self.model_type == 'stable_vae':
+            model = load_vae(ckpt_path)
+        else:
+            raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+        self.ae = model.eval()
+        self.quantization_first = quantization_first
+        print(f'Autoencoder quantization first mode: {quantization_first}')
+
+    @torch.no_grad()
+    def forward(self, audio=None, embedding=None):
+        if self.model_type == 'dac':
+            return self.process_dac(audio, embedding)
+        elif self.model_type == 'encodec':
+            return self.process_encodec(audio, embedding)
+        elif self.model_type == 'stable_vae':
+            return self.process_stable_vae(audio, embedding)
+        else:
+            raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+
+    def process_dac(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                z, *_ = self.ae.quantizer(z, None)
+            return z
+        elif embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                z, *_ = self.ae.quantizer(z, None)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
+
+    def process_encodec(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                code = self.ae.quantizer.encode(z)
+                z = self.ae.quantizer.decode(code)
+            return z
+        elif embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                code = self.ae.quantizer.encode(z)
+                z = self.ae.quantizer.decode(code)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
+
+    def process_stable_vae(self, audio=None, embedding=None):
+        if audio is not None:
+            z = self.ae.encoder(audio)
+            if self.quantization_first:
+                z = self.ae.bottleneck.encode(z)
+            return z
+        if embedding is not None:
+            z = embedding
+            if self.quantization_first:
+                audio = self.ae.decoder(z)
+            else:
+                z = self.ae.bottleneck.encode(z)
+                audio = self.ae.decoder(z)
+            return audio
+        else:
+            raise ValueError("Either audio or embedding must be provided.")
diff --git a/src/modules/clap_wrapper.py b/src/modules/clap_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215
--- /dev/null
+++ b/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,16 @@
+__version__ = "1.0.0"
+
+# preserved here for legacy reasons
+__model_version__ = "latest"
+
+import audiotools
+
+audiotools.ml.BaseModel.INTERN += ["dac.**"]
+audiotools.ml.BaseModel.EXTERN += ["einops"]
+
+
+from . import nn
+from . import model
+from . import utils
+from .model import DAC
+from .model import DACFile
diff --git a/src/modules/dac/__init__.py b/src/modules/dac/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215
--- /dev/null
+++ b/src/modules/dac/__init__.py
@@ -0,0 +1,16 @@
+__version__ = "1.0.0"
+
+# preserved here for legacy reasons
+__model_version__ = "latest"
+
+import audiotools
+
+audiotools.ml.BaseModel.INTERN += ["dac.**"]
+audiotools.ml.BaseModel.EXTERN += ["einops"]
+
+
+from . import nn
+from . import model
+from . import utils
+from .model import DAC
+from .model import DACFile
diff --git a/src/modules/dac/__main__.py b/src/modules/dac/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fa8d15307997663f8143669c2bd56e0889cb021
--- /dev/null
+++ b/src/modules/dac/__main__.py
@@ -0,0 +1,36 @@
+import sys
+
+import argbind
+
+from dac.utils import download
+from dac.utils.decode import decode
+from dac.utils.encode import encode
+
+STAGES = ["encode", "decode", "download"]
+
+
+def run(stage: str):
+    """Run stages.
+
+    Parameters
+    ----------
+    stage : str
+        Stage to run
+    """
+    if stage not in STAGES:
+        raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
+    stage_fn = globals()[stage]
+
+    if stage == "download":
+        stage_fn()
+        return
+
+    stage_fn()
+
+
+if __name__ == "__main__":
+    group = sys.argv.pop(1)
+    args = argbind.parse_args(group=group)
+
+    with argbind.scope(args):
+        run(group)
diff --git a/src/modules/dac/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37ab8619ee9f735925d2c6134f01379f9f07fe80
Binary files /dev/null and b/src/modules/dac/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/dac/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..996c60afc76637baa76609f79e1180a9db1ba68c
Binary files /dev/null and b/src/modules/dac/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/dac/compare/__init__.py b/src/modules/dac/compare/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/modules/dac/compare/encodec.py b/src/modules/dac/compare/encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..42877de3cffa7d681b28266e4e1f537d48b749eb
--- /dev/null
+++ b/src/modules/dac/compare/encodec.py
@@ -0,0 +1,54 @@
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from encodec import EncodecModel
+
+
+class Encodec(BaseModel):
+    def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
+        super().__init__()
+
+        if sample_rate == 24000:
+            self.model = EncodecModel.encodec_model_24khz()
+        else:
+            self.model = EncodecModel.encodec_model_48khz()
+        self.model.set_target_bandwidth(bandwidth)
+        self.sample_rate = 44100
+
+    def forward(
+        self,
+        audio_data: torch.Tensor,
+        sample_rate: int = 44100,
+        n_quantizers: int = None,
+    ):
+        signal = AudioSignal(audio_data, sample_rate)
+        signal.resample(self.model.sample_rate)
+        recons = self.model(signal.audio_data)
+        recons = AudioSignal(recons, self.model.sample_rate)
+        recons.resample(sample_rate)
+        return {"audio": recons.audio_data}
+
+
+if __name__ == "__main__":
+    import numpy as np
+    from functools import partial
+
+    model = Encodec()
+
+    for n, m in model.named_modules():
+        o = m.extra_repr()
+        p = sum([np.prod(p.size()) for p in m.parameters()])
+        fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+        setattr(m, "extra_repr", partial(fn, o=o, p=p))
+    print(model)
+    print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+    length = 88200 * 2
+    x = torch.randn(1, 1, length).to(model.device)
+    x.requires_grad_(True)
+    x.retain_grad()
+
+    # Make a forward pass
+    out = model(x)["audio"]
+
+    print(x.shape, out.shape)
diff --git a/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py b/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d44a18a9e98fdcce9377a744b6f9d7dfa6a607b
--- /dev/null
+++ b/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py
@@ -0,0 +1,364 @@
+import math
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from torch import nn
+
+from .base import CodecMixin
+from ..nn.layers import Snake1d
+from ..nn.layers import WNConv1d
+from ..nn.layers import WNConvTranspose1d
+from ..nn.quantize import ResidualVectorQuantize
+
+
+def init_weights(m):
+    if isinstance(m, nn.Conv1d):
+        nn.init.trunc_normal_(m.weight, std=0.02)
+        nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+    def __init__(self, dim: int = 16, dilation: int = 1):
+        super().__init__()
+        pad = ((7 - 1) * dilation) // 2
+        self.block = nn.Sequential(
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=1),
+        )
+
+    def forward(self, x):
+        y = self.block(x)
+        pad = (x.shape[-1] - y.shape[-1]) // 2
+        if pad > 0:
+            x = x[..., pad:-pad]
+        return x + y
+
+
+class EncoderBlock(nn.Module):
+    def __init__(self, dim: int = 16, stride: int = 1):
+        super().__init__()
+        self.block = nn.Sequential(
+            ResidualUnit(dim // 2, dilation=1),
+            ResidualUnit(dim // 2, dilation=3),
+            ResidualUnit(dim // 2, dilation=9),
+            Snake1d(dim // 2),
+            WNConv1d(
+                dim // 2,
+                dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        d_model: int = 64,
+        strides: list = [2, 4, 8, 8],
+        d_latent: int = 64,
+    ):
+        super().__init__()
+        # Create first convolution
+        self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+        # Create EncoderBlocks that double channels as they downsample by `stride`
+        for stride in strides:
+            d_model *= 2
+            self.block += [EncoderBlock(d_model, stride=stride)]
+
+        # Create last convolution
+        self.block += [
+            Snake1d(d_model),
+            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
+        ]
+
+        # Wrap black into nn.Sequential
+        self.block = nn.Sequential(*self.block)
+        self.enc_dim = d_model
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+        super().__init__()
+        self.block = nn.Sequential(
+            Snake1d(input_dim),
+            WNConvTranspose1d(
+                input_dim,
+                output_dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+            ResidualUnit(output_dim, dilation=1),
+            ResidualUnit(output_dim, dilation=3),
+            ResidualUnit(output_dim, dilation=9),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        input_channel,
+        channels,
+        rates,
+        d_out: int = 1,
+    ):
+        super().__init__()
+
+        # Add first conv layer
+        layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+        # Add upsampling + MRF blocks
+        for i, stride in enumerate(rates):
+            input_dim = channels // 2**i
+            output_dim = channels // 2 ** (i + 1)
+            layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+        # Add final conv layer
+        layers += [
+            Snake1d(output_dim),
+            WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+            nn.Tanh(),
+        ]
+
+        self.model = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+    def __init__(
+        self,
+        encoder_dim: int = 64,
+        encoder_rates: List[int] = [2, 4, 8, 8],
+        latent_dim: int = None,
+        decoder_dim: int = 1536,
+        decoder_rates: List[int] = [8, 8, 4, 2],
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: bool = False,
+        sample_rate: int = 44100,
+    ):
+        super().__init__()
+
+        self.encoder_dim = encoder_dim
+        self.encoder_rates = encoder_rates
+        self.decoder_dim = decoder_dim
+        self.decoder_rates = decoder_rates
+        self.sample_rate = sample_rate
+
+        if latent_dim is None:
+            latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+        self.latent_dim = latent_dim
+
+        self.hop_length = np.prod(encoder_rates)
+        self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
+
+        self.n_codebooks = n_codebooks
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+        self.quantizer = ResidualVectorQuantize(
+            input_dim=latent_dim,
+            n_codebooks=n_codebooks,
+            codebook_size=codebook_size,
+            codebook_dim=codebook_dim,
+            quantizer_dropout=quantizer_dropout,
+        )
+
+        self.decoder = Decoder(
+            latent_dim,
+            decoder_dim,
+            decoder_rates,
+        )
+        self.sample_rate = sample_rate
+        self.apply(init_weights)
+
+        self.delay = self.get_delay()
+
+    def preprocess(self, audio_data, sample_rate):
+        if sample_rate is None:
+            sample_rate = self.sample_rate
+        assert sample_rate == self.sample_rate
+
+        length = audio_data.shape[-1]
+        right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+        audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+        return audio_data
+
+    def encode(
+        self,
+        audio_data: torch.Tensor,
+        n_quantizers: int = None,
+    ):
+        """Encode given audio data and return quantized latent codes
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+            "length" : int
+                Number of samples in input audio
+        """
+        z = self.encoder(audio_data)
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+            z, n_quantizers
+        )
+        return z, codes, latents, commitment_loss, codebook_loss
+
+    def decode(self, z: torch.Tensor):
+        """Decode given latent codes and return audio data
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+            Quantized continuous representation of input
+        length : int, optional
+            Number of samples in output audio, by default None
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "audio" : Tensor[B x 1 x length]
+                Decoded audio data.
+        """
+        return self.decoder(z)
+
+    def forward(
+        self,
+        audio_data: torch.Tensor,
+        sample_rate: int = None,
+        n_quantizers: int = None,
+    ):
+        """Model forward pass
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        sample_rate : int, optional
+            Sample rate of audio data in Hz, by default None
+            If None, defaults to `self.sample_rate`
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None.
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+            "length" : int
+                Number of samples in input audio
+            "audio" : Tensor[B x 1 x length]
+                Decoded audio data.
+        """
+        length = audio_data.shape[-1]
+        audio_data = self.preprocess(audio_data, sample_rate)
+        z, codes, latents, commitment_loss, codebook_loss = self.encode(
+            audio_data, n_quantizers
+        )
+
+        x = self.decode(z)
+        return {
+            "audio": x[..., :length],
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+
+if __name__ == "__main__":
+    import numpy as np
+    from functools import partial
+
+    model = DAC().to("cpu")
+
+    for n, m in model.named_modules():
+        o = m.extra_repr()
+        p = sum([np.prod(p.size()) for p in m.parameters()])
+        fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+        setattr(m, "extra_repr", partial(fn, o=o, p=p))
+    print(model)
+    print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+    length = 88200 * 2
+    x = torch.randn(1, 1, length).to(model.device)
+    x.requires_grad_(True)
+    x.retain_grad()
+
+    # Make a forward pass
+    out = model(x)["audio"]
+    print("Input shape:", x.shape)
+    print("Output shape:", out.shape)
+
+    # Create gradient variable
+    grad = torch.zeros_like(out)
+    grad[:, :, grad.shape[-1] // 2] = 1
+
+    # Make a backward pass
+    out.backward(grad)
+
+    # Check non-zero values
+    gradmap = x.grad.squeeze(0)
+    gradmap = (gradmap != 0).sum(0)  # sum across features
+    rf = (gradmap != 0).sum()
+
+    print(f"Receptive field: {rf.item()}")
+
+    x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
+    model.decompress(model.compress(x, verbose=True), verbose=True)
diff --git a/src/modules/dac/model/__init__.py b/src/modules/dac/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf
--- /dev/null
+++ b/src/modules/dac/model/__init__.py
@@ -0,0 +1,4 @@
+from .base import CodecMixin
+from .base import DACFile
+from .dac import DAC
+from .discriminator import Discriminator
diff --git a/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4f412ea22c0e7baf9a7fa637c96a7d84dda476b
Binary files /dev/null and b/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c7e138a9f5631558eb4dbde2f623bf5b2863c03
Binary files /dev/null and b/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/dac/model/__pycache__/base.cpython-310.pyc b/src/modules/dac/model/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7dc58569bd1c938c3f478449d3a8636fcdfc4ad
Binary files /dev/null and b/src/modules/dac/model/__pycache__/base.cpython-310.pyc differ
diff --git a/src/modules/dac/model/__pycache__/base.cpython-311.pyc b/src/modules/dac/model/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b89297ea647f3e03add49e7d8c9bb11f46fe3bad
Binary files /dev/null and b/src/modules/dac/model/__pycache__/base.cpython-311.pyc differ
diff --git a/src/modules/dac/model/__pycache__/dac.cpython-310.pyc b/src/modules/dac/model/__pycache__/dac.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66adf0ce2b256c971fca5aedb051c1e9af4c4b1e
Binary files /dev/null and b/src/modules/dac/model/__pycache__/dac.cpython-310.pyc differ
diff --git a/src/modules/dac/model/__pycache__/dac.cpython-311.pyc b/src/modules/dac/model/__pycache__/dac.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0608cb168829bf41f769cd0f6b2923f7f421b17b
Binary files /dev/null and b/src/modules/dac/model/__pycache__/dac.cpython-311.pyc differ
diff --git a/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc b/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1af4a16e2b17cdf8620db8a033ebcee17c2ce516
Binary files /dev/null and b/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc differ
diff --git a/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc b/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..638e7e8901d6e40b105363abc2743bf326acb661
Binary files /dev/null and b/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc differ
diff --git a/src/modules/dac/model/base.py b/src/modules/dac/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..546b3cb7092d6bd1837ec780228d2a5b3e01fe8d
--- /dev/null
+++ b/src/modules/dac/model/base.py
@@ -0,0 +1,294 @@
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import torch
+import tqdm
+from audiotools import AudioSignal
+from torch import nn
+
+SUPPORTED_VERSIONS = ["1.0.0"]
+
+
+@dataclass
+class DACFile:
+    codes: torch.Tensor
+
+    # Metadata
+    chunk_length: int
+    original_length: int
+    input_db: float
+    channels: int
+    sample_rate: int
+    padding: bool
+    dac_version: str
+
+    def save(self, path):
+        artifacts = {
+            "codes": self.codes.numpy().astype(np.uint16),
+            "metadata": {
+                "input_db": self.input_db.numpy().astype(np.float32),
+                "original_length": self.original_length,
+                "sample_rate": self.sample_rate,
+                "chunk_length": self.chunk_length,
+                "channels": self.channels,
+                "padding": self.padding,
+                "dac_version": SUPPORTED_VERSIONS[-1],
+            },
+        }
+        path = Path(path).with_suffix(".dac")
+        with open(path, "wb") as f:
+            np.save(f, artifacts)
+        return path
+
+    @classmethod
+    def load(cls, path):
+        artifacts = np.load(path, allow_pickle=True)[()]
+        codes = torch.from_numpy(artifacts["codes"].astype(int))
+        if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
+            raise RuntimeError(
+                f"Given file {path} can't be loaded with this version of descript-audio-codec."
+            )
+        return cls(codes=codes, **artifacts["metadata"])
+
+
+class CodecMixin:
+    @property
+    def padding(self):
+        if not hasattr(self, "_padding"):
+            self._padding = True
+        return self._padding
+
+    @padding.setter
+    def padding(self, value):
+        assert isinstance(value, bool)
+
+        layers = [
+            l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
+        ]
+
+        for layer in layers:
+            if value:
+                if hasattr(layer, "original_padding"):
+                    layer.padding = layer.original_padding
+            else:
+                layer.original_padding = layer.padding
+                layer.padding = tuple(0 for _ in range(len(layer.padding)))
+
+        self._padding = value
+
+    def get_delay(self):
+        # Any number works here, delay is invariant to input length
+        l_out = self.get_output_length(0)
+        L = l_out
+
+        layers = []
+        for layer in self.modules():
+            if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+                layers.append(layer)
+
+        for layer in reversed(layers):
+            d = layer.dilation[0]
+            k = layer.kernel_size[0]
+            s = layer.stride[0]
+
+            if isinstance(layer, nn.ConvTranspose1d):
+                L = ((L - d * (k - 1) - 1) / s) + 1
+            elif isinstance(layer, nn.Conv1d):
+                L = (L - 1) * s + d * (k - 1) + 1
+
+            L = math.ceil(L)
+
+        l_in = L
+
+        return (l_in - l_out) // 2
+
+    def get_output_length(self, input_length):
+        L = input_length
+        # Calculate output length
+        for layer in self.modules():
+            if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+                d = layer.dilation[0]
+                k = layer.kernel_size[0]
+                s = layer.stride[0]
+
+                if isinstance(layer, nn.Conv1d):
+                    L = ((L - d * (k - 1) - 1) / s) + 1
+                elif isinstance(layer, nn.ConvTranspose1d):
+                    L = (L - 1) * s + d * (k - 1) + 1
+
+                L = math.floor(L)
+        return L
+
+    @torch.no_grad()
+    def compress(
+        self,
+        audio_path_or_signal: Union[str, Path, AudioSignal],
+        win_duration: float = 1.0,
+        verbose: bool = False,
+        normalize_db: float = -16,
+        n_quantizers: int = None,
+    ) -> DACFile:
+        """Processes an audio signal from a file or AudioSignal object into
+        discrete codes. This function processes the signal in short windows,
+        using constant GPU memory.
+
+        Parameters
+        ----------
+        audio_path_or_signal : Union[str, Path, AudioSignal]
+            audio signal to reconstruct
+        win_duration : float, optional
+            window duration in seconds, by default 5.0
+        verbose : bool, optional
+            by default False
+        normalize_db : float, optional
+            normalize db, by default -16
+
+        Returns
+        -------
+        DACFile
+            Object containing compressed codes and metadata
+            required for decompression
+        """
+        audio_signal = audio_path_or_signal
+        if isinstance(audio_signal, (str, Path)):
+            audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
+
+        self.eval()
+        original_padding = self.padding
+        original_device = audio_signal.device
+
+        audio_signal = audio_signal.clone()
+        original_sr = audio_signal.sample_rate
+
+        resample_fn = audio_signal.resample
+        loudness_fn = audio_signal.loudness
+
+        # If audio is > 10 minutes long, use the ffmpeg versions
+        if audio_signal.signal_duration >= 10 * 60 * 60:
+            resample_fn = audio_signal.ffmpeg_resample
+            loudness_fn = audio_signal.ffmpeg_loudness
+
+        original_length = audio_signal.signal_length
+        resample_fn(self.sample_rate)
+        input_db = loudness_fn()
+
+        if normalize_db is not None:
+            audio_signal.normalize(normalize_db)
+        audio_signal.ensure_max_of_audio()
+
+        nb, nac, nt = audio_signal.audio_data.shape
+        audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
+        win_duration = (
+            audio_signal.signal_duration if win_duration is None else win_duration
+        )
+
+        if audio_signal.signal_duration <= win_duration:
+            # Unchunked compression (used if signal length < win duration)
+            self.padding = True
+            n_samples = nt
+            hop = nt
+        else:
+            # Chunked inference
+            self.padding = False
+            # Zero-pad signal on either side by the delay
+            audio_signal.zero_pad(self.delay, self.delay)
+            n_samples = int(win_duration * self.sample_rate)
+            # Round n_samples to nearest hop length multiple
+            n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
+            hop = self.get_output_length(n_samples)
+
+        codes = []
+        range_fn = range if not verbose else tqdm.trange
+
+        for i in range_fn(0, nt, hop):
+            x = audio_signal[..., i : i + n_samples]
+            x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
+
+            audio_data = x.audio_data.to(self.device)
+            audio_data = self.preprocess(audio_data, self.sample_rate)
+            _, c, _, _, _ = self.encode(audio_data, n_quantizers)
+            codes.append(c.to(original_device))
+            chunk_length = c.shape[-1]
+
+        codes = torch.cat(codes, dim=-1)
+
+        dac_file = DACFile(
+            codes=codes,
+            chunk_length=chunk_length,
+            original_length=original_length,
+            input_db=input_db,
+            channels=nac,
+            sample_rate=original_sr,
+            padding=self.padding,
+            dac_version=SUPPORTED_VERSIONS[-1],
+        )
+
+        if n_quantizers is not None:
+            codes = codes[:, :n_quantizers, :]
+
+        self.padding = original_padding
+        return dac_file
+
+    @torch.no_grad()
+    def decompress(
+        self,
+        obj: Union[str, Path, DACFile],
+        verbose: bool = False,
+    ) -> AudioSignal:
+        """Reconstruct audio from a given .dac file
+
+        Parameters
+        ----------
+        obj : Union[str, Path, DACFile]
+            .dac file location or corresponding DACFile object.
+        verbose : bool, optional
+            Prints progress if True, by default False
+
+        Returns
+        -------
+        AudioSignal
+            Object with the reconstructed audio
+        """
+        self.eval()
+        if isinstance(obj, (str, Path)):
+            obj = DACFile.load(obj)
+
+        original_padding = self.padding
+        self.padding = obj.padding
+
+        range_fn = range if not verbose else tqdm.trange
+        codes = obj.codes
+        original_device = codes.device
+        chunk_length = obj.chunk_length
+        recons = []
+
+        for i in range_fn(0, codes.shape[-1], chunk_length):
+            c = codes[..., i : i + chunk_length].to(self.device)
+            z = self.quantizer.from_codes(c)[0]
+            r = self.decode(z)
+            recons.append(r.to(original_device))
+
+        recons = torch.cat(recons, dim=-1)
+        recons = AudioSignal(recons, self.sample_rate)
+
+        resample_fn = recons.resample
+        loudness_fn = recons.loudness
+
+        # If audio is > 10 minutes long, use the ffmpeg versions
+        if recons.signal_duration >= 10 * 60 * 60:
+            resample_fn = recons.ffmpeg_resample
+            loudness_fn = recons.ffmpeg_loudness
+
+        recons.normalize(obj.input_db)
+        resample_fn(obj.sample_rate)
+        recons = recons[..., : obj.original_length]
+        loudness_fn()
+        recons.audio_data = recons.audio_data.reshape(
+            -1, obj.channels, obj.original_length
+        )
+
+        self.padding = original_padding
+        return recons
diff --git a/src/modules/dac/model/dac.py b/src/modules/dac/model/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d44a18a9e98fdcce9377a744b6f9d7dfa6a607b
--- /dev/null
+++ b/src/modules/dac/model/dac.py
@@ -0,0 +1,364 @@
+import math
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from torch import nn
+
+from .base import CodecMixin
+from ..nn.layers import Snake1d
+from ..nn.layers import WNConv1d
+from ..nn.layers import WNConvTranspose1d
+from ..nn.quantize import ResidualVectorQuantize
+
+
+def init_weights(m):
+    if isinstance(m, nn.Conv1d):
+        nn.init.trunc_normal_(m.weight, std=0.02)
+        nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+    def __init__(self, dim: int = 16, dilation: int = 1):
+        super().__init__()
+        pad = ((7 - 1) * dilation) // 2
+        self.block = nn.Sequential(
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=1),
+        )
+
+    def forward(self, x):
+        y = self.block(x)
+        pad = (x.shape[-1] - y.shape[-1]) // 2
+        if pad > 0:
+            x = x[..., pad:-pad]
+        return x + y
+
+
+class EncoderBlock(nn.Module):
+    def __init__(self, dim: int = 16, stride: int = 1):
+        super().__init__()
+        self.block = nn.Sequential(
+            ResidualUnit(dim // 2, dilation=1),
+            ResidualUnit(dim // 2, dilation=3),
+            ResidualUnit(dim // 2, dilation=9),
+            Snake1d(dim // 2),
+            WNConv1d(
+                dim // 2,
+                dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        d_model: int = 64,
+        strides: list = [2, 4, 8, 8],
+        d_latent: int = 64,
+    ):
+        super().__init__()
+        # Create first convolution
+        self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+        # Create EncoderBlocks that double channels as they downsample by `stride`
+        for stride in strides:
+            d_model *= 2
+            self.block += [EncoderBlock(d_model, stride=stride)]
+
+        # Create last convolution
+        self.block += [
+            Snake1d(d_model),
+            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
+        ]
+
+        # Wrap black into nn.Sequential
+        self.block = nn.Sequential(*self.block)
+        self.enc_dim = d_model
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+        super().__init__()
+        self.block = nn.Sequential(
+            Snake1d(input_dim),
+            WNConvTranspose1d(
+                input_dim,
+                output_dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+            ResidualUnit(output_dim, dilation=1),
+            ResidualUnit(output_dim, dilation=3),
+            ResidualUnit(output_dim, dilation=9),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        input_channel,
+        channels,
+        rates,
+        d_out: int = 1,
+    ):
+        super().__init__()
+
+        # Add first conv layer
+        layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+        # Add upsampling + MRF blocks
+        for i, stride in enumerate(rates):
+            input_dim = channels // 2**i
+            output_dim = channels // 2 ** (i + 1)
+            layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+        # Add final conv layer
+        layers += [
+            Snake1d(output_dim),
+            WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+            nn.Tanh(),
+        ]
+
+        self.model = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+    def __init__(
+        self,
+        encoder_dim: int = 64,
+        encoder_rates: List[int] = [2, 4, 8, 8],
+        latent_dim: int = None,
+        decoder_dim: int = 1536,
+        decoder_rates: List[int] = [8, 8, 4, 2],
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: bool = False,
+        sample_rate: int = 44100,
+    ):
+        super().__init__()
+
+        self.encoder_dim = encoder_dim
+        self.encoder_rates = encoder_rates
+        self.decoder_dim = decoder_dim
+        self.decoder_rates = decoder_rates
+        self.sample_rate = sample_rate
+
+        if latent_dim is None:
+            latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+        self.latent_dim = latent_dim
+
+        self.hop_length = np.prod(encoder_rates)
+        self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
+
+        self.n_codebooks = n_codebooks
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+        self.quantizer = ResidualVectorQuantize(
+            input_dim=latent_dim,
+            n_codebooks=n_codebooks,
+            codebook_size=codebook_size,
+            codebook_dim=codebook_dim,
+            quantizer_dropout=quantizer_dropout,
+        )
+
+        self.decoder = Decoder(
+            latent_dim,
+            decoder_dim,
+            decoder_rates,
+        )
+        self.sample_rate = sample_rate
+        self.apply(init_weights)
+
+        self.delay = self.get_delay()
+
+    def preprocess(self, audio_data, sample_rate):
+        if sample_rate is None:
+            sample_rate = self.sample_rate
+        assert sample_rate == self.sample_rate
+
+        length = audio_data.shape[-1]
+        right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+        audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+        return audio_data
+
+    def encode(
+        self,
+        audio_data: torch.Tensor,
+        n_quantizers: int = None,
+    ):
+        """Encode given audio data and return quantized latent codes
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+            "length" : int
+                Number of samples in input audio
+        """
+        z = self.encoder(audio_data)
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+            z, n_quantizers
+        )
+        return z, codes, latents, commitment_loss, codebook_loss
+
+    def decode(self, z: torch.Tensor):
+        """Decode given latent codes and return audio data
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+            Quantized continuous representation of input
+        length : int, optional
+            Number of samples in output audio, by default None
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "audio" : Tensor[B x 1 x length]
+                Decoded audio data.
+        """
+        return self.decoder(z)
+
+    def forward(
+        self,
+        audio_data: torch.Tensor,
+        sample_rate: int = None,
+        n_quantizers: int = None,
+    ):
+        """Model forward pass
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        sample_rate : int, optional
+            Sample rate of audio data in Hz, by default None
+            If None, defaults to `self.sample_rate`
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None.
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+            "length" : int
+                Number of samples in input audio
+            "audio" : Tensor[B x 1 x length]
+                Decoded audio data.
+        """
+        length = audio_data.shape[-1]
+        audio_data = self.preprocess(audio_data, sample_rate)
+        z, codes, latents, commitment_loss, codebook_loss = self.encode(
+            audio_data, n_quantizers
+        )
+
+        x = self.decode(z)
+        return {
+            "audio": x[..., :length],
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+
+if __name__ == "__main__":
+    import numpy as np
+    from functools import partial
+
+    model = DAC().to("cpu")
+
+    for n, m in model.named_modules():
+        o = m.extra_repr()
+        p = sum([np.prod(p.size()) for p in m.parameters()])
+        fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+        setattr(m, "extra_repr", partial(fn, o=o, p=p))
+    print(model)
+    print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+    length = 88200 * 2
+    x = torch.randn(1, 1, length).to(model.device)
+    x.requires_grad_(True)
+    x.retain_grad()
+
+    # Make a forward pass
+    out = model(x)["audio"]
+    print("Input shape:", x.shape)
+    print("Output shape:", out.shape)
+
+    # Create gradient variable
+    grad = torch.zeros_like(out)
+    grad[:, :, grad.shape[-1] // 2] = 1
+
+    # Make a backward pass
+    out.backward(grad)
+
+    # Check non-zero values
+    gradmap = x.grad.squeeze(0)
+    gradmap = (gradmap != 0).sum(0)  # sum across features
+    rf = (gradmap != 0).sum()
+
+    print(f"Receptive field: {rf.item()}")
+
+    x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
+    model.decompress(model.compress(x, verbose=True), verbose=True)
diff --git a/src/modules/dac/model/discriminator.py b/src/modules/dac/model/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c79d1342ca46bef21daca64667577f05e61638
--- /dev/null
+++ b/src/modules/dac/model/discriminator.py
@@ -0,0 +1,228 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import ml
+from audiotools import STFTParams
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+    act = kwargs.pop("act", True)
+    conv = weight_norm(nn.Conv1d(*args, **kwargs))
+    if not act:
+        return conv
+    return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+def WNConv2d(*args, **kwargs):
+    act = kwargs.pop("act", True)
+    conv = weight_norm(nn.Conv2d(*args, **kwargs))
+    if not act:
+        return conv
+    return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+class MPD(nn.Module):
+    def __init__(self, period):
+        super().__init__()
+        self.period = period
+        self.convs = nn.ModuleList(
+            [
+                WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
+                WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+                WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+                WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+                WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+            ]
+        )
+        self.conv_post = WNConv2d(
+            1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+        )
+
+    def pad_to_period(self, x):
+        t = x.shape[-1]
+        x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+        return x
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.pad_to_period(x)
+        x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+        for layer in self.convs:
+            x = layer(x)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+
+        return fmap
+
+
+class MSD(nn.Module):
+    def __init__(self, rate: int = 1, sample_rate: int = 44100):
+        super().__init__()
+        self.convs = nn.ModuleList(
+            [
+                WNConv1d(1, 16, 15, 1, padding=7),
+                WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+                WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+                WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+                WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+                WNConv1d(1024, 1024, 5, 1, padding=2),
+            ]
+        )
+        self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+        self.sample_rate = sample_rate
+        self.rate = rate
+
+    def forward(self, x):
+        x = AudioSignal(x, self.sample_rate)
+        x.resample(self.sample_rate // self.rate)
+        x = x.audio_data
+
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+
+        return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+    def __init__(
+        self,
+        window_length: int,
+        hop_factor: float = 0.25,
+        sample_rate: int = 44100,
+        bands: list = BANDS,
+    ):
+        """Complex multi-band spectrogram discriminator.
+        Parameters
+        ----------
+        window_length : int
+            Window length of STFT.
+        hop_factor : float, optional
+            Hop factor of the STFT, defaults to ``0.25 * window_length``.
+        sample_rate : int, optional
+            Sampling rate of audio in Hz, by default 44100
+        bands : list, optional
+            Bands to run discriminator over.
+        """
+        super().__init__()
+
+        self.window_length = window_length
+        self.hop_factor = hop_factor
+        self.sample_rate = sample_rate
+        self.stft_params = STFTParams(
+            window_length=window_length,
+            hop_length=int(window_length * hop_factor),
+            match_stride=True,
+        )
+
+        n_fft = window_length // 2 + 1
+        bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+        self.bands = bands
+
+        ch = 32
+        convs = lambda: nn.ModuleList(
+            [
+                WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+                WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+                WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+                WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+                WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+            ]
+        )
+        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+        self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+    def spectrogram(self, x):
+        x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+        x = torch.view_as_real(x.stft())
+        x = rearrange(x, "b 1 f t c -> (b 1) c t f")
+        # Split into bands
+        x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+        return x_bands
+
+    def forward(self, x):
+        x_bands = self.spectrogram(x)
+        fmap = []
+
+        x = []
+        for band, stack in zip(x_bands, self.band_convs):
+            for layer in stack:
+                band = layer(band)
+                fmap.append(band)
+            x.append(band)
+
+        x = torch.cat(x, dim=-1)
+        x = self.conv_post(x)
+        fmap.append(x)
+
+        return fmap
+
+
+class Discriminator(ml.BaseModel):
+    def __init__(
+        self,
+        rates: list = [],
+        periods: list = [2, 3, 5, 7, 11],
+        fft_sizes: list = [2048, 1024, 512],
+        sample_rate: int = 44100,
+        bands: list = BANDS,
+    ):
+        """Discriminator that combines multiple discriminators.
+
+        Parameters
+        ----------
+        rates : list, optional
+            sampling rates (in Hz) to run MSD at, by default []
+            If empty, MSD is not used.
+        periods : list, optional
+            periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+        fft_sizes : list, optional
+            Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+        sample_rate : int, optional
+            Sampling rate of audio in Hz, by default 44100
+        bands : list, optional
+            Bands to run MRD at, by default `BANDS`
+        """
+        super().__init__()
+        discs = []
+        discs += [MPD(p) for p in periods]
+        discs += [MSD(r, sample_rate=sample_rate) for r in rates]
+        discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
+        self.discriminators = nn.ModuleList(discs)
+
+    def preprocess(self, y):
+        # Remove DC offset
+        y = y - y.mean(dim=-1, keepdims=True)
+        # Peak normalize the volume of input audio
+        y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+        return y
+
+    def forward(self, x):
+        x = self.preprocess(x)
+        fmaps = [d(x) for d in self.discriminators]
+        return fmaps
+
+
+if __name__ == "__main__":
+    disc = Discriminator()
+    x = torch.zeros(1, 1, 44100)
+    results = disc(x)
+    for i, result in enumerate(results):
+        print(f"disc{i}")
+        for i, r in enumerate(result):
+            print(r.shape, r.mean(), r.min(), r.max())
+        print()
diff --git a/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7
--- /dev/null
+++ b/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94
--- /dev/null
+++ b/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+    shape = x.shape
+    x = x.reshape(shape[0], shape[1], -1)
+    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+    x = x.reshape(shape)
+    return x
+
+
+class Snake1d(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+    def forward(self, x):
+        return snake(x, self.alpha)
diff --git a/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b
--- /dev/null
+++ b/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+    """L1 Loss between AudioSignals. Defaults
+    to comparing ``audio_data``, but any
+    attribute of an AudioSignal can be used.
+
+    Parameters
+    ----------
+    attribute : str, optional
+        Attribute of signal to compare, defaults to ``audio_data``.
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+        self.attribute = attribute
+        self.weight = weight
+        super().__init__(**kwargs)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate AudioSignal
+        y : AudioSignal
+            Reference AudioSignal
+
+        Returns
+        -------
+        torch.Tensor
+            L1 loss between AudioSignal attributes.
+        """
+        if isinstance(x, AudioSignal):
+            x = getattr(x, self.attribute)
+            y = getattr(y, self.attribute)
+        return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+    """
+    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+    of estimated and reference audio signals or aligned features.
+
+    Parameters
+    ----------
+    scaling : int, optional
+        Whether to use scale-invariant (True) or
+        signal-to-noise ratio (False), by default True
+    reduction : str, optional
+        How to reduce across the batch (either 'mean',
+        'sum', or none).], by default ' mean'
+    zero_mean : int, optional
+        Zero mean the references and estimates before
+        computing the loss, by default True
+    clip_min : int, optional
+        The minimum possible loss value. Helps network
+        to not focus on making already good examples better, by default None
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(
+        self,
+        scaling: int = True,
+        reduction: str = "mean",
+        zero_mean: int = True,
+        clip_min: int = None,
+        weight: float = 1.0,
+    ):
+        self.scaling = scaling
+        self.reduction = reduction
+        self.zero_mean = zero_mean
+        self.clip_min = clip_min
+        self.weight = weight
+        super().__init__()
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        eps = 1e-8
+        # nb, nc, nt
+        if isinstance(x, AudioSignal):
+            references = x.audio_data
+            estimates = y.audio_data
+        else:
+            references = x
+            estimates = y
+
+        nb = references.shape[0]
+        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+        # samples now on axis 1
+        if self.zero_mean:
+            mean_reference = references.mean(dim=1, keepdim=True)
+            mean_estimate = estimates.mean(dim=1, keepdim=True)
+        else:
+            mean_reference = 0
+            mean_estimate = 0
+
+        _references = references - mean_reference
+        _estimates = estimates - mean_estimate
+
+        references_projection = (_references**2).sum(dim=-2) + eps
+        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+        scale = (
+            (references_on_estimates / references_projection).unsqueeze(1)
+            if self.scaling
+            else 1
+        )
+
+        e_true = scale * _references
+        e_res = _estimates - e_true
+
+        signal = (e_true**2).sum(dim=1)
+        noise = (e_res**2).sum(dim=1)
+        sdr = -10 * torch.log10(signal / noise + eps)
+
+        if self.clip_min is not None:
+            sdr = torch.clamp(sdr, min=self.clip_min)
+
+        if self.reduction == "mean":
+            sdr = sdr.mean()
+        elif self.reduction == "sum":
+            sdr = sdr.sum()
+        return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+    """Computes the multi-scale STFT loss from [1].
+
+    Parameters
+    ----------
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    References
+    ----------
+
+    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
+        "DDSP: Differentiable Digital Signal Processing."
+        International Conference on Learning Representations. 2019.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.loss_fn = loss_fn
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.clamp_eps = clamp_eps
+        self.weight = weight
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes multi-scale STFT between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Multi-scale STFT loss.
+        """
+        loss = 0.0
+        for s in self.stft_params:
+            x.stft(s.window_length, s.hop_length, s.window_type)
+            y.stft(s.window_length, s.hop_length, s.window_type)
+            loss += self.log_weight * self.loss_fn(
+                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+        return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+    """Compute distance between mel spectrograms. Can be used
+    in a multi-scale way.
+
+    Parameters
+    ----------
+    n_mels : List[int]
+        Number of mels per STFT, by default [150, 80],
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        n_mels: List[int] = [150, 80],
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        mel_fmin: List[float] = [0.0, 0.0],
+        mel_fmax: List[float] = [None, None],
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.n_mels = n_mels
+        self.loss_fn = loss_fn
+        self.clamp_eps = clamp_eps
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.weight = weight
+        self.mel_fmin = mel_fmin
+        self.mel_fmax = mel_fmax
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes mel loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Mel loss.
+        """
+        loss = 0.0
+        for n_mels, fmin, fmax, s in zip(
+            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+        ):
+            kwargs = {
+                "window_length": s.window_length,
+                "hop_length": s.hop_length,
+                "window_type": s.window_type,
+            }
+            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+            loss += self.log_weight * self.loss_fn(
+                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+        return loss
+
+
+class GANLoss(nn.Module):
+    """
+    Computes a discriminator loss, given a discriminator on
+    generated waveforms/spectrograms compared to ground truth
+    waveforms/spectrograms. Computes the loss for both the
+    discriminator and the generator in separate functions.
+    """
+
+    def __init__(self, discriminator):
+        super().__init__()
+        self.discriminator = discriminator
+
+    def forward(self, fake, real):
+        d_fake = self.discriminator(fake.audio_data)
+        d_real = self.discriminator(real.audio_data)
+        return d_fake, d_real
+
+    def discriminator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+        loss_d = 0
+        for x_fake, x_real in zip(d_fake, d_real):
+            loss_d += torch.mean(x_fake[-1] ** 2)
+            loss_d += torch.mean((1 - x_real[-1]) ** 2)
+        return loss_d
+
+    def generator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake, real)
+
+        loss_g = 0
+        for x_fake in d_fake:
+            loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+        loss_feature = 0
+
+        for i in range(len(d_fake)):
+            for j in range(len(d_fake[i]) - 1):
+                loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+        return loss_g, loss_feature
diff --git a/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927
--- /dev/null
+++ b/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+    """
+    Implementation of VQ similar to Karpathy's repo:
+    https://github.com/karpathy/deep-vector-quantization
+    Additionally uses following tricks from Improved VQGAN
+    (https://arxiv.org/pdf/2110.04627.pdf):
+        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+            for improved codebook usage
+        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+            improves training stability
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id):
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents):
+        encodings = rearrange(latents, "b d t -> (b t) d")
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+    ):
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+
+    def forward(self, z, n_quantizers: int = None):
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+                residual
+            )
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return z_q, codes, latents, commitment_loss, codebook_loss
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+            0
+        ]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+    rvq = ResidualVectorQuantize(quantizer_dropout=True)
+    x = torch.randn(16, 512, 80)
+    y = rvq(x)
+    print(y["latents"].shape)
diff --git a/src/modules/dac/nn/__init__.py b/src/modules/dac/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7
--- /dev/null
+++ b/src/modules/dac/nn/__init__.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f834bc23586af3072ded94d4b593733396f89f4c
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eafea719f716b876b394d0d5fc1fefda77d5ba0d
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc b/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be3e7d3ab89240ade829e5e4e82b4c88cc50f41f
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc b/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c60216140d4846cb38f54cfb2ffef70b16f57c7c
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc b/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4a3ecad3947fe1ac0a7e8058136c5d9c629b369
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc b/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75b0df5a97a40a5493acf935aa0ab2d9a179afd8
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aaa26e8b6bdbb409f10fb8f5f229697197082fab
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc differ
diff --git a/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8737b2fa25ecfe93c80d03713896279f36a047ba
Binary files /dev/null and b/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc differ
diff --git a/src/modules/dac/nn/layers.py b/src/modules/dac/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94
--- /dev/null
+++ b/src/modules/dac/nn/layers.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+    shape = x.shape
+    x = x.reshape(shape[0], shape[1], -1)
+    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+    x = x.reshape(shape)
+    return x
+
+
+class Snake1d(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+    def forward(self, x):
+        return snake(x, self.alpha)
diff --git a/src/modules/dac/nn/loss.py b/src/modules/dac/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b
--- /dev/null
+++ b/src/modules/dac/nn/loss.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+    """L1 Loss between AudioSignals. Defaults
+    to comparing ``audio_data``, but any
+    attribute of an AudioSignal can be used.
+
+    Parameters
+    ----------
+    attribute : str, optional
+        Attribute of signal to compare, defaults to ``audio_data``.
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+        self.attribute = attribute
+        self.weight = weight
+        super().__init__(**kwargs)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate AudioSignal
+        y : AudioSignal
+            Reference AudioSignal
+
+        Returns
+        -------
+        torch.Tensor
+            L1 loss between AudioSignal attributes.
+        """
+        if isinstance(x, AudioSignal):
+            x = getattr(x, self.attribute)
+            y = getattr(y, self.attribute)
+        return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+    """
+    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+    of estimated and reference audio signals or aligned features.
+
+    Parameters
+    ----------
+    scaling : int, optional
+        Whether to use scale-invariant (True) or
+        signal-to-noise ratio (False), by default True
+    reduction : str, optional
+        How to reduce across the batch (either 'mean',
+        'sum', or none).], by default ' mean'
+    zero_mean : int, optional
+        Zero mean the references and estimates before
+        computing the loss, by default True
+    clip_min : int, optional
+        The minimum possible loss value. Helps network
+        to not focus on making already good examples better, by default None
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(
+        self,
+        scaling: int = True,
+        reduction: str = "mean",
+        zero_mean: int = True,
+        clip_min: int = None,
+        weight: float = 1.0,
+    ):
+        self.scaling = scaling
+        self.reduction = reduction
+        self.zero_mean = zero_mean
+        self.clip_min = clip_min
+        self.weight = weight
+        super().__init__()
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        eps = 1e-8
+        # nb, nc, nt
+        if isinstance(x, AudioSignal):
+            references = x.audio_data
+            estimates = y.audio_data
+        else:
+            references = x
+            estimates = y
+
+        nb = references.shape[0]
+        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+        # samples now on axis 1
+        if self.zero_mean:
+            mean_reference = references.mean(dim=1, keepdim=True)
+            mean_estimate = estimates.mean(dim=1, keepdim=True)
+        else:
+            mean_reference = 0
+            mean_estimate = 0
+
+        _references = references - mean_reference
+        _estimates = estimates - mean_estimate
+
+        references_projection = (_references**2).sum(dim=-2) + eps
+        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+        scale = (
+            (references_on_estimates / references_projection).unsqueeze(1)
+            if self.scaling
+            else 1
+        )
+
+        e_true = scale * _references
+        e_res = _estimates - e_true
+
+        signal = (e_true**2).sum(dim=1)
+        noise = (e_res**2).sum(dim=1)
+        sdr = -10 * torch.log10(signal / noise + eps)
+
+        if self.clip_min is not None:
+            sdr = torch.clamp(sdr, min=self.clip_min)
+
+        if self.reduction == "mean":
+            sdr = sdr.mean()
+        elif self.reduction == "sum":
+            sdr = sdr.sum()
+        return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+    """Computes the multi-scale STFT loss from [1].
+
+    Parameters
+    ----------
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    References
+    ----------
+
+    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
+        "DDSP: Differentiable Digital Signal Processing."
+        International Conference on Learning Representations. 2019.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.loss_fn = loss_fn
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.clamp_eps = clamp_eps
+        self.weight = weight
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes multi-scale STFT between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Multi-scale STFT loss.
+        """
+        loss = 0.0
+        for s in self.stft_params:
+            x.stft(s.window_length, s.hop_length, s.window_type)
+            y.stft(s.window_length, s.hop_length, s.window_type)
+            loss += self.log_weight * self.loss_fn(
+                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+        return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+    """Compute distance between mel spectrograms. Can be used
+    in a multi-scale way.
+
+    Parameters
+    ----------
+    n_mels : List[int]
+        Number of mels per STFT, by default [150, 80],
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        n_mels: List[int] = [150, 80],
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        mel_fmin: List[float] = [0.0, 0.0],
+        mel_fmax: List[float] = [None, None],
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.n_mels = n_mels
+        self.loss_fn = loss_fn
+        self.clamp_eps = clamp_eps
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.weight = weight
+        self.mel_fmin = mel_fmin
+        self.mel_fmax = mel_fmax
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes mel loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Mel loss.
+        """
+        loss = 0.0
+        for n_mels, fmin, fmax, s in zip(
+            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+        ):
+            kwargs = {
+                "window_length": s.window_length,
+                "hop_length": s.hop_length,
+                "window_type": s.window_type,
+            }
+            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+            loss += self.log_weight * self.loss_fn(
+                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+        return loss
+
+
+class GANLoss(nn.Module):
+    """
+    Computes a discriminator loss, given a discriminator on
+    generated waveforms/spectrograms compared to ground truth
+    waveforms/spectrograms. Computes the loss for both the
+    discriminator and the generator in separate functions.
+    """
+
+    def __init__(self, discriminator):
+        super().__init__()
+        self.discriminator = discriminator
+
+    def forward(self, fake, real):
+        d_fake = self.discriminator(fake.audio_data)
+        d_real = self.discriminator(real.audio_data)
+        return d_fake, d_real
+
+    def discriminator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+        loss_d = 0
+        for x_fake, x_real in zip(d_fake, d_real):
+            loss_d += torch.mean(x_fake[-1] ** 2)
+            loss_d += torch.mean((1 - x_real[-1]) ** 2)
+        return loss_d
+
+    def generator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake, real)
+
+        loss_g = 0
+        for x_fake in d_fake:
+            loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+        loss_feature = 0
+
+        for i in range(len(d_fake)):
+            for j in range(len(d_fake[i]) - 1):
+                loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+        return loss_g, loss_feature
diff --git a/src/modules/dac/nn/quantize.py b/src/modules/dac/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927
--- /dev/null
+++ b/src/modules/dac/nn/quantize.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+    """
+    Implementation of VQ similar to Karpathy's repo:
+    https://github.com/karpathy/deep-vector-quantization
+    Additionally uses following tricks from Improved VQGAN
+    (https://arxiv.org/pdf/2110.04627.pdf):
+        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+            for improved codebook usage
+        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+            improves training stability
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id):
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents):
+        encodings = rearrange(latents, "b d t -> (b t) d")
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+    ):
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+
+    def forward(self, z, n_quantizers: int = None):
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+                residual
+            )
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return z_q, codes, latents, commitment_loss, codebook_loss
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+            0
+        ]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+    rvq = ResidualVectorQuantize(quantizer_dropout=True)
+    x = torch.randn(16, 512, 80)
+    y = rvq(x)
+    print(y["latents"].shape)
diff --git a/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fbd70cf223b04135af71c4b322e1a92431d6ca
--- /dev/null
+++ b/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,122 @@
+from pathlib import Path
+
+import argbind
+from audiotools import ml
+
+from ..model import DAC
+
+Accelerator = ml.Accelerator
+
+__MODEL_LATEST_TAGS__ = {
+    ("44khz", "8kbps"): "0.0.1",
+    ("24khz", "8kbps"): "0.0.4",
+    ("16khz", "8kbps"): "0.0.5",
+    ("44khz", "16kbps"): "1.0.0",
+}
+
+__MODEL_URLS__ = {
+    (
+        "44khz",
+        "0.0.1",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
+    (
+        "24khz",
+        "0.0.4",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
+    (
+        "16khz",
+        "0.0.5",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
+    (
+        "44khz",
+        "1.0.0",
+        "16kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
+}
+
+
+@argbind.bind(group="download", positional=True, without_prefix=True)
+def download(
+    model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
+):
+    """
+    Function that downloads the weights file from URL if a local cache is not found.
+
+    Parameters
+    ----------
+    model_type : str
+        The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
+    model_bitrate: str
+        Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+        Only 44khz model supports 16kbps.
+    tag : str
+        The tag of the model to download. Defaults to "latest".
+
+    Returns
+    -------
+    Path
+        Directory path required to load model via audiotools.
+    """
+    model_type = model_type.lower()
+    tag = tag.lower()
+
+    assert model_type in [
+        "44khz",
+        "24khz",
+        "16khz",
+    ], "model_type must be one of '44khz', '24khz', or '16khz'"
+
+    assert model_bitrate in [
+        "8kbps",
+        "16kbps",
+    ], "model_bitrate must be one of '8kbps', or '16kbps'"
+
+    if tag == "latest":
+        tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
+
+    download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
+
+    if download_link is None:
+        raise ValueError(
+            f"Could not find model with tag {tag} and model type {model_type}"
+        )
+
+    local_path = (
+        Path.home()
+        / ".cache"
+        / "descript"
+        / "dac"
+        / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
+    )
+    if not local_path.exists():
+        local_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # Download the model
+        import requests
+
+        response = requests.get(download_link)
+
+        if response.status_code != 200:
+            raise ValueError(
+                f"Could not download model. Received response code {response.status_code}"
+            )
+        local_path.write_bytes(response.content)
+
+    return local_path
+
+
+def load_model(
+    model_type: str = "44khz",
+    model_bitrate: str = "8kbps",
+    tag: str = "latest",
+    load_path: str = None,
+):
+    if not load_path:
+        load_path = download(
+            model_type=model_type, model_bitrate=model_bitrate, tag=tag
+        )
+    generator = DAC.load(load_path)
+    return generator
diff --git a/src/modules/dac/utils/__init__.py b/src/modules/dac/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fbd70cf223b04135af71c4b322e1a92431d6ca
--- /dev/null
+++ b/src/modules/dac/utils/__init__.py
@@ -0,0 +1,122 @@
+from pathlib import Path
+
+import argbind
+from audiotools import ml
+
+from ..model import DAC
+
+Accelerator = ml.Accelerator
+
+__MODEL_LATEST_TAGS__ = {
+    ("44khz", "8kbps"): "0.0.1",
+    ("24khz", "8kbps"): "0.0.4",
+    ("16khz", "8kbps"): "0.0.5",
+    ("44khz", "16kbps"): "1.0.0",
+}
+
+__MODEL_URLS__ = {
+    (
+        "44khz",
+        "0.0.1",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
+    (
+        "24khz",
+        "0.0.4",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
+    (
+        "16khz",
+        "0.0.5",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
+    (
+        "44khz",
+        "1.0.0",
+        "16kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
+}
+
+
+@argbind.bind(group="download", positional=True, without_prefix=True)
+def download(
+    model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
+):
+    """
+    Function that downloads the weights file from URL if a local cache is not found.
+
+    Parameters
+    ----------
+    model_type : str
+        The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
+    model_bitrate: str
+        Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+        Only 44khz model supports 16kbps.
+    tag : str
+        The tag of the model to download. Defaults to "latest".
+
+    Returns
+    -------
+    Path
+        Directory path required to load model via audiotools.
+    """
+    model_type = model_type.lower()
+    tag = tag.lower()
+
+    assert model_type in [
+        "44khz",
+        "24khz",
+        "16khz",
+    ], "model_type must be one of '44khz', '24khz', or '16khz'"
+
+    assert model_bitrate in [
+        "8kbps",
+        "16kbps",
+    ], "model_bitrate must be one of '8kbps', or '16kbps'"
+
+    if tag == "latest":
+        tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
+
+    download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
+
+    if download_link is None:
+        raise ValueError(
+            f"Could not find model with tag {tag} and model type {model_type}"
+        )
+
+    local_path = (
+        Path.home()
+        / ".cache"
+        / "descript"
+        / "dac"
+        / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
+    )
+    if not local_path.exists():
+        local_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # Download the model
+        import requests
+
+        response = requests.get(download_link)
+
+        if response.status_code != 200:
+            raise ValueError(
+                f"Could not download model. Received response code {response.status_code}"
+            )
+        local_path.write_bytes(response.content)
+
+    return local_path
+
+
+def load_model(
+    model_type: str = "44khz",
+    model_bitrate: str = "8kbps",
+    tag: str = "latest",
+    load_path: str = None,
+):
+    if not load_path:
+        load_path = download(
+            model_type=model_type, model_bitrate=model_bitrate, tag=tag
+        )
+    generator = DAC.load(load_path)
+    return generator
diff --git a/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88aac19a3b6f8d5c904369537211025c985d307a
Binary files /dev/null and b/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee7b69ad676b50b2c12c1c2414f54c7974e249a9
Binary files /dev/null and b/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/dac/utils/decode.py b/src/modules/dac/utils/decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d44e8453ec4fa3433c2a9952d1a4da15315939
--- /dev/null
+++ b/src/modules/dac/utils/decode.py
@@ -0,0 +1,95 @@
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from tqdm import tqdm
+
+from dac import DACFile
+from dac.utils import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="decode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def decode(
+    input: str,
+    output: str = "",
+    weights_path: str = "",
+    model_tag: str = "latest",
+    model_bitrate: str = "8kbps",
+    device: str = "cuda",
+    model_type: str = "44khz",
+    verbose: bool = False,
+):
+    """Decode audio from codes.
+
+    Parameters
+    ----------
+    input : str
+        Path to input directory or file
+    output : str, optional
+        Path to output directory, by default "".
+        If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+    weights_path : str, optional
+        Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+        model_tag and model_type.
+    model_tag : str, optional
+        Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+    model_bitrate: str
+        Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+    device : str, optional
+        Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
+    model_type : str, optional
+        The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+    """
+    generator = load_model(
+        model_type=model_type,
+        model_bitrate=model_bitrate,
+        tag=model_tag,
+        load_path=weights_path,
+    )
+    generator.to(device)
+    generator.eval()
+
+    # Find all .dac files in input directory
+    _input = Path(input)
+    input_files = list(_input.glob("**/*.dac"))
+
+    # If input is a .dac file, add it to the list
+    if _input.suffix == ".dac":
+        input_files.append(_input)
+
+    # Create output directory
+    output = Path(output)
+    output.mkdir(parents=True, exist_ok=True)
+
+    for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
+        # Load file
+        artifact = DACFile.load(input_files[i])
+
+        # Reconstruct audio from codes
+        recons = generator.decompress(artifact, verbose=verbose)
+
+        # Compute output path
+        relative_path = input_files[i].relative_to(input)
+        output_dir = output / relative_path.parent
+        if not relative_path.name:
+            output_dir = output
+            relative_path = input_files[i]
+        output_name = relative_path.with_suffix(".wav").name
+        output_path = output_dir / output_name
+        output_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # Write to file
+        recons.write(output_path)
+
+
+if __name__ == "__main__":
+    args = argbind.parse_args()
+    with argbind.scope(args):
+        decode()
diff --git a/src/modules/dac/utils/encode.py b/src/modules/dac/utils/encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa3f6f44b3c210f485da1b1726b85494ff5e7804
--- /dev/null
+++ b/src/modules/dac/utils/encode.py
@@ -0,0 +1,94 @@
+import math
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.core import util
+from tqdm import tqdm
+
+from dac.utils import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="encode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def encode(
+    input: str,
+    output: str = "",
+    weights_path: str = "",
+    model_tag: str = "latest",
+    model_bitrate: str = "8kbps",
+    n_quantizers: int = None,
+    device: str = "cuda",
+    model_type: str = "44khz",
+    win_duration: float = 5.0,
+    verbose: bool = False,
+):
+    """Encode audio files in input path to .dac format.
+
+    Parameters
+    ----------
+    input : str
+        Path to input audio file or directory
+    output : str, optional
+        Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+    weights_path : str, optional
+        Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+        model_tag and model_type.
+    model_tag : str, optional
+        Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+    model_bitrate: str
+        Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+    n_quantizers : int, optional
+        Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
+    device : str, optional
+        Device to use, by default "cuda"
+    model_type : str, optional
+        The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+    """
+    generator = load_model(
+        model_type=model_type,
+        model_bitrate=model_bitrate,
+        tag=model_tag,
+        load_path=weights_path,
+    )
+    generator.to(device)
+    generator.eval()
+    kwargs = {"n_quantizers": n_quantizers}
+
+    # Find all audio files in input path
+    input = Path(input)
+    audio_files = util.find_audio(input)
+
+    output = Path(output)
+    output.mkdir(parents=True, exist_ok=True)
+
+    for i in tqdm(range(len(audio_files)), desc="Encoding files"):
+        # Load file
+        signal = AudioSignal(audio_files[i])
+
+        # Encode audio to .dac format
+        artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
+
+        # Compute output path
+        relative_path = audio_files[i].relative_to(input)
+        output_dir = output / relative_path.parent
+        if not relative_path.name:
+            output_dir = output
+            relative_path = audio_files[i]
+        output_name = relative_path.with_suffix(".dac").name
+        output_path = output_dir / output_name
+        output_path.parent.mkdir(parents=True, exist_ok=True)
+
+        artifact.save(output_path)
+
+
+if __name__ == "__main__":
+    args = argbind.parse_args()
+    with argbind.scope(args):
+        encode()
diff --git a/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8065ea65729a235c519de4cb86c5ea07e0ab7be
--- /dev/null
+++ b/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,40 @@
+from .models.autoencoders import create_autoencoder_from_config
+import os
+import json
+import torch
+from torch.nn.utils import remove_weight_norm
+
+
+def remove_all_weight_norm(model):
+    for name, module in model.named_modules():
+        if hasattr(module, 'weight_g'):
+            remove_weight_norm(module)
+
+
+def load_vae(ckpt_path, remove_weight_norm=False):
+    config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
+
+    # Load the model configuration
+    with open(config_file) as f:
+        model_config = json.load(f)
+
+    # Create the model from the configuration
+    model = create_autoencoder_from_config(model_config)
+
+    # Load the state dictionary from the checkpoint
+    model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
+
+    # Strip the "autoencoder." prefix from the keys
+    model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
+
+    # Load the state dictionary into the model
+    model.load_state_dict(model_dict)
+
+    # Remove weight normalization
+    if remove_weight_norm:
+        remove_all_weight_norm(model)
+
+    # Set the model to evaluation mode
+    model.eval()
+
+    return model
diff --git a/src/modules/stable_vae/__init__.py b/src/modules/stable_vae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8065ea65729a235c519de4cb86c5ea07e0ab7be
--- /dev/null
+++ b/src/modules/stable_vae/__init__.py
@@ -0,0 +1,40 @@
+from .models.autoencoders import create_autoencoder_from_config
+import os
+import json
+import torch
+from torch.nn.utils import remove_weight_norm
+
+
+def remove_all_weight_norm(model):
+    for name, module in model.named_modules():
+        if hasattr(module, 'weight_g'):
+            remove_weight_norm(module)
+
+
+def load_vae(ckpt_path, remove_weight_norm=False):
+    config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
+
+    # Load the model configuration
+    with open(config_file) as f:
+        model_config = json.load(f)
+
+    # Create the model from the configuration
+    model = create_autoencoder_from_config(model_config)
+
+    # Load the state dictionary from the checkpoint
+    model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
+
+    # Strip the "autoencoder." prefix from the keys
+    model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
+
+    # Load the state dictionary into the model
+    model.load_state_dict(model_dict)
+
+    # Remove weight normalization
+    if remove_weight_norm:
+        remove_all_weight_norm(model)
+
+    # Set the model to evaluation mode
+    model.eval()
+
+    return model
diff --git a/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90d03c41f4de2fc900eed305af55b29a4741b26d
Binary files /dev/null and b/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3351c994034560d767031851b503b096ef03355
Binary files /dev/null and b/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..2741d45e70e3c25bc28f2b4e43b1a3e925a4c9e3
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py
@@ -0,0 +1,683 @@
+import torch
+import math
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+from torchaudio import transforms as T
+from alias_free_torch import Activation1d
+from .nn.layers import WNConv1d, WNConvTranspose1d
+from typing import Literal, Dict, Any
+
+# from .inference.sampling import sample
+from .utils import prepare_audio
+from .blocks import SnakeBeta
+from .bottleneck import Bottleneck, DiscreteBottleneck
+from .factory import create_pretransform_from_config, create_bottleneck_from_config
+from .pretransforms import Pretransform
+
+def checkpoint(function, *args, **kwargs):
+    kwargs.setdefault("use_reentrant", False)
+    return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
+
+def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
+    if activation == "elu":
+        act = nn.ELU()
+    elif activation == "snake":
+        act = SnakeBeta(channels)
+    elif activation == "none":
+        act = nn.Identity()
+    else:
+        raise ValueError(f"Unknown activation {activation}")
+    
+    if antialias:
+        act = Activation1d(act)
+    
+    return act
+
+class ResidualUnit(nn.Module):
+    def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
+        super().__init__()
+        
+        self.dilation = dilation
+
+        padding = (dilation * (7-1)) // 2
+
+        self.layers = nn.Sequential(
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+            WNConv1d(in_channels=in_channels, out_channels=out_channels,
+                      kernel_size=7, dilation=dilation, padding=padding),
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+            WNConv1d(in_channels=out_channels, out_channels=out_channels,
+                      kernel_size=1)
+        )
+
+    def forward(self, x):
+        res = x
+        
+        #x = checkpoint(self.layers, x)
+        x = self.layers(x)
+
+        return x + res
+
+class EncoderBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
+        super().__init__()
+
+        self.layers = nn.Sequential(
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=1, use_snake=use_snake),
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=3, use_snake=use_snake),
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=9, use_snake=use_snake),
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+            WNConv1d(in_channels=in_channels, out_channels=out_channels,
+                      kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+class DecoderBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
+        super().__init__()
+
+        if use_nearest_upsample:
+            upsample_layer = nn.Sequential(
+                nn.Upsample(scale_factor=stride, mode="nearest"),
+                WNConv1d(in_channels=in_channels,
+                        out_channels=out_channels, 
+                        kernel_size=2*stride,
+                        stride=1,
+                        bias=False,
+                        padding='same')
+            )
+        else:
+            upsample_layer = WNConvTranspose1d(in_channels=in_channels,
+                               out_channels=out_channels,
+                               kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
+
+        self.layers = nn.Sequential(
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+            upsample_layer,
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=1, use_snake=use_snake),
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=3, use_snake=use_snake),
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=9, use_snake=use_snake),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+class OobleckEncoder(nn.Module):
+    def __init__(self, 
+                 in_channels=2, 
+                 channels=128, 
+                 latent_dim=32, 
+                 c_mults = [1, 2, 4, 8], 
+                 strides = [2, 4, 8, 8],
+                 use_snake=False,
+                 antialias_activation=False
+        ):
+        super().__init__()
+          
+        c_mults = [1] + c_mults
+
+        self.depth = len(c_mults)
+
+        layers = [
+            WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
+        ]
+        
+        for i in range(self.depth-1):
+            layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
+
+        layers += [
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
+            WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
+        ]
+
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class OobleckDecoder(nn.Module):
+    def __init__(self, 
+                 out_channels=2, 
+                 channels=128, 
+                 latent_dim=32, 
+                 c_mults = [1, 2, 4, 8], 
+                 strides = [2, 4, 8, 8],
+                 use_snake=False,
+                 antialias_activation=False,
+                 use_nearest_upsample=False,
+                 final_tanh=True):
+        super().__init__()
+
+        c_mults = [1] + c_mults
+        
+        self.depth = len(c_mults)
+
+        layers = [
+            WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
+        ]
+        
+        for i in range(self.depth-1, 0, -1):
+            layers += [DecoderBlock(
+                in_channels=c_mults[i]*channels, 
+                out_channels=c_mults[i-1]*channels, 
+                stride=strides[i-1], 
+                use_snake=use_snake, 
+                antialias_activation=antialias_activation,
+                use_nearest_upsample=use_nearest_upsample
+                )
+            ]
+
+        layers += [
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
+            WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
+            nn.Tanh() if final_tanh else nn.Identity()
+        ]
+
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class DACEncoderWrapper(nn.Module):
+    def __init__(self, in_channels=1, **kwargs):
+        super().__init__()
+
+        from dac.model.dac import Encoder as DACEncoder
+
+        latent_dim = kwargs.pop("latent_dim", None)
+
+        encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
+        self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
+        self.latent_dim = latent_dim
+
+        # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
+        self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
+
+        if in_channels != 1:
+            self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.proj_out(x)
+        return x
+
+class DACDecoderWrapper(nn.Module):
+    def __init__(self, latent_dim, out_channels=1, **kwargs):
+        super().__init__()
+
+        from dac.model.dac import Decoder as DACDecoder
+
+        self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
+
+        self.latent_dim = latent_dim
+
+    def forward(self, x):
+        return self.decoder(x)
+
+class AudioAutoencoder(nn.Module):
+    def __init__(
+        self,
+        encoder,
+        decoder,
+        latent_dim,
+        downsampling_ratio,
+        sample_rate,
+        io_channels=2,
+        bottleneck: Bottleneck = None,
+        pretransform: Pretransform = None,
+        in_channels = None,
+        out_channels = None,
+        soft_clip = False
+    ):
+        super().__init__()
+
+        self.downsampling_ratio = downsampling_ratio
+        self.sample_rate = sample_rate
+
+        self.latent_dim = latent_dim
+        self.io_channels = io_channels
+        self.in_channels = io_channels
+        self.out_channels = io_channels
+
+        self.min_length = self.downsampling_ratio
+
+        if in_channels is not None:
+            self.in_channels = in_channels
+
+        if out_channels is not None:
+            self.out_channels = out_channels
+
+        self.bottleneck = bottleneck
+
+        self.encoder = encoder
+
+        self.decoder = decoder
+
+        self.pretransform = pretransform
+
+        self.soft_clip = soft_clip
+ 
+        self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
+
+    def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
+
+        info = {}
+
+        if self.pretransform is not None and not skip_pretransform:
+            if self.pretransform.enable_grad:
+                if iterate_batch:
+                    audios = []
+                    for i in range(audio.shape[0]):
+                        audios.append(self.pretransform.encode(audio[i:i+1]))
+                    audio = torch.cat(audios, dim=0)
+                else:
+                    audio = self.pretransform.encode(audio)
+            else:
+                with torch.no_grad():
+                    if iterate_batch:
+                        audios = []
+                        for i in range(audio.shape[0]):
+                            audios.append(self.pretransform.encode(audio[i:i+1]))
+                        audio = torch.cat(audios, dim=0)
+                    else:
+                        audio = self.pretransform.encode(audio)
+
+        if self.encoder is not None:
+            if iterate_batch:
+                latents = []
+                for i in range(audio.shape[0]):
+                    latents.append(self.encoder(audio[i:i+1]))
+                latents = torch.cat(latents, dim=0)
+            else:
+                latents = self.encoder(audio)
+        else:
+            latents = audio
+
+        if self.bottleneck is not None:
+            # TODO: Add iterate batch logic, needs to merge the info dicts
+            latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
+
+            info.update(bottleneck_info)
+
+        if return_info:
+            return latents, info
+
+        return latents
+
+    def decode(self, latents, iterate_batch=False, **kwargs):
+
+        if self.bottleneck is not None:
+            if iterate_batch:
+                decoded = []
+                for i in range(latents.shape[0]):
+                    decoded.append(self.bottleneck.decode(latents[i:i+1]))
+                decoded = torch.cat(decoded, dim=0)
+            else:
+                latents = self.bottleneck.decode(latents)
+
+        if iterate_batch:
+            decoded = []
+            for i in range(latents.shape[0]):
+                decoded.append(self.decoder(latents[i:i+1]))
+            decoded = torch.cat(decoded, dim=0)
+        else:
+            decoded = self.decoder(latents, **kwargs)
+
+        if self.pretransform is not None:
+            if self.pretransform.enable_grad:
+                if iterate_batch:
+                    decodeds = []
+                    for i in range(decoded.shape[0]):
+                        decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+                    decoded = torch.cat(decodeds, dim=0)
+                else:
+                    decoded = self.pretransform.decode(decoded)
+            else:
+                with torch.no_grad():
+                    if iterate_batch:
+                        decodeds = []
+                        for i in range(latents.shape[0]):
+                            decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+                        decoded = torch.cat(decodeds, dim=0)
+                    else:
+                        decoded = self.pretransform.decode(decoded)
+
+        if self.soft_clip:
+            decoded = torch.tanh(decoded)
+
+        return decoded
+
+    def decode_tokens(self, tokens, **kwargs):
+        '''
+        Decode discrete tokens to audio
+        Only works with discrete autoencoders
+        '''
+
+        assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
+
+        latents = self.bottleneck.decode_tokens(tokens, **kwargs)
+
+        return self.decode(latents, **kwargs)
+        
+    
+    def preprocess_audio_for_encoder(self, audio, in_sr):
+        '''
+        Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
+        If the model is mono, stereo audio will be converted to mono.
+        Audio will be silence-padded to be a multiple of the model's downsampling ratio.
+        Audio will be resampled to the model's sample rate. 
+        The output will have batch size 1 and be shape (1 x Channels x Length)
+        '''
+        return self.preprocess_audio_list_for_encoder([audio], [in_sr])
+
+    def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
+        '''
+        Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. 
+        The audio in that list can be of different lengths and channels. 
+        in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
+        All audio will be resampled to the model's sample rate. 
+        Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. 
+        If the model is mono, all audio will be converted to mono. 
+        The output will be a tensor of shape (Batch x Channels x Length)
+        '''
+        batch_size = len(audio_list)
+        if isinstance(in_sr_list, int):
+            in_sr_list = [in_sr_list]*batch_size
+        assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
+        new_audio = []
+        max_length = 0
+        # resample & find the max length
+        for i in range(batch_size):
+            audio = audio_list[i]
+            in_sr = in_sr_list[i]
+            if len(audio.shape) == 3 and audio.shape[0] == 1:
+                # batchsize 1 was given by accident. Just squeeze it.
+                audio = audio.squeeze(0)
+            elif len(audio.shape) == 1:
+                # Mono signal, channel dimension is missing, unsqueeze it in
+                audio = audio.unsqueeze(0)
+            assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" 
+            # Resample audio
+            if in_sr != self.sample_rate:
+                resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
+                audio = resample_tf(audio)
+            new_audio.append(audio)
+            if audio.shape[-1] > max_length:
+                max_length = audio.shape[-1]
+        # Pad every audio to the same length, multiple of model's downsampling ratio
+        padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
+        for i in range(batch_size):
+            # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
+            new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, 
+                target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
+        # convert to tensor 
+        return torch.stack(new_audio) 
+
+    def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
+        '''
+        Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
+        If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
+        Overlap and chunk_size params are both measured in number of latents (not audio samples) 
+        # and therefore you likely could use the same values with decode_audio. 
+        A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. 
+        Every autoencoder will have a different receptive field size, and thus ideal overlap.
+        You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
+        The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+        Smaller chunk_size uses less memory, but more compute.
+        The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+        For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+        '''
+        if not chunked:
+            # default behavior. Encode the entire audio in parallel
+            return self.encode(audio, **kwargs)
+        else:
+            # CHUNKED ENCODING
+            # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
+            samples_per_latent = self.downsampling_ratio
+            total_size = audio.shape[2] # in samples
+            batch_size = audio.shape[0]
+            chunk_size *= samples_per_latent # converting metric in latents to samples
+            overlap *= samples_per_latent # converting metric in latents to samples
+            hop_size = chunk_size - overlap
+            chunks = []
+            for i in range(0, total_size - chunk_size + 1, hop_size):
+                chunk = audio[:,:,i:i+chunk_size]
+                chunks.append(chunk)
+            if i+chunk_size != total_size:
+                # Final chunk
+                chunk = audio[:,:,-chunk_size:]
+                chunks.append(chunk)
+            chunks = torch.stack(chunks)
+            num_chunks = chunks.shape[0]
+            # Note: y_size might be a different value from the latent length used in diffusion training
+            # because we can encode audio of varying lengths
+            # However, the audio should've been padded to a multiple of samples_per_latent by now.
+            y_size = total_size // samples_per_latent
+            # Create an empty latent, we will populate it with chunks as we encode them
+            y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
+            for i in range(num_chunks):
+                x_chunk = chunks[i,:]
+                # encode the chunk
+                y_chunk = self.encode(x_chunk)
+                # figure out where to put the audio along the time domain
+                if i == num_chunks-1:
+                    # final chunk always goes at the end
+                    t_end = y_size
+                    t_start = t_end - y_chunk.shape[2]
+                else:
+                    t_start = i * hop_size // samples_per_latent
+                    t_end = t_start + chunk_size // samples_per_latent
+                #  remove the edges of the overlaps
+                ol = overlap//samples_per_latent//2
+                chunk_start = 0
+                chunk_end = y_chunk.shape[2]
+                if i > 0:
+                    # no overlap for the start of the first chunk
+                    t_start += ol
+                    chunk_start += ol
+                if i < num_chunks-1:
+                    # no overlap for the end of the last chunk
+                    t_end -= ol
+                    chunk_end -= ol
+                # paste the chunked audio into our y_final output audio
+                y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+            return y_final
+    
+    def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
+        '''
+        Decode latents to audio. 
+        If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. 
+        A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. 
+        Every autoencoder will have a different receptive field size, and thus ideal overlap.
+        You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
+        The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+        Smaller chunk_size uses less memory, but more compute.
+        The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+        For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+        '''
+        if not chunked:
+            # default behavior. Decode the entire latent in parallel
+            return self.decode(latents, **kwargs)
+        else:
+            # chunked decoding
+            hop_size = chunk_size - overlap
+            total_size = latents.shape[2]
+            batch_size = latents.shape[0]
+            chunks = []
+            for i in range(0, total_size - chunk_size + 1, hop_size):
+                chunk = latents[:,:,i:i+chunk_size]
+                chunks.append(chunk)
+            if i+chunk_size != total_size:
+                # Final chunk
+                chunk = latents[:,:,-chunk_size:]
+                chunks.append(chunk)
+            chunks = torch.stack(chunks)
+            num_chunks = chunks.shape[0]
+            # samples_per_latent is just the downsampling ratio
+            samples_per_latent = self.downsampling_ratio
+            # Create an empty waveform, we will populate it with chunks as decode them
+            y_size = total_size * samples_per_latent
+            y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
+            for i in range(num_chunks):
+                x_chunk = chunks[i,:]
+                # decode the chunk
+                y_chunk = self.decode(x_chunk)
+                # figure out where to put the audio along the time domain
+                if i == num_chunks-1:
+                    # final chunk always goes at the end
+                    t_end = y_size
+                    t_start = t_end - y_chunk.shape[2]
+                else:
+                    t_start = i * hop_size * samples_per_latent
+                    t_end = t_start + chunk_size * samples_per_latent
+                #  remove the edges of the overlaps
+                ol = (overlap//2) * samples_per_latent
+                chunk_start = 0
+                chunk_end = y_chunk.shape[2]
+                if i > 0:
+                    # no overlap for the start of the first chunk
+                    t_start += ol
+                    chunk_start += ol
+                if i < num_chunks-1:
+                    # no overlap for the end of the last chunk
+                    t_end -= ol
+                    chunk_end -= ol
+                # paste the chunked audio into our y_final output audio
+                y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+            return y_final
+
+    
+# AE factories
+
+def create_encoder_from_config(encoder_config: Dict[str, Any]):
+    encoder_type = encoder_config.get("type", None)
+    assert encoder_type is not None, "Encoder type must be specified"
+
+    if encoder_type == "oobleck":
+        encoder = OobleckEncoder(
+            **encoder_config["config"]
+        )
+    
+    elif encoder_type == "seanet":
+        from encodec.modules import SEANetEncoder
+        seanet_encoder_config = encoder_config["config"]
+
+        #SEANet encoder expects strides in reverse order
+        seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
+        encoder = SEANetEncoder(
+            **seanet_encoder_config
+        )
+    elif encoder_type == "dac":
+        dac_config = encoder_config["config"]
+
+        encoder = DACEncoderWrapper(**dac_config)
+    elif encoder_type == "local_attn":
+        from .local_attention import TransformerEncoder1D
+
+        local_attn_config = encoder_config["config"]
+
+        encoder = TransformerEncoder1D(
+            **local_attn_config
+        )
+    else:
+        raise ValueError(f"Unknown encoder type {encoder_type}")
+    
+    requires_grad = encoder_config.get("requires_grad", True)
+    if not requires_grad:
+        for param in encoder.parameters():
+            param.requires_grad = False
+
+    return encoder
+
+def create_decoder_from_config(decoder_config: Dict[str, Any]):
+    decoder_type = decoder_config.get("type", None)
+    assert decoder_type is not None, "Decoder type must be specified"
+
+    if decoder_type == "oobleck":
+        decoder = OobleckDecoder(
+            **decoder_config["config"]
+        )
+    elif decoder_type == "seanet":
+        from encodec.modules import SEANetDecoder
+
+        decoder = SEANetDecoder(
+            **decoder_config["config"]
+        )
+    elif decoder_type == "dac":
+        dac_config = decoder_config["config"]
+
+        decoder = DACDecoderWrapper(**dac_config)
+    elif decoder_type == "local_attn":
+        from .local_attention import TransformerDecoder1D
+
+        local_attn_config = decoder_config["config"]
+
+        decoder = TransformerDecoder1D(
+            **local_attn_config
+        )
+    else:
+        raise ValueError(f"Unknown decoder type {decoder_type}")
+    
+    requires_grad = decoder_config.get("requires_grad", True)
+    if not requires_grad:
+        for param in decoder.parameters():
+            param.requires_grad = False
+
+    return decoder
+
+def create_autoencoder_from_config(config: Dict[str, Any]):
+    
+    ae_config = config["model"]
+
+    encoder = create_encoder_from_config(ae_config["encoder"])
+    decoder = create_decoder_from_config(ae_config["decoder"])
+
+    bottleneck = ae_config.get("bottleneck", None)
+
+    latent_dim = ae_config.get("latent_dim", None)
+    assert latent_dim is not None, "latent_dim must be specified in model config"
+    downsampling_ratio = ae_config.get("downsampling_ratio", None)
+    assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
+    io_channels = ae_config.get("io_channels", None)
+    assert io_channels is not None, "io_channels must be specified in model config"
+    sample_rate = config.get("sample_rate", None)
+    assert sample_rate is not None, "sample_rate must be specified in model config"
+
+    in_channels = ae_config.get("in_channels", None)
+    out_channels = ae_config.get("out_channels", None)
+
+    pretransform = ae_config.get("pretransform", None)
+
+    if pretransform is not None:
+        pretransform = create_pretransform_from_config(pretransform, sample_rate)
+
+    if bottleneck is not None:
+        bottleneck = create_bottleneck_from_config(bottleneck)
+
+    soft_clip = ae_config["decoder"].get("soft_clip", False)
+
+    return AudioAutoencoder(
+        encoder,
+        decoder,
+        io_channels=io_channels,
+        latent_dim=latent_dim,
+        downsampling_ratio=downsampling_ratio,
+        sample_rate=sample_rate,
+        bottleneck=bottleneck,
+        pretransform=pretransform,
+        in_channels=in_channels,
+        out_channels=out_channels,
+        soft_clip=soft_clip
+    )
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb310c8980ef5dc0f138e6f9f3478d4cdc63354d
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py
@@ -0,0 +1,359 @@
+from functools import reduce
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.backends.cuda import sdp_kernel
+from packaging import version
+
+from .nn.layers import Snake1d
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, main, skip=None):
+        super().__init__()
+        self.main = nn.Sequential(*main)
+        self.skip = skip if skip else nn.Identity()
+
+    def forward(self, input):
+        return self.main(input) + self.skip(input)
+
+
+class ResConvBlock(ResidualBlock):
+    def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
+        skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
+        super().__init__([
+            nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
+            nn.GroupNorm(1, c_mid),
+            Snake1d(c_mid) if use_snake else nn.GELU(),
+            nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
+            nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
+            (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
+        ], skip)
+
+
+class SelfAttention1d(nn.Module):
+    def __init__(self, c_in, n_head=1, dropout_rate=0.):
+        super().__init__()
+        assert c_in % n_head == 0
+        self.norm = nn.GroupNorm(1, c_in)
+        self.n_head = n_head
+        self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
+        self.out_proj = nn.Conv1d(c_in, c_in, 1)
+        self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+        self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
+
+        if not self.use_flash:
+            return
+
+        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+
+        if device_properties.major == 8 and device_properties.minor == 0:
+            # Use flash attention for A100 GPUs
+            self.sdp_kernel_config = (True, False, False)
+        else:
+            # Don't use flash attention for other GPUs
+            self.sdp_kernel_config = (False, True, True)
+
+    def forward(self, input):
+        n, c, s = input.shape
+        qkv = self.qkv_proj(self.norm(input))
+        qkv = qkv.view(
+            [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = k.shape[3]**-0.25
+
+        if self.use_flash:
+            with sdp_kernel(*self.sdp_kernel_config):
+                y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
+        else:
+            att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
+            y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
+
+
+        return input + self.dropout(self.out_proj(y))
+
+
+class SkipBlock(nn.Module):
+    def __init__(self, *main):
+        super().__init__()
+        self.main = nn.Sequential(*main)
+
+    def forward(self, input):
+        return torch.cat([self.main(input), input], dim=1)
+
+
+class FourierFeatures(nn.Module):
+    def __init__(self, in_features, out_features, std=1.):
+        super().__init__()
+        assert out_features % 2 == 0
+        self.weight = nn.Parameter(torch.randn(
+            [out_features // 2, in_features]) * std)
+
+    def forward(self, input):
+        f = 2 * math.pi * input @ self.weight.T
+        return torch.cat([f.cos(), f.sin()], dim=-1)
+
+
+def expand_to_planes(input, shape):
+    return input[..., None].repeat([1, 1, shape[2]])
+
+_kernels = {
+    'linear':
+        [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+    'cubic': 
+        [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
+        0.43359375, 0.11328125, -0.03515625, -0.01171875],
+    'lanczos3': 
+        [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
+        -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
+        0.44638532400131226, 0.13550527393817902, -0.066637322306633,
+        -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
+}
+
+
+class Downsample1d(nn.Module):
+    def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+        super().__init__()
+        self.pad_mode = pad_mode
+        kernel_1d = torch.tensor(_kernels[kernel])
+        self.pad = kernel_1d.shape[0] // 2 - 1
+        self.register_buffer('kernel', kernel_1d)
+        self.channels_last = channels_last
+    
+    def forward(self, x):
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        x = F.pad(x, (self.pad,) * 2, self.pad_mode)
+        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+        indices = torch.arange(x.shape[1], device=x.device)
+        weight[indices, indices] = self.kernel.to(weight)
+        x = F.conv1d(x, weight, stride=2)
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        return x
+
+
+class Upsample1d(nn.Module):
+    def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+        super().__init__()
+        self.pad_mode = pad_mode
+        kernel_1d = torch.tensor(_kernels[kernel]) * 2
+        self.pad = kernel_1d.shape[0] // 2 - 1
+        self.register_buffer('kernel', kernel_1d)
+        self.channels_last = channels_last
+    
+    def forward(self, x):
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+        indices = torch.arange(x.shape[1], device=x.device)
+        weight[indices, indices] = self.kernel.to(weight)
+        x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        return x
+
+
+def Downsample1d_2(
+    in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
+) -> nn.Module:
+    assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
+
+    return nn.Conv1d(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=factor * kernel_multiplier + 1,
+        stride=factor,
+        padding=factor * (kernel_multiplier // 2),
+    )
+
+
+def Upsample1d_2(
+    in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
+) -> nn.Module:
+
+    if factor == 1:
+        return nn.Conv1d(
+            in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
+        )
+
+    if use_nearest:
+        return nn.Sequential(
+            nn.Upsample(scale_factor=factor, mode="nearest"),
+            nn.Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                padding=1,
+            ),
+        )
+    else:
+        return nn.ConvTranspose1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=factor * 2,
+            stride=factor,
+            padding=factor // 2 + factor % 2,
+            output_padding=factor % 2,
+        )
+
+
+def zero_init(layer):
+    nn.init.zeros_(layer.weight)
+    if layer.bias is not None:
+        nn.init.zeros_(layer.bias)
+    return layer
+
+
+def rms_norm(x, scale, eps):
+    dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+    mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+    scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+    return x * scale.to(x.dtype)
+
+#rms_norm = torch.compile(rms_norm)
+
+class AdaRMSNorm(nn.Module):
+    def __init__(self, features, cond_features, eps=1e-6):
+        super().__init__()
+        self.eps = eps
+        self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
+  
+    def extra_repr(self):
+        return f"eps={self.eps},"
+
+    def forward(self, x, cond):
+        return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
+
+
+def normalize(x, eps=1e-4):
+    dim = list(range(1, x.ndim))
+    n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
+    alpha = np.sqrt(n.numel() / x.numel())
+    return x / torch.add(eps, n, alpha=alpha)
+
+
+class ForcedWNConv1d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=1):
+        super().__init__()
+        self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
+
+    def forward(self, x):
+        if self.training:
+            with torch.no_grad():
+                self.weight.copy_(normalize(self.weight))
+        
+        fan_in = self.weight[0].numel()
+
+        w = normalize(self.weight) / math.sqrt(fan_in)
+
+        return F.conv1d(x, w, padding='same')
+        
+# Kernels
+
+use_compile = True
+
+def compile(function, *args, **kwargs):
+    if not use_compile:
+        return function
+    try:
+        return torch.compile(function, *args, **kwargs)
+    except RuntimeError:
+        return function
+
+
+@compile
+def linear_geglu(x, weight, bias=None):
+    x = x @ weight.mT
+    if bias is not None:
+        x = x + bias
+    x, gate = x.chunk(2, dim=-1)
+    return x * F.gelu(gate)
+
+
+@compile
+def rms_norm(x, scale, eps):
+    dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+    mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+    scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+    return x * scale.to(x.dtype)
+
+# Layers
+
+
+class LinearGEGLU(nn.Linear):
+    def __init__(self, in_features, out_features, bias=True):
+        super().__init__(in_features, out_features * 2, bias=bias)
+        self.out_features = out_features
+
+    def forward(self, x):
+        return linear_geglu(x, self.weight, self.bias)
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, shape, fix_scale = False, eps=1e-6):
+        super().__init__()
+        self.eps = eps
+
+        if fix_scale:
+            self.register_buffer("scale", torch.ones(shape))
+        else:
+            self.scale = nn.Parameter(torch.ones(shape))
+
+    def extra_repr(self):
+        return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
+
+    def forward(self, x):
+        return rms_norm(x, self.scale, self.eps)
+
+
+# jit script make it 1.4x faster and save GPU memory
+@torch.jit.script
+def snake_beta(x, alpha, beta):
+    return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
+
+# try:
+#     snake_beta = torch.compile(snake_beta)
+# except RuntimeError:
+#     pass
+
+
+# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
+# License available in LICENSES/LICENSE_NVIDIA.txt
+class SnakeBeta(nn.Module):
+
+    def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
+        super(SnakeBeta, self).__init__()
+        self.in_features = in_features
+
+        # initialize alpha
+        self.alpha_logscale = alpha_logscale
+        if self.alpha_logscale: 
+            # log scale alphas initialized to zeros
+            self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+            self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+        else:
+            # linear scale alphas initialized to ones
+            self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+            self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+        # self.no_div_by_zero = 0.000000001
+
+    def forward(self, x):
+        alpha = self.alpha.unsqueeze(0).unsqueeze(-1) 
+        # line up with x to [B, C, T]
+        beta = self.beta.unsqueeze(0).unsqueeze(-1)
+        if self.alpha_logscale:
+            alpha = torch.exp(alpha)
+            beta = torch.exp(beta)
+        x = snake_beta(x, alpha, beta)
+
+        return x
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..df88c5f1b1f5fa3675c1a42f42e5e31e27d00ed3
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py
@@ -0,0 +1,346 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from einops import rearrange
+from vector_quantize_pytorch import ResidualVQ, FSQ
+from .nn.quantize import ResidualVectorQuantize as DACResidualVQ
+
+
+class Bottleneck(nn.Module):
+    def __init__(self, is_discrete: bool = False):
+        super().__init__()
+
+        self.is_discrete = is_discrete
+
+    def encode(self, x, return_info=False, **kwargs):
+        raise NotImplementedError
+
+    def decode(self, x):
+        raise NotImplementedError
+
+
+class DiscreteBottleneck(Bottleneck):
+    def __init__(self, num_quantizers, codebook_size, tokens_id):
+        super().__init__(is_discrete=True)
+
+        self.num_quantizers = num_quantizers
+        self.codebook_size = codebook_size
+        self.tokens_id = tokens_id
+
+    def decode_tokens(self, codes, **kwargs):
+        raise NotImplementedError
+
+
+class TanhBottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+        self.tanh = nn.Tanh()
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = torch.tanh(x)
+
+        if return_info:
+            return x, info
+        else:
+            return x
+
+    def decode(self, x):
+        return x
+
+
+@torch.jit.script
+def vae_sample_kl(mean, scale):
+    stdev = nn.functional.softplus(scale) + 1e-4
+    var = stdev * stdev
+    logvar = torch.log(var)
+    latents = torch.randn_like(mean) * stdev + mean
+
+    kl = (mean * mean + var - logvar - 1).sum(1).mean()
+
+    return latents, kl
+
+
+@torch.jit.script
+def vae_sample(mean, scale):
+    stdev = nn.functional.softplus(scale) + 1e-4
+    latents = torch.randn_like(mean) * stdev + mean
+    return latents
+
+
+class VAEBottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+
+    def encode(self, x, return_info=False, **kwargs):
+        mean, scale = x.chunk(2, dim=1)
+
+        if return_info:
+            info = {}
+            x, kl = vae_sample_kl(mean, scale)
+            info["kl"] = kl
+            return x, info
+        else:
+            x = vae_sample(mean, scale)
+            return x
+
+    def decode(self, x):
+        return x
+
+
+def compute_mean_kernel(x, y):
+    kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
+    return torch.exp(-kernel_input).mean()
+
+
+def compute_mmd(latents):
+    latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
+    noise = torch.randn_like(latents_reshaped)
+
+    latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
+    noise_kernel = compute_mean_kernel(noise, noise)
+    latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
+    
+    mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
+    return mmd.mean()
+
+
+class WassersteinBottleneck(Bottleneck):
+    def __init__(self, noise_augment_dim: int = 0):
+        super().__init__(is_discrete=False)
+
+        self.noise_augment_dim = noise_augment_dim
+    
+    def encode(self, x, return_info=False):
+        info = {}
+
+        if self.training and return_info:
+            mmd = compute_mmd(x)
+            info["mmd"] = mmd
+        
+        if return_info:
+            return x, info
+        
+        return x
+
+    def decode(self, x):
+
+        if self.noise_augment_dim > 0:
+            noise = torch.randn(x.shape[0], self.noise_augment_dim,
+                                x.shape[-1]).type_as(x)
+            x = torch.cat([x, noise], dim=1)
+
+        return x
+
+
+class L2Bottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+    
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = F.normalize(x, dim=1)
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return F.normalize(x, dim=1)
+
+
+class RVQBottleneck(DiscreteBottleneck):
+    def __init__(self, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+        self.quantizer = ResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+    def encode(self, x, return_info=False, **kwargs):
+        info = {}
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices, loss = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+        info["quantizer_loss"] = loss.mean()
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents = self.quantizer.get_outputs_from_indices(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class RVQVAEBottleneck(DiscreteBottleneck):
+    def __init__(self, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+        self.quantizer = ResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x, kl = vae_sample(*x.chunk(2, dim=1))
+
+        info["kl"] = kl
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices, loss = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+        info["quantizer_loss"] = loss.mean()
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents = self.quantizer.get_outputs_from_indices(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class DACRVQBottleneck(DiscreteBottleneck):
+    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+        self.quantizer = DACResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["n_codebooks"]
+        self.quantize_on_decode = quantize_on_decode
+
+    def encode(self, x, return_info=False, **kwargs):
+        info = {}
+
+        info["pre_quantizer"] = x
+
+        if self.quantize_on_decode:
+            return x, info if return_info else x
+
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
+
+        output = {
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+        output["vq/commitment_loss"] /= self.num_quantizers
+        output["vq/codebook_loss"] /= self.num_quantizers
+
+        info.update(output)
+
+        if return_info:
+            return output["z"], info
+        
+        return output["z"]
+    
+    def decode(self, x):
+
+        if self.quantize_on_decode:
+            x = self.quantizer(x)[0]
+
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents, _, _ = self.quantizer.from_codes(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class DACRVQVAEBottleneck(DiscreteBottleneck):
+    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+        self.quantizer = DACResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["n_codebooks"]
+        self.quantize_on_decode = quantize_on_decode
+
+    def encode(self, x, return_info=False, n_quantizers: int = None):
+        info = {}
+
+        mean, scale = x.chunk(2, dim=1)
+
+        x, kl = vae_sample(mean, scale)
+
+        info["pre_quantizer"] = x
+        info["kl"] = kl
+
+        if self.quantize_on_decode:
+            return x, info if return_info else x
+
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
+
+        output = {
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+        output["vq/commitment_loss"] /= self.num_quantizers
+        output["vq/codebook_loss"] /= self.num_quantizers
+
+        info.update(output)
+
+        if return_info:
+            return output["z"], info
+        
+        return output["z"]
+    
+    def decode(self, x):
+
+        if self.quantize_on_decode:
+            x = self.quantizer(x)[0]
+
+        return x
+
+    def decode_tokens(self, codes, **kwargs):
+        latents, _, _ = self.quantizer.from_codes(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class FSQBottleneck(DiscreteBottleneck):
+    def __init__(self, dim, levels):
+        super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
+        self.quantizer = FSQ(levels=[levels] * dim)
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, tokens, **kwargs):
+        latents = self.quantizer.indices_to_codes(tokens)
+
+        return self.decode(latents, **kwargs)
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..4188703000ee176342c7f329342f18d6fe747b04
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py
@@ -0,0 +1,153 @@
+import json
+
+def create_model_from_config(model_config):
+    model_type = model_config.get('model_type', None)
+
+    assert model_type is not None, 'model_type must be specified in model config'
+
+    if model_type == 'autoencoder':
+        from .autoencoders import create_autoencoder_from_config
+        return create_autoencoder_from_config(model_config)
+    elif model_type == 'diffusion_uncond':
+        from .diffusion import create_diffusion_uncond_from_config
+        return create_diffusion_uncond_from_config(model_config)
+    elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
+        from .diffusion import create_diffusion_cond_from_config
+        return create_diffusion_cond_from_config(model_config)
+    elif model_type == 'diffusion_autoencoder':
+        from .autoencoders import create_diffAE_from_config
+        return create_diffAE_from_config(model_config)
+    elif model_type == 'lm':
+        from .lm import create_audio_lm_from_config
+        return create_audio_lm_from_config(model_config)
+    else:
+        raise NotImplementedError(f'Unknown model type: {model_type}')
+
+def create_model_from_config_path(model_config_path):
+    with open(model_config_path) as f:
+        model_config = json.load(f)
+    
+    return create_model_from_config(model_config)
+
+def create_pretransform_from_config(pretransform_config, sample_rate):
+    pretransform_type = pretransform_config.get('type', None)
+
+    assert pretransform_type is not None, 'type must be specified in pretransform config'
+
+    if pretransform_type == 'autoencoder':
+        from .autoencoders import create_autoencoder_from_config
+        from .pretransforms import AutoencoderPretransform
+
+        # Create fake top-level config to pass sample rate to autoencoder constructor
+        # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
+        autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
+        autoencoder = create_autoencoder_from_config(autoencoder_config)
+
+        scale = pretransform_config.get("scale", 1.0)
+        model_half = pretransform_config.get("model_half", False)
+        iterate_batch = pretransform_config.get("iterate_batch", False)
+        chunked = pretransform_config.get("chunked", False)
+
+        pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
+    elif pretransform_type == 'wavelet':
+        from .pretransforms import WaveletPretransform
+
+        wavelet_config = pretransform_config["config"]
+        channels = wavelet_config["channels"]
+        levels = wavelet_config["levels"]
+        wavelet = wavelet_config["wavelet"]
+
+        pretransform = WaveletPretransform(channels, levels, wavelet)
+    elif pretransform_type == 'pqmf':
+        from .pretransforms import PQMFPretransform
+        pqmf_config = pretransform_config["config"]
+        pretransform = PQMFPretransform(**pqmf_config)
+    elif pretransform_type == 'dac_pretrained':
+        from .pretransforms import PretrainedDACPretransform
+        pretrained_dac_config = pretransform_config["config"]
+        pretransform = PretrainedDACPretransform(**pretrained_dac_config)
+    elif pretransform_type == "audiocraft_pretrained":
+        from .pretransforms import AudiocraftCompressionPretransform
+
+        audiocraft_config = pretransform_config["config"]
+        pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
+    else:
+        raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
+    
+    enable_grad = pretransform_config.get('enable_grad', False)
+    pretransform.enable_grad = enable_grad
+
+    pretransform.eval().requires_grad_(pretransform.enable_grad)
+
+    return pretransform
+
+def create_bottleneck_from_config(bottleneck_config):
+    bottleneck_type = bottleneck_config.get('type', None)
+
+    assert bottleneck_type is not None, 'type must be specified in bottleneck config'
+
+    if bottleneck_type == 'tanh':
+        from .bottleneck import TanhBottleneck
+        bottleneck = TanhBottleneck()
+    elif bottleneck_type == 'vae':
+        from .bottleneck import VAEBottleneck
+        bottleneck = VAEBottleneck()
+    elif bottleneck_type == 'rvq':
+        from .bottleneck import RVQBottleneck
+
+        quantizer_params = {
+            "dim": 128,
+            "codebook_size": 1024,
+            "num_quantizers": 8,
+            "decay": 0.99,
+            "kmeans_init": True,
+            "kmeans_iters": 50,
+            "threshold_ema_dead_code": 2,
+        }
+
+        quantizer_params.update(bottleneck_config["config"])
+
+        bottleneck = RVQBottleneck(**quantizer_params)
+    elif bottleneck_type == "dac_rvq":
+        from .bottleneck import DACRVQBottleneck
+
+        bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
+    
+    elif bottleneck_type == 'rvq_vae':
+        from .bottleneck import RVQVAEBottleneck
+
+        quantizer_params = {
+            "dim": 128,
+            "codebook_size": 1024,
+            "num_quantizers": 8,
+            "decay": 0.99,
+            "kmeans_init": True,
+            "kmeans_iters": 50,
+            "threshold_ema_dead_code": 2,
+        }
+
+        quantizer_params.update(bottleneck_config["config"])
+
+        bottleneck = RVQVAEBottleneck(**quantizer_params)
+        
+    elif bottleneck_type == 'dac_rvq_vae':
+        from .bottleneck import DACRVQVAEBottleneck
+        bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
+    elif bottleneck_type == 'l2_norm':
+        from .bottleneck import L2Bottleneck
+        bottleneck = L2Bottleneck()
+    elif bottleneck_type == "wasserstein":
+        from .bottleneck import WassersteinBottleneck
+        bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
+    elif bottleneck_type == "fsq":
+        from .bottleneck import FSQBottleneck
+        bottleneck = FSQBottleneck(**bottleneck_config["config"])
+    else:
+        raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
+    
+    requires_grad = bottleneck_config.get('requires_grad', True)
+    if not requires_grad:
+        for param in bottleneck.parameters():
+            param.requires_grad = False
+
+    return bottleneck
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py
@@ -0,0 +1,258 @@
+import torch
+from einops import rearrange
+from torch import nn
+
+class Pretransform(nn.Module):
+    def __init__(self, enable_grad, io_channels, is_discrete):
+        super().__init__()
+
+        self.is_discrete = is_discrete
+        self.io_channels = io_channels
+        self.encoded_channels = None
+        self.downsampling_ratio = None
+
+        self.enable_grad = enable_grad
+
+    def encode(self, x):
+        raise NotImplementedError
+
+    def decode(self, z):
+        raise NotImplementedError
+    
+    def tokenize(self, x):
+        raise NotImplementedError
+    
+    def decode_tokens(self, tokens):
+        raise NotImplementedError
+
+class AutoencoderPretransform(Pretransform):
+    def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
+        super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
+        self.model = model
+        self.model.requires_grad_(False).eval()
+        self.scale=scale
+        self.downsampling_ratio = model.downsampling_ratio
+        self.io_channels = model.io_channels
+        self.sample_rate = model.sample_rate
+        
+        self.model_half = model_half
+        self.iterate_batch = iterate_batch
+
+        self.encoded_channels = model.latent_dim
+
+        self.chunked = chunked
+        self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
+        self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
+
+        if self.model_half:
+            self.model.half()
+    
+    def encode(self, x, **kwargs):
+        
+        if self.model_half:
+            x = x.half()
+            self.model.to(torch.float16)
+
+        encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+        if self.model_half:
+            encoded = encoded.float()
+
+        return encoded / self.scale
+
+    def decode(self, z, **kwargs):
+        z = z * self.scale
+
+        if self.model_half:
+            z = z.half()
+            self.model.to(torch.float16)
+
+        decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+        if self.model_half:
+            decoded = decoded.float()
+
+        return decoded
+    
+    def tokenize(self, x, **kwargs):
+        assert self.model.is_discrete, "Cannot tokenize with a continuous model"
+
+        _, info = self.model.encode(x, return_info = True, **kwargs)
+
+        return info[self.model.bottleneck.tokens_id]
+    
+    def decode_tokens(self, tokens, **kwargs):
+        assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
+
+        return self.model.decode_tokens(tokens, **kwargs)
+    
+    def load_state_dict(self, state_dict, strict=True):
+        self.model.load_state_dict(state_dict, strict=strict)
+
+class WaveletPretransform(Pretransform):
+    def __init__(self, channels, levels, wavelet):
+        super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
+
+        from .wavelets import WaveletEncode1d, WaveletDecode1d
+
+        self.encoder = WaveletEncode1d(channels, levels, wavelet)
+        self.decoder = WaveletDecode1d(channels, levels, wavelet)
+
+        self.downsampling_ratio = 2 ** levels
+        self.io_channels = channels
+        self.encoded_channels = channels * self.downsampling_ratio
+    
+    def encode(self, x):
+        return self.encoder(x)
+    
+    def decode(self, z):
+        return self.decoder(z)
+    
+class PQMFPretransform(Pretransform):
+    def __init__(self, attenuation=100, num_bands=16):
+        # TODO: Fix PQMF to take in in-channels
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
+        from .pqmf import PQMF
+        self.pqmf = PQMF(attenuation, num_bands)
+
+
+    def encode(self, x):
+        # x is (Batch x Channels x Time)
+        x = self.pqmf.forward(x)
+        # pqmf.forward returns (Batch x Channels x Bands x Time)
+        # but Pretransform needs Batch x Channels x Time
+        # so concatenate channels and bands into one axis
+        return rearrange(x, "b c n t -> b (c n) t")
+
+    def decode(self, x):
+        # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) 
+        x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
+        # returns (Batch x Channels x Time) 
+        return self.pqmf.inverse(x)
+        
+class PretrainedDACPretransform(Pretransform):
+    def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+        
+        import dac
+        
+        model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
+        
+        self.model = dac.DAC.load(model_path)
+
+        self.quantize_on_decode = quantize_on_decode
+
+        if model_type == "44khz":
+            self.downsampling_ratio = 512
+        else:
+            self.downsampling_ratio = 320
+
+        self.io_channels = 1
+
+        self.scale = scale
+
+        self.chunked = chunked
+
+        self.encoded_channels = self.model.latent_dim
+
+        self.num_quantizers = self.model.n_codebooks
+
+        self.codebook_size = self.model.codebook_size
+
+    def encode(self, x):
+
+        latents = self.model.encoder(x)
+
+        if self.quantize_on_decode:
+            output = latents
+        else:
+            z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+            output = z
+        
+        if self.scale != 1.0:
+            output = output / self.scale
+        
+        return output
+
+    def decode(self, z):
+        
+        if self.scale != 1.0:
+            z = z * self.scale
+
+        if self.quantize_on_decode:
+            z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+        return self.model.decode(z)
+
+    def tokenize(self, x):
+        return self.model.encode(x)[1]
+    
+    def decode_tokens(self, tokens):
+        latents = self.model.quantizer.from_codes(tokens)
+        return self.model.decode(latents)
+    
+class AudiocraftCompressionPretransform(Pretransform):
+    def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+        
+        try:
+            from audiocraft.models import CompressionModel
+        except ImportError:
+            raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
+               
+        self.model = CompressionModel.get_pretrained(model_type)
+
+        self.quantize_on_decode = quantize_on_decode
+
+        self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
+
+        self.sample_rate = self.model.sample_rate
+
+        self.io_channels = self.model.channels
+
+        self.scale = scale
+
+        #self.encoded_channels = self.model.latent_dim
+
+        self.num_quantizers = self.model.num_codebooks
+
+        self.codebook_size = self.model.cardinality
+
+        self.model.to(torch.float16).eval().requires_grad_(False)
+
+    def encode(self, x):
+
+        assert False, "Audiocraft compression models do not support continuous encoding"
+
+        # latents = self.model.encoder(x)
+
+        # if self.quantize_on_decode:
+        #     output = latents
+        # else:
+        #     z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+        #     output = z
+        
+        # if self.scale != 1.0:
+        #     output = output / self.scale
+        
+        # return output
+
+    def decode(self, z):
+        
+        assert False, "Audiocraft compression models do not support continuous decoding"
+
+        # if self.scale != 1.0:
+        #     z = z * self.scale
+
+        # if self.quantize_on_decode:
+        #     z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+        # return self.model.decode(z)
+
+    def tokenize(self, x):
+        with torch.cuda.amp.autocast(enabled=False):
+            return self.model.encode(x.to(torch.float16))[0]
+    
+    def decode_tokens(self, tokens):
+        with torch.cuda.amp.autocast(enabled=False):
+            return self.model.decode(tokens)
diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8eeaf773d47db2c000a3b2237d88d310214dcf
--- /dev/null
+++ b/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+from torchaudio import transforms as T
+
+
+class PadCrop(nn.Module):
+    def __init__(self, n_samples, randomize=True):
+        super().__init__()
+        self.n_samples = n_samples
+        self.randomize = randomize
+
+    def __call__(self, signal):
+        n, s = signal.shape
+        start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
+        end = start + self.n_samples
+        output = signal.new_zeros([n, self.n_samples])
+        output[:, :min(s, self.n_samples)] = signal[:, start:end]
+        return output
+
+
+def set_audio_channels(audio, target_channels):
+    if target_channels == 1:
+        # Convert to mono
+        audio = audio.mean(1, keepdim=True)
+    elif target_channels == 2:
+        # Convert to stereo
+        if audio.shape[1] == 1:
+            audio = audio.repeat(1, 2, 1)
+        elif audio.shape[1] > 2:
+            audio = audio[:, :2, :]
+    return audio
+
+def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
+    
+    audio = audio.to(device)
+
+    if in_sr != target_sr:
+        resample_tf = T.Resample(in_sr, target_sr).to(device)
+        audio = resample_tf(audio)
+
+    audio = PadCrop(target_length, randomize=False)(audio)
+
+    # Add batch dimension
+    if audio.dim() == 1:
+        audio = audio.unsqueeze(0).unsqueeze(0)
+    elif audio.dim() == 2:
+        audio = audio.unsqueeze(0)
+
+    audio = set_audio_channels(audio, target_channels)
+
+    return audio
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4213aeeee249e1d2a179e446ba30328d221ea2cd
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08682d55156c6ecc615f63e6f5965578b382715f
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13098e65f73ca56e09af304cc58a5b91e09b9230
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cda754279d6057bf47564662d75126fdc161716
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed2e2fecc81e0df784bbe5fc60247202f5916b43
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..280282dd1b335001ac67af2d0a6f567e498544aa
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6023bf14f71f20035d26f964428a5cf099bce8ae
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11fc6d584244a08c1ac553468ff9630b7fdfa42f
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1103c38bb0506ad111c4a1231a93aa9bb117791
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..112c3a4193ea3cc5b7fe48acce3399214d9cb634
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..735577fd4a79f20b8b1da0a4fc546714ab048eca
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d72ad1a597cb988276627b751e986040e489542d
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18170d2cfdfc233f2d04fde448d30c4a8e5a12c9
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79028d2ae1b49d2819c08ab69353eb6f15e2e455
Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/autoencoders.py b/src/modules/stable_vae/models/autoencoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..2741d45e70e3c25bc28f2b4e43b1a3e925a4c9e3
--- /dev/null
+++ b/src/modules/stable_vae/models/autoencoders.py
@@ -0,0 +1,683 @@
+import torch
+import math
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+from torchaudio import transforms as T
+from alias_free_torch import Activation1d
+from .nn.layers import WNConv1d, WNConvTranspose1d
+from typing import Literal, Dict, Any
+
+# from .inference.sampling import sample
+from .utils import prepare_audio
+from .blocks import SnakeBeta
+from .bottleneck import Bottleneck, DiscreteBottleneck
+from .factory import create_pretransform_from_config, create_bottleneck_from_config
+from .pretransforms import Pretransform
+
+def checkpoint(function, *args, **kwargs):
+    kwargs.setdefault("use_reentrant", False)
+    return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
+
+def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
+    if activation == "elu":
+        act = nn.ELU()
+    elif activation == "snake":
+        act = SnakeBeta(channels)
+    elif activation == "none":
+        act = nn.Identity()
+    else:
+        raise ValueError(f"Unknown activation {activation}")
+    
+    if antialias:
+        act = Activation1d(act)
+    
+    return act
+
+class ResidualUnit(nn.Module):
+    def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
+        super().__init__()
+        
+        self.dilation = dilation
+
+        padding = (dilation * (7-1)) // 2
+
+        self.layers = nn.Sequential(
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+            WNConv1d(in_channels=in_channels, out_channels=out_channels,
+                      kernel_size=7, dilation=dilation, padding=padding),
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+            WNConv1d(in_channels=out_channels, out_channels=out_channels,
+                      kernel_size=1)
+        )
+
+    def forward(self, x):
+        res = x
+        
+        #x = checkpoint(self.layers, x)
+        x = self.layers(x)
+
+        return x + res
+
+class EncoderBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
+        super().__init__()
+
+        self.layers = nn.Sequential(
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=1, use_snake=use_snake),
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=3, use_snake=use_snake),
+            ResidualUnit(in_channels=in_channels,
+                         out_channels=in_channels, dilation=9, use_snake=use_snake),
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+            WNConv1d(in_channels=in_channels, out_channels=out_channels,
+                      kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+class DecoderBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
+        super().__init__()
+
+        if use_nearest_upsample:
+            upsample_layer = nn.Sequential(
+                nn.Upsample(scale_factor=stride, mode="nearest"),
+                WNConv1d(in_channels=in_channels,
+                        out_channels=out_channels, 
+                        kernel_size=2*stride,
+                        stride=1,
+                        bias=False,
+                        padding='same')
+            )
+        else:
+            upsample_layer = WNConvTranspose1d(in_channels=in_channels,
+                               out_channels=out_channels,
+                               kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
+
+        self.layers = nn.Sequential(
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+            upsample_layer,
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=1, use_snake=use_snake),
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=3, use_snake=use_snake),
+            ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+                         dilation=9, use_snake=use_snake),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+class OobleckEncoder(nn.Module):
+    def __init__(self, 
+                 in_channels=2, 
+                 channels=128, 
+                 latent_dim=32, 
+                 c_mults = [1, 2, 4, 8], 
+                 strides = [2, 4, 8, 8],
+                 use_snake=False,
+                 antialias_activation=False
+        ):
+        super().__init__()
+          
+        c_mults = [1] + c_mults
+
+        self.depth = len(c_mults)
+
+        layers = [
+            WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
+        ]
+        
+        for i in range(self.depth-1):
+            layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
+
+        layers += [
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
+            WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
+        ]
+
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class OobleckDecoder(nn.Module):
+    def __init__(self, 
+                 out_channels=2, 
+                 channels=128, 
+                 latent_dim=32, 
+                 c_mults = [1, 2, 4, 8], 
+                 strides = [2, 4, 8, 8],
+                 use_snake=False,
+                 antialias_activation=False,
+                 use_nearest_upsample=False,
+                 final_tanh=True):
+        super().__init__()
+
+        c_mults = [1] + c_mults
+        
+        self.depth = len(c_mults)
+
+        layers = [
+            WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
+        ]
+        
+        for i in range(self.depth-1, 0, -1):
+            layers += [DecoderBlock(
+                in_channels=c_mults[i]*channels, 
+                out_channels=c_mults[i-1]*channels, 
+                stride=strides[i-1], 
+                use_snake=use_snake, 
+                antialias_activation=antialias_activation,
+                use_nearest_upsample=use_nearest_upsample
+                )
+            ]
+
+        layers += [
+            get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
+            WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
+            nn.Tanh() if final_tanh else nn.Identity()
+        ]
+
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class DACEncoderWrapper(nn.Module):
+    def __init__(self, in_channels=1, **kwargs):
+        super().__init__()
+
+        from dac.model.dac import Encoder as DACEncoder
+
+        latent_dim = kwargs.pop("latent_dim", None)
+
+        encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
+        self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
+        self.latent_dim = latent_dim
+
+        # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
+        self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
+
+        if in_channels != 1:
+            self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.proj_out(x)
+        return x
+
+class DACDecoderWrapper(nn.Module):
+    def __init__(self, latent_dim, out_channels=1, **kwargs):
+        super().__init__()
+
+        from dac.model.dac import Decoder as DACDecoder
+
+        self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
+
+        self.latent_dim = latent_dim
+
+    def forward(self, x):
+        return self.decoder(x)
+
+class AudioAutoencoder(nn.Module):
+    def __init__(
+        self,
+        encoder,
+        decoder,
+        latent_dim,
+        downsampling_ratio,
+        sample_rate,
+        io_channels=2,
+        bottleneck: Bottleneck = None,
+        pretransform: Pretransform = None,
+        in_channels = None,
+        out_channels = None,
+        soft_clip = False
+    ):
+        super().__init__()
+
+        self.downsampling_ratio = downsampling_ratio
+        self.sample_rate = sample_rate
+
+        self.latent_dim = latent_dim
+        self.io_channels = io_channels
+        self.in_channels = io_channels
+        self.out_channels = io_channels
+
+        self.min_length = self.downsampling_ratio
+
+        if in_channels is not None:
+            self.in_channels = in_channels
+
+        if out_channels is not None:
+            self.out_channels = out_channels
+
+        self.bottleneck = bottleneck
+
+        self.encoder = encoder
+
+        self.decoder = decoder
+
+        self.pretransform = pretransform
+
+        self.soft_clip = soft_clip
+ 
+        self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
+
+    def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
+
+        info = {}
+
+        if self.pretransform is not None and not skip_pretransform:
+            if self.pretransform.enable_grad:
+                if iterate_batch:
+                    audios = []
+                    for i in range(audio.shape[0]):
+                        audios.append(self.pretransform.encode(audio[i:i+1]))
+                    audio = torch.cat(audios, dim=0)
+                else:
+                    audio = self.pretransform.encode(audio)
+            else:
+                with torch.no_grad():
+                    if iterate_batch:
+                        audios = []
+                        for i in range(audio.shape[0]):
+                            audios.append(self.pretransform.encode(audio[i:i+1]))
+                        audio = torch.cat(audios, dim=0)
+                    else:
+                        audio = self.pretransform.encode(audio)
+
+        if self.encoder is not None:
+            if iterate_batch:
+                latents = []
+                for i in range(audio.shape[0]):
+                    latents.append(self.encoder(audio[i:i+1]))
+                latents = torch.cat(latents, dim=0)
+            else:
+                latents = self.encoder(audio)
+        else:
+            latents = audio
+
+        if self.bottleneck is not None:
+            # TODO: Add iterate batch logic, needs to merge the info dicts
+            latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
+
+            info.update(bottleneck_info)
+
+        if return_info:
+            return latents, info
+
+        return latents
+
+    def decode(self, latents, iterate_batch=False, **kwargs):
+
+        if self.bottleneck is not None:
+            if iterate_batch:
+                decoded = []
+                for i in range(latents.shape[0]):
+                    decoded.append(self.bottleneck.decode(latents[i:i+1]))
+                decoded = torch.cat(decoded, dim=0)
+            else:
+                latents = self.bottleneck.decode(latents)
+
+        if iterate_batch:
+            decoded = []
+            for i in range(latents.shape[0]):
+                decoded.append(self.decoder(latents[i:i+1]))
+            decoded = torch.cat(decoded, dim=0)
+        else:
+            decoded = self.decoder(latents, **kwargs)
+
+        if self.pretransform is not None:
+            if self.pretransform.enable_grad:
+                if iterate_batch:
+                    decodeds = []
+                    for i in range(decoded.shape[0]):
+                        decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+                    decoded = torch.cat(decodeds, dim=0)
+                else:
+                    decoded = self.pretransform.decode(decoded)
+            else:
+                with torch.no_grad():
+                    if iterate_batch:
+                        decodeds = []
+                        for i in range(latents.shape[0]):
+                            decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+                        decoded = torch.cat(decodeds, dim=0)
+                    else:
+                        decoded = self.pretransform.decode(decoded)
+
+        if self.soft_clip:
+            decoded = torch.tanh(decoded)
+
+        return decoded
+
+    def decode_tokens(self, tokens, **kwargs):
+        '''
+        Decode discrete tokens to audio
+        Only works with discrete autoencoders
+        '''
+
+        assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
+
+        latents = self.bottleneck.decode_tokens(tokens, **kwargs)
+
+        return self.decode(latents, **kwargs)
+        
+    
+    def preprocess_audio_for_encoder(self, audio, in_sr):
+        '''
+        Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
+        If the model is mono, stereo audio will be converted to mono.
+        Audio will be silence-padded to be a multiple of the model's downsampling ratio.
+        Audio will be resampled to the model's sample rate. 
+        The output will have batch size 1 and be shape (1 x Channels x Length)
+        '''
+        return self.preprocess_audio_list_for_encoder([audio], [in_sr])
+
+    def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
+        '''
+        Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. 
+        The audio in that list can be of different lengths and channels. 
+        in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
+        All audio will be resampled to the model's sample rate. 
+        Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. 
+        If the model is mono, all audio will be converted to mono. 
+        The output will be a tensor of shape (Batch x Channels x Length)
+        '''
+        batch_size = len(audio_list)
+        if isinstance(in_sr_list, int):
+            in_sr_list = [in_sr_list]*batch_size
+        assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
+        new_audio = []
+        max_length = 0
+        # resample & find the max length
+        for i in range(batch_size):
+            audio = audio_list[i]
+            in_sr = in_sr_list[i]
+            if len(audio.shape) == 3 and audio.shape[0] == 1:
+                # batchsize 1 was given by accident. Just squeeze it.
+                audio = audio.squeeze(0)
+            elif len(audio.shape) == 1:
+                # Mono signal, channel dimension is missing, unsqueeze it in
+                audio = audio.unsqueeze(0)
+            assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" 
+            # Resample audio
+            if in_sr != self.sample_rate:
+                resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
+                audio = resample_tf(audio)
+            new_audio.append(audio)
+            if audio.shape[-1] > max_length:
+                max_length = audio.shape[-1]
+        # Pad every audio to the same length, multiple of model's downsampling ratio
+        padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
+        for i in range(batch_size):
+            # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
+            new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, 
+                target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
+        # convert to tensor 
+        return torch.stack(new_audio) 
+
+    def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
+        '''
+        Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
+        If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
+        Overlap and chunk_size params are both measured in number of latents (not audio samples) 
+        # and therefore you likely could use the same values with decode_audio. 
+        A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. 
+        Every autoencoder will have a different receptive field size, and thus ideal overlap.
+        You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
+        The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+        Smaller chunk_size uses less memory, but more compute.
+        The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+        For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+        '''
+        if not chunked:
+            # default behavior. Encode the entire audio in parallel
+            return self.encode(audio, **kwargs)
+        else:
+            # CHUNKED ENCODING
+            # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
+            samples_per_latent = self.downsampling_ratio
+            total_size = audio.shape[2] # in samples
+            batch_size = audio.shape[0]
+            chunk_size *= samples_per_latent # converting metric in latents to samples
+            overlap *= samples_per_latent # converting metric in latents to samples
+            hop_size = chunk_size - overlap
+            chunks = []
+            for i in range(0, total_size - chunk_size + 1, hop_size):
+                chunk = audio[:,:,i:i+chunk_size]
+                chunks.append(chunk)
+            if i+chunk_size != total_size:
+                # Final chunk
+                chunk = audio[:,:,-chunk_size:]
+                chunks.append(chunk)
+            chunks = torch.stack(chunks)
+            num_chunks = chunks.shape[0]
+            # Note: y_size might be a different value from the latent length used in diffusion training
+            # because we can encode audio of varying lengths
+            # However, the audio should've been padded to a multiple of samples_per_latent by now.
+            y_size = total_size // samples_per_latent
+            # Create an empty latent, we will populate it with chunks as we encode them
+            y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
+            for i in range(num_chunks):
+                x_chunk = chunks[i,:]
+                # encode the chunk
+                y_chunk = self.encode(x_chunk)
+                # figure out where to put the audio along the time domain
+                if i == num_chunks-1:
+                    # final chunk always goes at the end
+                    t_end = y_size
+                    t_start = t_end - y_chunk.shape[2]
+                else:
+                    t_start = i * hop_size // samples_per_latent
+                    t_end = t_start + chunk_size // samples_per_latent
+                #  remove the edges of the overlaps
+                ol = overlap//samples_per_latent//2
+                chunk_start = 0
+                chunk_end = y_chunk.shape[2]
+                if i > 0:
+                    # no overlap for the start of the first chunk
+                    t_start += ol
+                    chunk_start += ol
+                if i < num_chunks-1:
+                    # no overlap for the end of the last chunk
+                    t_end -= ol
+                    chunk_end -= ol
+                # paste the chunked audio into our y_final output audio
+                y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+            return y_final
+    
+    def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
+        '''
+        Decode latents to audio. 
+        If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. 
+        A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. 
+        Every autoencoder will have a different receptive field size, and thus ideal overlap.
+        You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
+        The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+        Smaller chunk_size uses less memory, but more compute.
+        The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+        For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+        '''
+        if not chunked:
+            # default behavior. Decode the entire latent in parallel
+            return self.decode(latents, **kwargs)
+        else:
+            # chunked decoding
+            hop_size = chunk_size - overlap
+            total_size = latents.shape[2]
+            batch_size = latents.shape[0]
+            chunks = []
+            for i in range(0, total_size - chunk_size + 1, hop_size):
+                chunk = latents[:,:,i:i+chunk_size]
+                chunks.append(chunk)
+            if i+chunk_size != total_size:
+                # Final chunk
+                chunk = latents[:,:,-chunk_size:]
+                chunks.append(chunk)
+            chunks = torch.stack(chunks)
+            num_chunks = chunks.shape[0]
+            # samples_per_latent is just the downsampling ratio
+            samples_per_latent = self.downsampling_ratio
+            # Create an empty waveform, we will populate it with chunks as decode them
+            y_size = total_size * samples_per_latent
+            y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
+            for i in range(num_chunks):
+                x_chunk = chunks[i,:]
+                # decode the chunk
+                y_chunk = self.decode(x_chunk)
+                # figure out where to put the audio along the time domain
+                if i == num_chunks-1:
+                    # final chunk always goes at the end
+                    t_end = y_size
+                    t_start = t_end - y_chunk.shape[2]
+                else:
+                    t_start = i * hop_size * samples_per_latent
+                    t_end = t_start + chunk_size * samples_per_latent
+                #  remove the edges of the overlaps
+                ol = (overlap//2) * samples_per_latent
+                chunk_start = 0
+                chunk_end = y_chunk.shape[2]
+                if i > 0:
+                    # no overlap for the start of the first chunk
+                    t_start += ol
+                    chunk_start += ol
+                if i < num_chunks-1:
+                    # no overlap for the end of the last chunk
+                    t_end -= ol
+                    chunk_end -= ol
+                # paste the chunked audio into our y_final output audio
+                y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+            return y_final
+
+    
+# AE factories
+
+def create_encoder_from_config(encoder_config: Dict[str, Any]):
+    encoder_type = encoder_config.get("type", None)
+    assert encoder_type is not None, "Encoder type must be specified"
+
+    if encoder_type == "oobleck":
+        encoder = OobleckEncoder(
+            **encoder_config["config"]
+        )
+    
+    elif encoder_type == "seanet":
+        from encodec.modules import SEANetEncoder
+        seanet_encoder_config = encoder_config["config"]
+
+        #SEANet encoder expects strides in reverse order
+        seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
+        encoder = SEANetEncoder(
+            **seanet_encoder_config
+        )
+    elif encoder_type == "dac":
+        dac_config = encoder_config["config"]
+
+        encoder = DACEncoderWrapper(**dac_config)
+    elif encoder_type == "local_attn":
+        from .local_attention import TransformerEncoder1D
+
+        local_attn_config = encoder_config["config"]
+
+        encoder = TransformerEncoder1D(
+            **local_attn_config
+        )
+    else:
+        raise ValueError(f"Unknown encoder type {encoder_type}")
+    
+    requires_grad = encoder_config.get("requires_grad", True)
+    if not requires_grad:
+        for param in encoder.parameters():
+            param.requires_grad = False
+
+    return encoder
+
+def create_decoder_from_config(decoder_config: Dict[str, Any]):
+    decoder_type = decoder_config.get("type", None)
+    assert decoder_type is not None, "Decoder type must be specified"
+
+    if decoder_type == "oobleck":
+        decoder = OobleckDecoder(
+            **decoder_config["config"]
+        )
+    elif decoder_type == "seanet":
+        from encodec.modules import SEANetDecoder
+
+        decoder = SEANetDecoder(
+            **decoder_config["config"]
+        )
+    elif decoder_type == "dac":
+        dac_config = decoder_config["config"]
+
+        decoder = DACDecoderWrapper(**dac_config)
+    elif decoder_type == "local_attn":
+        from .local_attention import TransformerDecoder1D
+
+        local_attn_config = decoder_config["config"]
+
+        decoder = TransformerDecoder1D(
+            **local_attn_config
+        )
+    else:
+        raise ValueError(f"Unknown decoder type {decoder_type}")
+    
+    requires_grad = decoder_config.get("requires_grad", True)
+    if not requires_grad:
+        for param in decoder.parameters():
+            param.requires_grad = False
+
+    return decoder
+
+def create_autoencoder_from_config(config: Dict[str, Any]):
+    
+    ae_config = config["model"]
+
+    encoder = create_encoder_from_config(ae_config["encoder"])
+    decoder = create_decoder_from_config(ae_config["decoder"])
+
+    bottleneck = ae_config.get("bottleneck", None)
+
+    latent_dim = ae_config.get("latent_dim", None)
+    assert latent_dim is not None, "latent_dim must be specified in model config"
+    downsampling_ratio = ae_config.get("downsampling_ratio", None)
+    assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
+    io_channels = ae_config.get("io_channels", None)
+    assert io_channels is not None, "io_channels must be specified in model config"
+    sample_rate = config.get("sample_rate", None)
+    assert sample_rate is not None, "sample_rate must be specified in model config"
+
+    in_channels = ae_config.get("in_channels", None)
+    out_channels = ae_config.get("out_channels", None)
+
+    pretransform = ae_config.get("pretransform", None)
+
+    if pretransform is not None:
+        pretransform = create_pretransform_from_config(pretransform, sample_rate)
+
+    if bottleneck is not None:
+        bottleneck = create_bottleneck_from_config(bottleneck)
+
+    soft_clip = ae_config["decoder"].get("soft_clip", False)
+
+    return AudioAutoencoder(
+        encoder,
+        decoder,
+        io_channels=io_channels,
+        latent_dim=latent_dim,
+        downsampling_ratio=downsampling_ratio,
+        sample_rate=sample_rate,
+        bottleneck=bottleneck,
+        pretransform=pretransform,
+        in_channels=in_channels,
+        out_channels=out_channels,
+        soft_clip=soft_clip
+    )
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/blocks.py b/src/modules/stable_vae/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb310c8980ef5dc0f138e6f9f3478d4cdc63354d
--- /dev/null
+++ b/src/modules/stable_vae/models/blocks.py
@@ -0,0 +1,359 @@
+from functools import reduce
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.backends.cuda import sdp_kernel
+from packaging import version
+
+from .nn.layers import Snake1d
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, main, skip=None):
+        super().__init__()
+        self.main = nn.Sequential(*main)
+        self.skip = skip if skip else nn.Identity()
+
+    def forward(self, input):
+        return self.main(input) + self.skip(input)
+
+
+class ResConvBlock(ResidualBlock):
+    def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
+        skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
+        super().__init__([
+            nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
+            nn.GroupNorm(1, c_mid),
+            Snake1d(c_mid) if use_snake else nn.GELU(),
+            nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
+            nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
+            (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
+        ], skip)
+
+
+class SelfAttention1d(nn.Module):
+    def __init__(self, c_in, n_head=1, dropout_rate=0.):
+        super().__init__()
+        assert c_in % n_head == 0
+        self.norm = nn.GroupNorm(1, c_in)
+        self.n_head = n_head
+        self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
+        self.out_proj = nn.Conv1d(c_in, c_in, 1)
+        self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+        self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
+
+        if not self.use_flash:
+            return
+
+        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+
+        if device_properties.major == 8 and device_properties.minor == 0:
+            # Use flash attention for A100 GPUs
+            self.sdp_kernel_config = (True, False, False)
+        else:
+            # Don't use flash attention for other GPUs
+            self.sdp_kernel_config = (False, True, True)
+
+    def forward(self, input):
+        n, c, s = input.shape
+        qkv = self.qkv_proj(self.norm(input))
+        qkv = qkv.view(
+            [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = k.shape[3]**-0.25
+
+        if self.use_flash:
+            with sdp_kernel(*self.sdp_kernel_config):
+                y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
+        else:
+            att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
+            y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
+
+
+        return input + self.dropout(self.out_proj(y))
+
+
+class SkipBlock(nn.Module):
+    def __init__(self, *main):
+        super().__init__()
+        self.main = nn.Sequential(*main)
+
+    def forward(self, input):
+        return torch.cat([self.main(input), input], dim=1)
+
+
+class FourierFeatures(nn.Module):
+    def __init__(self, in_features, out_features, std=1.):
+        super().__init__()
+        assert out_features % 2 == 0
+        self.weight = nn.Parameter(torch.randn(
+            [out_features // 2, in_features]) * std)
+
+    def forward(self, input):
+        f = 2 * math.pi * input @ self.weight.T
+        return torch.cat([f.cos(), f.sin()], dim=-1)
+
+
+def expand_to_planes(input, shape):
+    return input[..., None].repeat([1, 1, shape[2]])
+
+_kernels = {
+    'linear':
+        [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+    'cubic': 
+        [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
+        0.43359375, 0.11328125, -0.03515625, -0.01171875],
+    'lanczos3': 
+        [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
+        -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
+        0.44638532400131226, 0.13550527393817902, -0.066637322306633,
+        -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
+}
+
+
+class Downsample1d(nn.Module):
+    def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+        super().__init__()
+        self.pad_mode = pad_mode
+        kernel_1d = torch.tensor(_kernels[kernel])
+        self.pad = kernel_1d.shape[0] // 2 - 1
+        self.register_buffer('kernel', kernel_1d)
+        self.channels_last = channels_last
+    
+    def forward(self, x):
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        x = F.pad(x, (self.pad,) * 2, self.pad_mode)
+        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+        indices = torch.arange(x.shape[1], device=x.device)
+        weight[indices, indices] = self.kernel.to(weight)
+        x = F.conv1d(x, weight, stride=2)
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        return x
+
+
+class Upsample1d(nn.Module):
+    def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+        super().__init__()
+        self.pad_mode = pad_mode
+        kernel_1d = torch.tensor(_kernels[kernel]) * 2
+        self.pad = kernel_1d.shape[0] // 2 - 1
+        self.register_buffer('kernel', kernel_1d)
+        self.channels_last = channels_last
+    
+    def forward(self, x):
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+        indices = torch.arange(x.shape[1], device=x.device)
+        weight[indices, indices] = self.kernel.to(weight)
+        x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
+        if self.channels_last:
+            x = x.permute(0, 2, 1)
+        return x
+
+
+def Downsample1d_2(
+    in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
+) -> nn.Module:
+    assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
+
+    return nn.Conv1d(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=factor * kernel_multiplier + 1,
+        stride=factor,
+        padding=factor * (kernel_multiplier // 2),
+    )
+
+
+def Upsample1d_2(
+    in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
+) -> nn.Module:
+
+    if factor == 1:
+        return nn.Conv1d(
+            in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
+        )
+
+    if use_nearest:
+        return nn.Sequential(
+            nn.Upsample(scale_factor=factor, mode="nearest"),
+            nn.Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                padding=1,
+            ),
+        )
+    else:
+        return nn.ConvTranspose1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=factor * 2,
+            stride=factor,
+            padding=factor // 2 + factor % 2,
+            output_padding=factor % 2,
+        )
+
+
+def zero_init(layer):
+    nn.init.zeros_(layer.weight)
+    if layer.bias is not None:
+        nn.init.zeros_(layer.bias)
+    return layer
+
+
+def rms_norm(x, scale, eps):
+    dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+    mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+    scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+    return x * scale.to(x.dtype)
+
+#rms_norm = torch.compile(rms_norm)
+
+class AdaRMSNorm(nn.Module):
+    def __init__(self, features, cond_features, eps=1e-6):
+        super().__init__()
+        self.eps = eps
+        self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
+  
+    def extra_repr(self):
+        return f"eps={self.eps},"
+
+    def forward(self, x, cond):
+        return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
+
+
+def normalize(x, eps=1e-4):
+    dim = list(range(1, x.ndim))
+    n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
+    alpha = np.sqrt(n.numel() / x.numel())
+    return x / torch.add(eps, n, alpha=alpha)
+
+
+class ForcedWNConv1d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=1):
+        super().__init__()
+        self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
+
+    def forward(self, x):
+        if self.training:
+            with torch.no_grad():
+                self.weight.copy_(normalize(self.weight))
+        
+        fan_in = self.weight[0].numel()
+
+        w = normalize(self.weight) / math.sqrt(fan_in)
+
+        return F.conv1d(x, w, padding='same')
+        
+# Kernels
+
+use_compile = True
+
+def compile(function, *args, **kwargs):
+    if not use_compile:
+        return function
+    try:
+        return torch.compile(function, *args, **kwargs)
+    except RuntimeError:
+        return function
+
+
+@compile
+def linear_geglu(x, weight, bias=None):
+    x = x @ weight.mT
+    if bias is not None:
+        x = x + bias
+    x, gate = x.chunk(2, dim=-1)
+    return x * F.gelu(gate)
+
+
+@compile
+def rms_norm(x, scale, eps):
+    dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+    mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+    scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+    return x * scale.to(x.dtype)
+
+# Layers
+
+
+class LinearGEGLU(nn.Linear):
+    def __init__(self, in_features, out_features, bias=True):
+        super().__init__(in_features, out_features * 2, bias=bias)
+        self.out_features = out_features
+
+    def forward(self, x):
+        return linear_geglu(x, self.weight, self.bias)
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, shape, fix_scale = False, eps=1e-6):
+        super().__init__()
+        self.eps = eps
+
+        if fix_scale:
+            self.register_buffer("scale", torch.ones(shape))
+        else:
+            self.scale = nn.Parameter(torch.ones(shape))
+
+    def extra_repr(self):
+        return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
+
+    def forward(self, x):
+        return rms_norm(x, self.scale, self.eps)
+
+
+# jit script make it 1.4x faster and save GPU memory
+@torch.jit.script
+def snake_beta(x, alpha, beta):
+    return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
+
+# try:
+#     snake_beta = torch.compile(snake_beta)
+# except RuntimeError:
+#     pass
+
+
+# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
+# License available in LICENSES/LICENSE_NVIDIA.txt
+class SnakeBeta(nn.Module):
+
+    def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
+        super(SnakeBeta, self).__init__()
+        self.in_features = in_features
+
+        # initialize alpha
+        self.alpha_logscale = alpha_logscale
+        if self.alpha_logscale: 
+            # log scale alphas initialized to zeros
+            self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+            self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+        else:
+            # linear scale alphas initialized to ones
+            self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+            self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+        # self.no_div_by_zero = 0.000000001
+
+    def forward(self, x):
+        alpha = self.alpha.unsqueeze(0).unsqueeze(-1) 
+        # line up with x to [B, C, T]
+        beta = self.beta.unsqueeze(0).unsqueeze(-1)
+        if self.alpha_logscale:
+            alpha = torch.exp(alpha)
+            beta = torch.exp(beta)
+        x = snake_beta(x, alpha, beta)
+
+        return x
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/bottleneck.py b/src/modules/stable_vae/models/bottleneck.py
new file mode 100644
index 0000000000000000000000000000000000000000..df88c5f1b1f5fa3675c1a42f42e5e31e27d00ed3
--- /dev/null
+++ b/src/modules/stable_vae/models/bottleneck.py
@@ -0,0 +1,346 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from einops import rearrange
+from vector_quantize_pytorch import ResidualVQ, FSQ
+from .nn.quantize import ResidualVectorQuantize as DACResidualVQ
+
+
+class Bottleneck(nn.Module):
+    def __init__(self, is_discrete: bool = False):
+        super().__init__()
+
+        self.is_discrete = is_discrete
+
+    def encode(self, x, return_info=False, **kwargs):
+        raise NotImplementedError
+
+    def decode(self, x):
+        raise NotImplementedError
+
+
+class DiscreteBottleneck(Bottleneck):
+    def __init__(self, num_quantizers, codebook_size, tokens_id):
+        super().__init__(is_discrete=True)
+
+        self.num_quantizers = num_quantizers
+        self.codebook_size = codebook_size
+        self.tokens_id = tokens_id
+
+    def decode_tokens(self, codes, **kwargs):
+        raise NotImplementedError
+
+
+class TanhBottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+        self.tanh = nn.Tanh()
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = torch.tanh(x)
+
+        if return_info:
+            return x, info
+        else:
+            return x
+
+    def decode(self, x):
+        return x
+
+
+@torch.jit.script
+def vae_sample_kl(mean, scale):
+    stdev = nn.functional.softplus(scale) + 1e-4
+    var = stdev * stdev
+    logvar = torch.log(var)
+    latents = torch.randn_like(mean) * stdev + mean
+
+    kl = (mean * mean + var - logvar - 1).sum(1).mean()
+
+    return latents, kl
+
+
+@torch.jit.script
+def vae_sample(mean, scale):
+    stdev = nn.functional.softplus(scale) + 1e-4
+    latents = torch.randn_like(mean) * stdev + mean
+    return latents
+
+
+class VAEBottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+
+    def encode(self, x, return_info=False, **kwargs):
+        mean, scale = x.chunk(2, dim=1)
+
+        if return_info:
+            info = {}
+            x, kl = vae_sample_kl(mean, scale)
+            info["kl"] = kl
+            return x, info
+        else:
+            x = vae_sample(mean, scale)
+            return x
+
+    def decode(self, x):
+        return x
+
+
+def compute_mean_kernel(x, y):
+    kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
+    return torch.exp(-kernel_input).mean()
+
+
+def compute_mmd(latents):
+    latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
+    noise = torch.randn_like(latents_reshaped)
+
+    latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
+    noise_kernel = compute_mean_kernel(noise, noise)
+    latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
+    
+    mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
+    return mmd.mean()
+
+
+class WassersteinBottleneck(Bottleneck):
+    def __init__(self, noise_augment_dim: int = 0):
+        super().__init__(is_discrete=False)
+
+        self.noise_augment_dim = noise_augment_dim
+    
+    def encode(self, x, return_info=False):
+        info = {}
+
+        if self.training and return_info:
+            mmd = compute_mmd(x)
+            info["mmd"] = mmd
+        
+        if return_info:
+            return x, info
+        
+        return x
+
+    def decode(self, x):
+
+        if self.noise_augment_dim > 0:
+            noise = torch.randn(x.shape[0], self.noise_augment_dim,
+                                x.shape[-1]).type_as(x)
+            x = torch.cat([x, noise], dim=1)
+
+        return x
+
+
+class L2Bottleneck(Bottleneck):
+    def __init__(self):
+        super().__init__(is_discrete=False)
+    
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = F.normalize(x, dim=1)
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return F.normalize(x, dim=1)
+
+
+class RVQBottleneck(DiscreteBottleneck):
+    def __init__(self, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+        self.quantizer = ResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+    def encode(self, x, return_info=False, **kwargs):
+        info = {}
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices, loss = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+        info["quantizer_loss"] = loss.mean()
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents = self.quantizer.get_outputs_from_indices(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class RVQVAEBottleneck(DiscreteBottleneck):
+    def __init__(self, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+        self.quantizer = ResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x, kl = vae_sample(*x.chunk(2, dim=1))
+
+        info["kl"] = kl
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices, loss = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+        info["quantizer_loss"] = loss.mean()
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents = self.quantizer.get_outputs_from_indices(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class DACRVQBottleneck(DiscreteBottleneck):
+    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+        self.quantizer = DACResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["n_codebooks"]
+        self.quantize_on_decode = quantize_on_decode
+
+    def encode(self, x, return_info=False, **kwargs):
+        info = {}
+
+        info["pre_quantizer"] = x
+
+        if self.quantize_on_decode:
+            return x, info if return_info else x
+
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
+
+        output = {
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+        output["vq/commitment_loss"] /= self.num_quantizers
+        output["vq/codebook_loss"] /= self.num_quantizers
+
+        info.update(output)
+
+        if return_info:
+            return output["z"], info
+        
+        return output["z"]
+    
+    def decode(self, x):
+
+        if self.quantize_on_decode:
+            x = self.quantizer(x)[0]
+
+        return x
+    
+    def decode_tokens(self, codes, **kwargs):
+        latents, _, _ = self.quantizer.from_codes(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class DACRVQVAEBottleneck(DiscreteBottleneck):
+    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+        self.quantizer = DACResidualVQ(**quantizer_kwargs)
+        self.num_quantizers = quantizer_kwargs["n_codebooks"]
+        self.quantize_on_decode = quantize_on_decode
+
+    def encode(self, x, return_info=False, n_quantizers: int = None):
+        info = {}
+
+        mean, scale = x.chunk(2, dim=1)
+
+        x, kl = vae_sample(mean, scale)
+
+        info["pre_quantizer"] = x
+        info["kl"] = kl
+
+        if self.quantize_on_decode:
+            return x, info if return_info else x
+
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
+
+        output = {
+            "z": z,
+            "codes": codes,
+            "latents": latents,
+            "vq/commitment_loss": commitment_loss,
+            "vq/codebook_loss": codebook_loss,
+        }
+
+        output["vq/commitment_loss"] /= self.num_quantizers
+        output["vq/codebook_loss"] /= self.num_quantizers
+
+        info.update(output)
+
+        if return_info:
+            return output["z"], info
+        
+        return output["z"]
+    
+    def decode(self, x):
+
+        if self.quantize_on_decode:
+            x = self.quantizer(x)[0]
+
+        return x
+
+    def decode_tokens(self, codes, **kwargs):
+        latents, _, _ = self.quantizer.from_codes(codes)
+
+        return self.decode(latents, **kwargs)
+
+
+class FSQBottleneck(DiscreteBottleneck):
+    def __init__(self, dim, levels):
+        super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
+        self.quantizer = FSQ(levels=[levels] * dim)
+
+    def encode(self, x, return_info=False):
+        info = {}
+
+        x = rearrange(x, "b c n -> b n c")
+        x, indices = self.quantizer(x)
+        x = rearrange(x, "b n c -> b c n")
+
+        info["quantizer_indices"] = indices
+
+        if return_info:
+            return x, info
+        else:
+            return x
+        
+    def decode(self, x):
+        return x
+    
+    def decode_tokens(self, tokens, **kwargs):
+        latents = self.quantizer.indices_to_codes(tokens)
+
+        return self.decode(latents, **kwargs)
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/factory.py b/src/modules/stable_vae/models/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4188703000ee176342c7f329342f18d6fe747b04
--- /dev/null
+++ b/src/modules/stable_vae/models/factory.py
@@ -0,0 +1,153 @@
+import json
+
+def create_model_from_config(model_config):
+    model_type = model_config.get('model_type', None)
+
+    assert model_type is not None, 'model_type must be specified in model config'
+
+    if model_type == 'autoencoder':
+        from .autoencoders import create_autoencoder_from_config
+        return create_autoencoder_from_config(model_config)
+    elif model_type == 'diffusion_uncond':
+        from .diffusion import create_diffusion_uncond_from_config
+        return create_diffusion_uncond_from_config(model_config)
+    elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
+        from .diffusion import create_diffusion_cond_from_config
+        return create_diffusion_cond_from_config(model_config)
+    elif model_type == 'diffusion_autoencoder':
+        from .autoencoders import create_diffAE_from_config
+        return create_diffAE_from_config(model_config)
+    elif model_type == 'lm':
+        from .lm import create_audio_lm_from_config
+        return create_audio_lm_from_config(model_config)
+    else:
+        raise NotImplementedError(f'Unknown model type: {model_type}')
+
+def create_model_from_config_path(model_config_path):
+    with open(model_config_path) as f:
+        model_config = json.load(f)
+    
+    return create_model_from_config(model_config)
+
+def create_pretransform_from_config(pretransform_config, sample_rate):
+    pretransform_type = pretransform_config.get('type', None)
+
+    assert pretransform_type is not None, 'type must be specified in pretransform config'
+
+    if pretransform_type == 'autoencoder':
+        from .autoencoders import create_autoencoder_from_config
+        from .pretransforms import AutoencoderPretransform
+
+        # Create fake top-level config to pass sample rate to autoencoder constructor
+        # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
+        autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
+        autoencoder = create_autoencoder_from_config(autoencoder_config)
+
+        scale = pretransform_config.get("scale", 1.0)
+        model_half = pretransform_config.get("model_half", False)
+        iterate_batch = pretransform_config.get("iterate_batch", False)
+        chunked = pretransform_config.get("chunked", False)
+
+        pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
+    elif pretransform_type == 'wavelet':
+        from .pretransforms import WaveletPretransform
+
+        wavelet_config = pretransform_config["config"]
+        channels = wavelet_config["channels"]
+        levels = wavelet_config["levels"]
+        wavelet = wavelet_config["wavelet"]
+
+        pretransform = WaveletPretransform(channels, levels, wavelet)
+    elif pretransform_type == 'pqmf':
+        from .pretransforms import PQMFPretransform
+        pqmf_config = pretransform_config["config"]
+        pretransform = PQMFPretransform(**pqmf_config)
+    elif pretransform_type == 'dac_pretrained':
+        from .pretransforms import PretrainedDACPretransform
+        pretrained_dac_config = pretransform_config["config"]
+        pretransform = PretrainedDACPretransform(**pretrained_dac_config)
+    elif pretransform_type == "audiocraft_pretrained":
+        from .pretransforms import AudiocraftCompressionPretransform
+
+        audiocraft_config = pretransform_config["config"]
+        pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
+    else:
+        raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
+    
+    enable_grad = pretransform_config.get('enable_grad', False)
+    pretransform.enable_grad = enable_grad
+
+    pretransform.eval().requires_grad_(pretransform.enable_grad)
+
+    return pretransform
+
+def create_bottleneck_from_config(bottleneck_config):
+    bottleneck_type = bottleneck_config.get('type', None)
+
+    assert bottleneck_type is not None, 'type must be specified in bottleneck config'
+
+    if bottleneck_type == 'tanh':
+        from .bottleneck import TanhBottleneck
+        bottleneck = TanhBottleneck()
+    elif bottleneck_type == 'vae':
+        from .bottleneck import VAEBottleneck
+        bottleneck = VAEBottleneck()
+    elif bottleneck_type == 'rvq':
+        from .bottleneck import RVQBottleneck
+
+        quantizer_params = {
+            "dim": 128,
+            "codebook_size": 1024,
+            "num_quantizers": 8,
+            "decay": 0.99,
+            "kmeans_init": True,
+            "kmeans_iters": 50,
+            "threshold_ema_dead_code": 2,
+        }
+
+        quantizer_params.update(bottleneck_config["config"])
+
+        bottleneck = RVQBottleneck(**quantizer_params)
+    elif bottleneck_type == "dac_rvq":
+        from .bottleneck import DACRVQBottleneck
+
+        bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
+    
+    elif bottleneck_type == 'rvq_vae':
+        from .bottleneck import RVQVAEBottleneck
+
+        quantizer_params = {
+            "dim": 128,
+            "codebook_size": 1024,
+            "num_quantizers": 8,
+            "decay": 0.99,
+            "kmeans_init": True,
+            "kmeans_iters": 50,
+            "threshold_ema_dead_code": 2,
+        }
+
+        quantizer_params.update(bottleneck_config["config"])
+
+        bottleneck = RVQVAEBottleneck(**quantizer_params)
+        
+    elif bottleneck_type == 'dac_rvq_vae':
+        from .bottleneck import DACRVQVAEBottleneck
+        bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
+    elif bottleneck_type == 'l2_norm':
+        from .bottleneck import L2Bottleneck
+        bottleneck = L2Bottleneck()
+    elif bottleneck_type == "wasserstein":
+        from .bottleneck import WassersteinBottleneck
+        bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
+    elif bottleneck_type == "fsq":
+        from .bottleneck import FSQBottleneck
+        bottleneck = FSQBottleneck(**bottleneck_config["config"])
+    else:
+        raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
+    
+    requires_grad = bottleneck_config.get('requires_grad', True)
+    if not requires_grad:
+        for param in bottleneck.parameters():
+            param.requires_grad = False
+
+    return bottleneck
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+    shape = x.shape
+    x = x.reshape(shape[0], shape[1], -1)
+    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+    x = x.reshape(shape)
+    return x
+
+
+class Snake1d(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+    def forward(self, x):
+        return snake(x, self.alpha)
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+    """L1 Loss between AudioSignals. Defaults
+    to comparing ``audio_data``, but any
+    attribute of an AudioSignal can be used.
+
+    Parameters
+    ----------
+    attribute : str, optional
+        Attribute of signal to compare, defaults to ``audio_data``.
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+        self.attribute = attribute
+        self.weight = weight
+        super().__init__(**kwargs)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate AudioSignal
+        y : AudioSignal
+            Reference AudioSignal
+
+        Returns
+        -------
+        torch.Tensor
+            L1 loss between AudioSignal attributes.
+        """
+        if isinstance(x, AudioSignal):
+            x = getattr(x, self.attribute)
+            y = getattr(y, self.attribute)
+        return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+    """
+    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+    of estimated and reference audio signals or aligned features.
+
+    Parameters
+    ----------
+    scaling : int, optional
+        Whether to use scale-invariant (True) or
+        signal-to-noise ratio (False), by default True
+    reduction : str, optional
+        How to reduce across the batch (either 'mean',
+        'sum', or none).], by default ' mean'
+    zero_mean : int, optional
+        Zero mean the references and estimates before
+        computing the loss, by default True
+    clip_min : int, optional
+        The minimum possible loss value. Helps network
+        to not focus on making already good examples better, by default None
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(
+        self,
+        scaling: int = True,
+        reduction: str = "mean",
+        zero_mean: int = True,
+        clip_min: int = None,
+        weight: float = 1.0,
+    ):
+        self.scaling = scaling
+        self.reduction = reduction
+        self.zero_mean = zero_mean
+        self.clip_min = clip_min
+        self.weight = weight
+        super().__init__()
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        eps = 1e-8
+        # nb, nc, nt
+        if isinstance(x, AudioSignal):
+            references = x.audio_data
+            estimates = y.audio_data
+        else:
+            references = x
+            estimates = y
+
+        nb = references.shape[0]
+        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+        # samples now on axis 1
+        if self.zero_mean:
+            mean_reference = references.mean(dim=1, keepdim=True)
+            mean_estimate = estimates.mean(dim=1, keepdim=True)
+        else:
+            mean_reference = 0
+            mean_estimate = 0
+
+        _references = references - mean_reference
+        _estimates = estimates - mean_estimate
+
+        references_projection = (_references**2).sum(dim=-2) + eps
+        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+        scale = (
+            (references_on_estimates / references_projection).unsqueeze(1)
+            if self.scaling
+            else 1
+        )
+
+        e_true = scale * _references
+        e_res = _estimates - e_true
+
+        signal = (e_true**2).sum(dim=1)
+        noise = (e_res**2).sum(dim=1)
+        sdr = -10 * torch.log10(signal / noise + eps)
+
+        if self.clip_min is not None:
+            sdr = torch.clamp(sdr, min=self.clip_min)
+
+        if self.reduction == "mean":
+            sdr = sdr.mean()
+        elif self.reduction == "sum":
+            sdr = sdr.sum()
+        return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+    """Computes the multi-scale STFT loss from [1].
+
+    Parameters
+    ----------
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    References
+    ----------
+
+    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
+        "DDSP: Differentiable Digital Signal Processing."
+        International Conference on Learning Representations. 2019.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.loss_fn = loss_fn
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.clamp_eps = clamp_eps
+        self.weight = weight
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes multi-scale STFT between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Multi-scale STFT loss.
+        """
+        loss = 0.0
+        for s in self.stft_params:
+            x.stft(s.window_length, s.hop_length, s.window_type)
+            y.stft(s.window_length, s.hop_length, s.window_type)
+            loss += self.log_weight * self.loss_fn(
+                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+        return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+    """Compute distance between mel spectrograms. Can be used
+    in a multi-scale way.
+
+    Parameters
+    ----------
+    n_mels : List[int]
+        Number of mels per STFT, by default [150, 80],
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        n_mels: List[int] = [150, 80],
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        mel_fmin: List[float] = [0.0, 0.0],
+        mel_fmax: List[float] = [None, None],
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.n_mels = n_mels
+        self.loss_fn = loss_fn
+        self.clamp_eps = clamp_eps
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.weight = weight
+        self.mel_fmin = mel_fmin
+        self.mel_fmax = mel_fmax
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes mel loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Mel loss.
+        """
+        loss = 0.0
+        for n_mels, fmin, fmax, s in zip(
+            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+        ):
+            kwargs = {
+                "window_length": s.window_length,
+                "hop_length": s.hop_length,
+                "window_type": s.window_type,
+            }
+            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+            loss += self.log_weight * self.loss_fn(
+                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+        return loss
+
+
+class GANLoss(nn.Module):
+    """
+    Computes a discriminator loss, given a discriminator on
+    generated waveforms/spectrograms compared to ground truth
+    waveforms/spectrograms. Computes the loss for both the
+    discriminator and the generator in separate functions.
+    """
+
+    def __init__(self, discriminator):
+        super().__init__()
+        self.discriminator = discriminator
+
+    def forward(self, fake, real):
+        d_fake = self.discriminator(fake.audio_data)
+        d_real = self.discriminator(real.audio_data)
+        return d_fake, d_real
+
+    def discriminator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+        loss_d = 0
+        for x_fake, x_real in zip(d_fake, d_real):
+            loss_d += torch.mean(x_fake[-1] ** 2)
+            loss_d += torch.mean((1 - x_real[-1]) ** 2)
+        return loss_d
+
+    def generator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake, real)
+
+        loss_g = 0
+        for x_fake in d_fake:
+            loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+        loss_feature = 0
+
+        for i in range(len(d_fake)):
+            for j in range(len(d_fake[i]) - 1):
+                loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+        return loss_g, loss_feature
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+    """
+    Implementation of VQ similar to Karpathy's repo:
+    https://github.com/karpathy/deep-vector-quantization
+    Additionally uses following tricks from Improved VQGAN
+    (https://arxiv.org/pdf/2110.04627.pdf):
+        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+            for improved codebook usage
+        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+            improves training stability
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id):
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents):
+        encodings = rearrange(latents, "b d t -> (b t) d")
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+    ):
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+
+    def forward(self, z, n_quantizers: int = None):
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+                residual
+            )
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return z_q, codes, latents, commitment_loss, codebook_loss
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+            0
+        ]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+    rvq = ResidualVectorQuantize(quantizer_dropout=True)
+    x = torch.randn(16, 512, 80)
+    y = rvq(x)
+    print(y["latents"].shape)
diff --git a/src/modules/stable_vae/models/nn/__init__.py b/src/modules/stable_vae/models/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/__init__.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd83e18bd22222ca6b9ce0f0ab056cc026747bb3
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..613da213e9ce976569f03e20476c405e0a68b0cc
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57a85d0e28029662f1ecc2790e44f71caa09cd0c
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c93f6bedcfe519c094304705d6dd5033cc9a7b45
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47d3c0daa35a755146c3ebf4f49d430469fe0c6a
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96bfb48894b2b72431b76c07dcd7187e322b1415
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d960f1c6e8f02f6b1c3b72b9d60543ceadb619cf
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c59b1da7d3336c046e199092b83bcb281049336b
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/layers.py b/src/modules/stable_vae/models/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/layers.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+    shape = x.shape
+    x = x.reshape(shape[0], shape[1], -1)
+    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+    x = x.reshape(shape)
+    return x
+
+
+class Snake1d(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+    def forward(self, x):
+        return snake(x, self.alpha)
diff --git a/src/modules/stable_vae/models/nn/loss.py b/src/modules/stable_vae/models/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/loss.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+    """L1 Loss between AudioSignals. Defaults
+    to comparing ``audio_data``, but any
+    attribute of an AudioSignal can be used.
+
+    Parameters
+    ----------
+    attribute : str, optional
+        Attribute of signal to compare, defaults to ``audio_data``.
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+        self.attribute = attribute
+        self.weight = weight
+        super().__init__(**kwargs)
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate AudioSignal
+        y : AudioSignal
+            Reference AudioSignal
+
+        Returns
+        -------
+        torch.Tensor
+            L1 loss between AudioSignal attributes.
+        """
+        if isinstance(x, AudioSignal):
+            x = getattr(x, self.attribute)
+            y = getattr(y, self.attribute)
+        return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+    """
+    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+    of estimated and reference audio signals or aligned features.
+
+    Parameters
+    ----------
+    scaling : int, optional
+        Whether to use scale-invariant (True) or
+        signal-to-noise ratio (False), by default True
+    reduction : str, optional
+        How to reduce across the batch (either 'mean',
+        'sum', or none).], by default ' mean'
+    zero_mean : int, optional
+        Zero mean the references and estimates before
+        computing the loss, by default True
+    clip_min : int, optional
+        The minimum possible loss value. Helps network
+        to not focus on making already good examples better, by default None
+    weight : float, optional
+        Weight of this loss, defaults to 1.0.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+    """
+
+    def __init__(
+        self,
+        scaling: int = True,
+        reduction: str = "mean",
+        zero_mean: int = True,
+        clip_min: int = None,
+        weight: float = 1.0,
+    ):
+        self.scaling = scaling
+        self.reduction = reduction
+        self.zero_mean = zero_mean
+        self.clip_min = clip_min
+        self.weight = weight
+        super().__init__()
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        eps = 1e-8
+        # nb, nc, nt
+        if isinstance(x, AudioSignal):
+            references = x.audio_data
+            estimates = y.audio_data
+        else:
+            references = x
+            estimates = y
+
+        nb = references.shape[0]
+        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+        # samples now on axis 1
+        if self.zero_mean:
+            mean_reference = references.mean(dim=1, keepdim=True)
+            mean_estimate = estimates.mean(dim=1, keepdim=True)
+        else:
+            mean_reference = 0
+            mean_estimate = 0
+
+        _references = references - mean_reference
+        _estimates = estimates - mean_estimate
+
+        references_projection = (_references**2).sum(dim=-2) + eps
+        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+        scale = (
+            (references_on_estimates / references_projection).unsqueeze(1)
+            if self.scaling
+            else 1
+        )
+
+        e_true = scale * _references
+        e_res = _estimates - e_true
+
+        signal = (e_true**2).sum(dim=1)
+        noise = (e_res**2).sum(dim=1)
+        sdr = -10 * torch.log10(signal / noise + eps)
+
+        if self.clip_min is not None:
+            sdr = torch.clamp(sdr, min=self.clip_min)
+
+        if self.reduction == "mean":
+            sdr = sdr.mean()
+        elif self.reduction == "sum":
+            sdr = sdr.sum()
+        return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+    """Computes the multi-scale STFT loss from [1].
+
+    Parameters
+    ----------
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    References
+    ----------
+
+    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
+        "DDSP: Differentiable Digital Signal Processing."
+        International Conference on Learning Representations. 2019.
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.loss_fn = loss_fn
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.clamp_eps = clamp_eps
+        self.weight = weight
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes multi-scale STFT between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Multi-scale STFT loss.
+        """
+        loss = 0.0
+        for s in self.stft_params:
+            x.stft(s.window_length, s.hop_length, s.window_type)
+            y.stft(s.window_length, s.hop_length, s.window_type)
+            loss += self.log_weight * self.loss_fn(
+                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+        return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+    """Compute distance between mel spectrograms. Can be used
+    in a multi-scale way.
+
+    Parameters
+    ----------
+    n_mels : List[int]
+        Number of mels per STFT, by default [150, 80],
+    window_lengths : List[int], optional
+        Length of each window of each STFT, by default [2048, 512]
+    loss_fn : typing.Callable, optional
+        How to compare each loss, by default nn.L1Loss()
+    clamp_eps : float, optional
+        Clamp on the log magnitude, below, by default 1e-5
+    mag_weight : float, optional
+        Weight of raw magnitude portion of loss, by default 1.0
+    log_weight : float, optional
+        Weight of log magnitude portion of loss, by default 1.0
+    pow : float, optional
+        Power to raise magnitude to before taking log, by default 2.0
+    weight : float, optional
+        Weight of this loss, by default 1.0
+    match_stride : bool, optional
+        Whether to match the stride of convolutional layers, by default False
+
+    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+    """
+
+    def __init__(
+        self,
+        n_mels: List[int] = [150, 80],
+        window_lengths: List[int] = [2048, 512],
+        loss_fn: typing.Callable = nn.L1Loss(),
+        clamp_eps: float = 1e-5,
+        mag_weight: float = 1.0,
+        log_weight: float = 1.0,
+        pow: float = 2.0,
+        weight: float = 1.0,
+        match_stride: bool = False,
+        mel_fmin: List[float] = [0.0, 0.0],
+        mel_fmax: List[float] = [None, None],
+        window_type: str = None,
+    ):
+        super().__init__()
+        self.stft_params = [
+            STFTParams(
+                window_length=w,
+                hop_length=w // 4,
+                match_stride=match_stride,
+                window_type=window_type,
+            )
+            for w in window_lengths
+        ]
+        self.n_mels = n_mels
+        self.loss_fn = loss_fn
+        self.clamp_eps = clamp_eps
+        self.log_weight = log_weight
+        self.mag_weight = mag_weight
+        self.weight = weight
+        self.mel_fmin = mel_fmin
+        self.mel_fmax = mel_fmax
+        self.pow = pow
+
+    def forward(self, x: AudioSignal, y: AudioSignal):
+        """Computes mel loss between an estimate and a reference
+        signal.
+
+        Parameters
+        ----------
+        x : AudioSignal
+            Estimate signal
+        y : AudioSignal
+            Reference signal
+
+        Returns
+        -------
+        torch.Tensor
+            Mel loss.
+        """
+        loss = 0.0
+        for n_mels, fmin, fmax, s in zip(
+            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+        ):
+            kwargs = {
+                "window_length": s.window_length,
+                "hop_length": s.hop_length,
+                "window_type": s.window_type,
+            }
+            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+            loss += self.log_weight * self.loss_fn(
+                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+            )
+            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+        return loss
+
+
+class GANLoss(nn.Module):
+    """
+    Computes a discriminator loss, given a discriminator on
+    generated waveforms/spectrograms compared to ground truth
+    waveforms/spectrograms. Computes the loss for both the
+    discriminator and the generator in separate functions.
+    """
+
+    def __init__(self, discriminator):
+        super().__init__()
+        self.discriminator = discriminator
+
+    def forward(self, fake, real):
+        d_fake = self.discriminator(fake.audio_data)
+        d_real = self.discriminator(real.audio_data)
+        return d_fake, d_real
+
+    def discriminator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+        loss_d = 0
+        for x_fake, x_real in zip(d_fake, d_real):
+            loss_d += torch.mean(x_fake[-1] ** 2)
+            loss_d += torch.mean((1 - x_real[-1]) ** 2)
+        return loss_d
+
+    def generator_loss(self, fake, real):
+        d_fake, d_real = self.forward(fake, real)
+
+        loss_g = 0
+        for x_fake in d_fake:
+            loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+        loss_feature = 0
+
+        for i in range(len(d_fake)):
+            for j in range(len(d_fake[i]) - 1):
+                loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+        return loss_g, loss_feature
diff --git a/src/modules/stable_vae/models/nn/quantize.py b/src/modules/stable_vae/models/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/quantize.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+    """
+    Implementation of VQ similar to Karpathy's repo:
+    https://github.com/karpathy/deep-vector-quantization
+    Additionally uses following tricks from Improved VQGAN
+    (https://arxiv.org/pdf/2110.04627.pdf):
+        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+            for improved codebook usage
+        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+            improves training stability
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id):
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents):
+        encodings = rearrange(latents, "b d t -> (b t) d")
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+    ):
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+
+    def forward(self, z, n_quantizers: int = None):
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        dict
+            A dictionary with the following keys:
+
+            "z" : Tensor[B x D x T]
+                Quantized continuous representation of input
+            "codes" : Tensor[B x N x T]
+                Codebook indices for each codebook
+                (quantized discrete representation of input)
+            "latents" : Tensor[B x N*D x T]
+                Projected latents (continuous representation of input before quantization)
+            "vq/commitment_loss" : Tensor[1]
+                Commitment loss to train encoder to predict vectors closer to codebook
+                entries
+            "vq/codebook_loss" : Tensor[1]
+                Codebook loss to update the codebook
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+                residual
+            )
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return z_q, codes, latents, commitment_loss, codebook_loss
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+            0
+        ]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+    rvq = ResidualVectorQuantize(quantizer_dropout=True)
+    x = torch.randn(16, 512, 80)
+    y = rvq(x)
+    print(y["latents"].shape)
diff --git a/src/modules/stable_vae/models/pretransforms.py b/src/modules/stable_vae/models/pretransforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a
--- /dev/null
+++ b/src/modules/stable_vae/models/pretransforms.py
@@ -0,0 +1,258 @@
+import torch
+from einops import rearrange
+from torch import nn
+
+class Pretransform(nn.Module):
+    def __init__(self, enable_grad, io_channels, is_discrete):
+        super().__init__()
+
+        self.is_discrete = is_discrete
+        self.io_channels = io_channels
+        self.encoded_channels = None
+        self.downsampling_ratio = None
+
+        self.enable_grad = enable_grad
+
+    def encode(self, x):
+        raise NotImplementedError
+
+    def decode(self, z):
+        raise NotImplementedError
+    
+    def tokenize(self, x):
+        raise NotImplementedError
+    
+    def decode_tokens(self, tokens):
+        raise NotImplementedError
+
+class AutoencoderPretransform(Pretransform):
+    def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
+        super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
+        self.model = model
+        self.model.requires_grad_(False).eval()
+        self.scale=scale
+        self.downsampling_ratio = model.downsampling_ratio
+        self.io_channels = model.io_channels
+        self.sample_rate = model.sample_rate
+        
+        self.model_half = model_half
+        self.iterate_batch = iterate_batch
+
+        self.encoded_channels = model.latent_dim
+
+        self.chunked = chunked
+        self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
+        self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
+
+        if self.model_half:
+            self.model.half()
+    
+    def encode(self, x, **kwargs):
+        
+        if self.model_half:
+            x = x.half()
+            self.model.to(torch.float16)
+
+        encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+        if self.model_half:
+            encoded = encoded.float()
+
+        return encoded / self.scale
+
+    def decode(self, z, **kwargs):
+        z = z * self.scale
+
+        if self.model_half:
+            z = z.half()
+            self.model.to(torch.float16)
+
+        decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+        if self.model_half:
+            decoded = decoded.float()
+
+        return decoded
+    
+    def tokenize(self, x, **kwargs):
+        assert self.model.is_discrete, "Cannot tokenize with a continuous model"
+
+        _, info = self.model.encode(x, return_info = True, **kwargs)
+
+        return info[self.model.bottleneck.tokens_id]
+    
+    def decode_tokens(self, tokens, **kwargs):
+        assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
+
+        return self.model.decode_tokens(tokens, **kwargs)
+    
+    def load_state_dict(self, state_dict, strict=True):
+        self.model.load_state_dict(state_dict, strict=strict)
+
+class WaveletPretransform(Pretransform):
+    def __init__(self, channels, levels, wavelet):
+        super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
+
+        from .wavelets import WaveletEncode1d, WaveletDecode1d
+
+        self.encoder = WaveletEncode1d(channels, levels, wavelet)
+        self.decoder = WaveletDecode1d(channels, levels, wavelet)
+
+        self.downsampling_ratio = 2 ** levels
+        self.io_channels = channels
+        self.encoded_channels = channels * self.downsampling_ratio
+    
+    def encode(self, x):
+        return self.encoder(x)
+    
+    def decode(self, z):
+        return self.decoder(z)
+    
+class PQMFPretransform(Pretransform):
+    def __init__(self, attenuation=100, num_bands=16):
+        # TODO: Fix PQMF to take in in-channels
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
+        from .pqmf import PQMF
+        self.pqmf = PQMF(attenuation, num_bands)
+
+
+    def encode(self, x):
+        # x is (Batch x Channels x Time)
+        x = self.pqmf.forward(x)
+        # pqmf.forward returns (Batch x Channels x Bands x Time)
+        # but Pretransform needs Batch x Channels x Time
+        # so concatenate channels and bands into one axis
+        return rearrange(x, "b c n t -> b (c n) t")
+
+    def decode(self, x):
+        # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) 
+        x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
+        # returns (Batch x Channels x Time) 
+        return self.pqmf.inverse(x)
+        
+class PretrainedDACPretransform(Pretransform):
+    def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+        
+        import dac
+        
+        model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
+        
+        self.model = dac.DAC.load(model_path)
+
+        self.quantize_on_decode = quantize_on_decode
+
+        if model_type == "44khz":
+            self.downsampling_ratio = 512
+        else:
+            self.downsampling_ratio = 320
+
+        self.io_channels = 1
+
+        self.scale = scale
+
+        self.chunked = chunked
+
+        self.encoded_channels = self.model.latent_dim
+
+        self.num_quantizers = self.model.n_codebooks
+
+        self.codebook_size = self.model.codebook_size
+
+    def encode(self, x):
+
+        latents = self.model.encoder(x)
+
+        if self.quantize_on_decode:
+            output = latents
+        else:
+            z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+            output = z
+        
+        if self.scale != 1.0:
+            output = output / self.scale
+        
+        return output
+
+    def decode(self, z):
+        
+        if self.scale != 1.0:
+            z = z * self.scale
+
+        if self.quantize_on_decode:
+            z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+        return self.model.decode(z)
+
+    def tokenize(self, x):
+        return self.model.encode(x)[1]
+    
+    def decode_tokens(self, tokens):
+        latents = self.model.quantizer.from_codes(tokens)
+        return self.model.decode(latents)
+    
+class AudiocraftCompressionPretransform(Pretransform):
+    def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
+        super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+        
+        try:
+            from audiocraft.models import CompressionModel
+        except ImportError:
+            raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
+               
+        self.model = CompressionModel.get_pretrained(model_type)
+
+        self.quantize_on_decode = quantize_on_decode
+
+        self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
+
+        self.sample_rate = self.model.sample_rate
+
+        self.io_channels = self.model.channels
+
+        self.scale = scale
+
+        #self.encoded_channels = self.model.latent_dim
+
+        self.num_quantizers = self.model.num_codebooks
+
+        self.codebook_size = self.model.cardinality
+
+        self.model.to(torch.float16).eval().requires_grad_(False)
+
+    def encode(self, x):
+
+        assert False, "Audiocraft compression models do not support continuous encoding"
+
+        # latents = self.model.encoder(x)
+
+        # if self.quantize_on_decode:
+        #     output = latents
+        # else:
+        #     z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+        #     output = z
+        
+        # if self.scale != 1.0:
+        #     output = output / self.scale
+        
+        # return output
+
+    def decode(self, z):
+        
+        assert False, "Audiocraft compression models do not support continuous decoding"
+
+        # if self.scale != 1.0:
+        #     z = z * self.scale
+
+        # if self.quantize_on_decode:
+        #     z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+        # return self.model.decode(z)
+
+    def tokenize(self, x):
+        with torch.cuda.amp.autocast(enabled=False):
+            return self.model.encode(x.to(torch.float16))[0]
+    
+    def decode_tokens(self, tokens):
+        with torch.cuda.amp.autocast(enabled=False):
+            return self.model.decode(tokens)
diff --git a/src/modules/stable_vae/models/utils.py b/src/modules/stable_vae/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8eeaf773d47db2c000a3b2237d88d310214dcf
--- /dev/null
+++ b/src/modules/stable_vae/models/utils.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+from torchaudio import transforms as T
+
+
+class PadCrop(nn.Module):
+    def __init__(self, n_samples, randomize=True):
+        super().__init__()
+        self.n_samples = n_samples
+        self.randomize = randomize
+
+    def __call__(self, signal):
+        n, s = signal.shape
+        start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
+        end = start + self.n_samples
+        output = signal.new_zeros([n, self.n_samples])
+        output[:, :min(s, self.n_samples)] = signal[:, start:end]
+        return output
+
+
+def set_audio_channels(audio, target_channels):
+    if target_channels == 1:
+        # Convert to mono
+        audio = audio.mean(1, keepdim=True)
+    elif target_channels == 2:
+        # Convert to stereo
+        if audio.shape[1] == 1:
+            audio = audio.repeat(1, 2, 1)
+        elif audio.shape[1] > 2:
+            audio = audio[:, :2, :]
+    return audio
+
+def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
+    
+    audio = audio.to(device)
+
+    if in_sr != target_sr:
+        resample_tf = T.Resample(in_sr, target_sr).to(device)
+        audio = resample_tf(audio)
+
+    audio = PadCrop(target_length, randomize=False)(audio)
+
+    # Add batch dimension
+    if audio.dim() == 1:
+        audio = audio.unsqueeze(0).unsqueeze(0)
+    elif audio.dim() == 2:
+        audio = audio.unsqueeze(0)
+
+    audio = set_audio_channels(audio, target_channels)
+
+    return audio
\ No newline at end of file
diff --git a/src/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/utils/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67
--- /dev/null
+++ b/src/utils/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1 @@
+from .utils import *
\ No newline at end of file
diff --git a/src/utils/.ipynb_checkpoints/utils-checkpoint.py b/src/utils/.ipynb_checkpoints/utils-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655
--- /dev/null
+++ b/src/utils/.ipynb_checkpoints/utils-checkpoint.py
@@ -0,0 +1,94 @@
+import torch
+import numpy as np
+import yaml
+import os
+
+
+def load_yaml_with_includes(yaml_file):
+    def loader_with_include(loader, node):
+        # Load the included file
+        include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node))
+        with open(include_path, 'r') as f:
+            return yaml.load(f, Loader=yaml.FullLoader)
+
+    yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader)
+
+    with open(yaml_file, 'r') as f:
+        return yaml.load(f, Loader=yaml.FullLoader)
+
+
+def scale_shift(x, scale, shift):
+    return (x+shift) * scale
+
+
+def scale_shift_re(x, scale, shift):
+    return (x/scale) - shift
+
+
+def align_seq(source, target_length, mapping_method='hard'):
+    source_len = source.shape[1]
+    if mapping_method == 'hard':
+        mapping_idx = np.round(np.arange(target_length) * source_len / target_length)
+        output = source[:, mapping_idx]
+    else:
+        # TBD
+        raise NotImplementedError
+
+    return output
+
+
+def customized_lr_scheduler(optimizer, warmup_steps=-1):
+    from torch.optim.lr_scheduler import LambdaLR
+
+    def fn(step):
+        if warmup_steps > 0:
+            return min(step / warmup_steps, 1)
+        else:
+            return 1
+    return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+    if name == 'customized':
+        return customized_lr_scheduler(optimizer, **kwargs)
+    elif name == 'cosine':
+        from torch.optim.lr_scheduler import CosineAnnealingLR
+        return CosineAnnealingLR(optimizer, **kwargs)
+    else:
+        raise NotImplementedError(name)
+
+
+def compute_snr(noise_scheduler, timesteps):
+    """
+    Computes SNR as per
+    https://github.com/TiankaiHang/Min-SNR-Diffusion
+    Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+    """
+    alphas_cumprod = noise_scheduler.alphas_cumprod
+    sqrt_alphas_cumprod = alphas_cumprod**0.5
+    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+    # Expand the tensors.
+    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion
+    # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+    # Compute SNR.
+    snr = (alpha / sigma) ** 2
+    return snr
+
+
+if __name__ == "__main__":
+
+    a = torch.rand(2, 10)
+    target_len = 15
+
+    b = align_seq(a, target_len)
\ No newline at end of file
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1 @@
+from .utils import *
\ No newline at end of file
diff --git a/src/utils/__pycache__/__init__.cpython-310.pyc b/src/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f5bf998d1743a35df9c2834fa91f5bd74a593f6
Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/utils/__pycache__/__init__.cpython-311.pyc b/src/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23b6f8b4f31f08d986faf16cf2f3fc26440c5aca
Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/utils.cpython-310.pyc b/src/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7b4f963755252cd0393b6ec4a9b70d930937fe8
Binary files /dev/null and b/src/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/src/utils/__pycache__/utils.cpython-311.pyc b/src/utils/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6023a70f40be0ead3861f81fb870c4623a7b2e1
Binary files /dev/null and b/src/utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,94 @@
+import torch
+import numpy as np
+import yaml
+import os
+
+
+def load_yaml_with_includes(yaml_file):
+    def loader_with_include(loader, node):
+        # Load the included file
+        include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node))
+        with open(include_path, 'r') as f:
+            return yaml.load(f, Loader=yaml.FullLoader)
+
+    yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader)
+
+    with open(yaml_file, 'r') as f:
+        return yaml.load(f, Loader=yaml.FullLoader)
+
+
+def scale_shift(x, scale, shift):
+    return (x+shift) * scale
+
+
+def scale_shift_re(x, scale, shift):
+    return (x/scale) - shift
+
+
+def align_seq(source, target_length, mapping_method='hard'):
+    source_len = source.shape[1]
+    if mapping_method == 'hard':
+        mapping_idx = np.round(np.arange(target_length) * source_len / target_length)
+        output = source[:, mapping_idx]
+    else:
+        # TBD
+        raise NotImplementedError
+
+    return output
+
+
+def customized_lr_scheduler(optimizer, warmup_steps=-1):
+    from torch.optim.lr_scheduler import LambdaLR
+
+    def fn(step):
+        if warmup_steps > 0:
+            return min(step / warmup_steps, 1)
+        else:
+            return 1
+    return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+    if name == 'customized':
+        return customized_lr_scheduler(optimizer, **kwargs)
+    elif name == 'cosine':
+        from torch.optim.lr_scheduler import CosineAnnealingLR
+        return CosineAnnealingLR(optimizer, **kwargs)
+    else:
+        raise NotImplementedError(name)
+
+
+def compute_snr(noise_scheduler, timesteps):
+    """
+    Computes SNR as per
+    https://github.com/TiankaiHang/Min-SNR-Diffusion
+    Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+    """
+    alphas_cumprod = noise_scheduler.alphas_cumprod
+    sqrt_alphas_cumprod = alphas_cumprod**0.5
+    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+    # Expand the tensors.
+    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion
+    # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+    # Compute SNR.
+    snr = (alpha / sigma) ** 2
+    return snr
+
+
+if __name__ == "__main__":
+
+    a = torch.rand(2, 10)
+    target_len = 15
+
+    b = align_seq(a, target_len)
\ No newline at end of file