KB40aa commited on
Commit
9a7fcc5
·
verified ·
1 Parent(s): 7044679

Upload 6 files

Browse files
Files changed (6) hide show
  1. __init__.py +6 -0
  2. audio.py +163 -0
  3. config.py +176 -0
  4. layers.py +888 -0
  5. model.py +802 -0
  6. state.py +217 -0
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .model import Dia
2
+
3
+
4
+ __all__ = [
5
+ "Dia",
6
+ ]
audio.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+
6
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
7
+ """
8
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
9
+ Negative t_idx => BOS; t_idx >= T => PAD.
10
+ """
11
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
12
+
13
+ t_idx_BxT = torch.broadcast_to(
14
+ torch.arange(T, dtype=torch.int32)[None, :],
15
+ [B, T],
16
+ )
17
+ t_idx_BxTx1 = t_idx_BxT[..., None]
18
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
19
+
20
+ b_idx_BxTxC = torch.broadcast_to(
21
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
22
+ [B, T, C],
23
+ )
24
+ c_idx_BxTxC = torch.broadcast_to(
25
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
26
+ [B, T, C],
27
+ )
28
+
29
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
30
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
31
+
32
+ indices_BTCx3 = torch.stack(
33
+ [
34
+ b_idx_BxTxC.reshape(-1),
35
+ t_clamped_BxTxC.reshape(-1),
36
+ c_idx_BxTxC.reshape(-1),
37
+ ],
38
+ dim=1,
39
+ ).long() # Ensure indices are long type for indexing
40
+
41
+ return t_idx_BxTxC, indices_BTCx3
42
+
43
+
44
+ def apply_audio_delay(
45
+ audio_BxTxC: torch.Tensor,
46
+ pad_value: int,
47
+ bos_value: int,
48
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
49
+ ) -> torch.Tensor:
50
+ """
51
+ Applies the delay pattern to batched audio tokens using precomputed indices,
52
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
53
+
54
+ Args:
55
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
56
+ pad_value: the padding token
57
+ bos_value: the BOS token
58
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
59
+
60
+ Returns:
61
+ result_BxTxC: [B, T, C] delayed audio tokens
62
+ """
63
+ device = audio_BxTxC.device # Get device from input tensor
64
+ t_idx_BxTxC, indices_BTCx3 = precomp
65
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
66
+ indices_BTCx3 = indices_BTCx3.to(device)
67
+
68
+ # Equivalent of tf.gather_nd using advanced indexing
69
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
70
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
71
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
72
+
73
+ # Create masks on the correct device
74
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
75
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
76
+
77
+ # Create scalar tensors on the correct device
78
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
79
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
80
+
81
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
82
+ # All tensors should now be on the same device
83
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
84
+
85
+ return result_BxTxC
86
+
87
+
88
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Precompute indices for the revert operation using PyTorch.
91
+
92
+ Returns:
93
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
94
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
95
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
96
+ batch indices, clamped time indices, and channel indices.
97
+ """
98
+ # Use default device unless specified otherwise; assumes inputs might define device later
99
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
100
+
101
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
102
+
103
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
104
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
105
+
106
+ t_idx_BxTxC = torch.minimum(
107
+ t_idx_BT1 + delay_arr.view(1, 1, C),
108
+ torch.tensor(T - 1, device=device),
109
+ )
110
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
111
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
112
+
113
+ indices_BTCx3 = torch.stack(
114
+ [
115
+ b_idx_BxTxC.reshape(-1),
116
+ t_idx_BxTxC.reshape(-1),
117
+ c_idx_BxTxC.reshape(-1),
118
+ ],
119
+ axis=1,
120
+ ).long() # Ensure indices are long type
121
+
122
+ return t_idx_BxTxC, indices_BTCx3
123
+
124
+
125
+ def revert_audio_delay(
126
+ audio_BxTxC: torch.Tensor,
127
+ pad_value: int,
128
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
129
+ T: int,
130
+ ) -> torch.Tensor:
131
+ """
132
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
133
+
134
+ Args:
135
+ audio_BxTxC: Input delayed audio tensor
136
+ pad_value: Padding value for out-of-bounds indices
137
+ precomp: Precomputed revert indices tuple containing:
138
+ - t_idx_BxTxC: Time offset indices tensor
139
+ - indices_BTCx3: Gather indices tensor for original audio
140
+ T: Original sequence length before padding
141
+
142
+ Returns:
143
+ Reverted audio tensor with same shape as input
144
+ """
145
+ t_idx_BxTxC, indices_BTCx3 = precomp
146
+ device = audio_BxTxC.device # Get device from input tensor
147
+
148
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
149
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
150
+ indices_BTCx3 = indices_BTCx3.to(device)
151
+
152
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
153
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
154
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
155
+
156
+ # Create pad_tensor on the correct device
157
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
158
+ # Create T tensor on the correct device for comparison
159
+ T_tensor = torch.tensor(T, device=device)
160
+
161
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
162
+
163
+ return result_BxTxC
config.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ class EncoderConfig(BaseModel, frozen=True):
22
+ """Configuration for the encoder component of the Dia model.
23
+
24
+ Attributes:
25
+ model_type: Type of the model, defaults to "dia_encoder".
26
+ hidden_size: Size of the encoder layers, defaults to 1024.
27
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096.
28
+ num_hidden_layers: Number of hidden layers in the encoder, defaults to 12.
29
+ num_attention_heads: Number of attention heads in the encoder, defaults to 16.
30
+ num_key_value_heads: Number of key-value heads in the encoder, defaults to 16.
31
+ head_dim: Dimension of each attention head, defaults to 128.
32
+ hidden_act: Activation function in the encoder, defaults to "silu".
33
+ max_position_embeddings: Maximum number of position embeddings, defaults to 1024.
34
+ initializer_range: Range for initializing weights, defaults to 0.02.
35
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
36
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
37
+ rope_scaling: Optional scaling factor for RoPE.
38
+ vocab_size: Vocabulary size, defaults to 256.
39
+ """
40
+
41
+ head_dim: int = Field(default=128, gt=0)
42
+ hidden_act: str = Field(default="silu")
43
+ hidden_size: int = Field(default=1024, gt=0)
44
+ initializer_range: float = Field(default=0.02)
45
+ intermediate_size: int = Field(default=4096, gt=0)
46
+ max_position_embeddings: int = Field(default=1024, gt=0)
47
+ model_type: str = Field(default="dia_encoder")
48
+ norm_eps: float = Field(default=1e-5)
49
+ num_attention_heads: int = Field(default=16, gt=0)
50
+ num_hidden_layers: int = Field(default=12, gt=0)
51
+ num_key_value_heads: int = Field(default=16, gt=0)
52
+ rope_scaling: float | None = Field(default=None)
53
+ rope_theta: float = Field(default=10000.0)
54
+ vocab_size: int = Field(default=256, gt=0)
55
+
56
+
57
+ class DecoderConfig(BaseModel, frozen=True):
58
+ """Configuration for the decoder component of the Dia model.
59
+
60
+ Attributes:
61
+ model_type: Type of the model, defaults to "dia_decoder".
62
+ hidden_size: Size of the decoder layers, defaults to 2048.
63
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192.
64
+ num_hidden_layers: Number of hidden layers in the decoder, defaults to 18.
65
+ num_attention_heads: Number of attention heads in the decoder, defaults to 16.
66
+ num_key_value_heads: Number of key-value heads in the decoder, defaults to 4.
67
+ head_dim: Dimension of each attention head, defaults to 128.
68
+ cross_hidden_size: Size of the cross-attention layers, defaults to 1024.
69
+ cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16.
70
+ cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16.
71
+ cross_head_dim: Dimension of each cross-attention head, defaults to 128.
72
+ hidden_act: Activation function in the decoder, defaults to "silu".
73
+ max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072.
74
+ initializer_range: Range for initializing weights in the decoder, defaults to 0.02.
75
+ norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5.
76
+ rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0.
77
+ rope_scaling: Optional scaling factor for RoPE in the decoder.
78
+ vocab_size: Vocabulary size for the decoder, defaults to 1028.
79
+ num_channels: Number of channels in the decoder, defaults to 9.
80
+ """
81
+
82
+ cross_head_dim: int = Field(default=128, gt=0)
83
+ cross_hidden_size: int = Field(default=1024, gt=0)
84
+ cross_num_attention_heads: int = Field(default=16, gt=0)
85
+ cross_num_key_value_heads: int = Field(default=16, gt=0)
86
+ head_dim: int = Field(default=128, gt=0)
87
+ hidden_act: str = Field(default="silu")
88
+ hidden_size: int = Field(default=2048, gt=0)
89
+ initializer_range: float = Field(default=0.02)
90
+ intermediate_size: int = Field(default=8192, gt=0)
91
+ max_position_embeddings: int = Field(default=3072, gt=0)
92
+ model_type: str = Field(default="dia_decoder")
93
+ norm_eps: float = Field(default=1e-5)
94
+ num_attention_heads: int = Field(default=16, gt=0)
95
+ num_channels: int = Field(default=9, gt=0)
96
+ num_hidden_layers: int = Field(default=18, gt=0)
97
+ num_key_value_heads: int = Field(default=4, gt=0)
98
+ rope_scaling: float | None = Field(default=None)
99
+ rope_theta: float = Field(default=10000.0)
100
+ vocab_size: int = Field(default=1028, gt=0)
101
+
102
+
103
+ class DiaConfig(BaseModel, frozen=True):
104
+ """Main configuration container for the Dia model architecture.
105
+
106
+ Attributes:
107
+ model_type: Type of the model, defaults to "dia".
108
+ is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True.
109
+ encoder: Configuration for the encoder component.
110
+ decoder: Configuration for the decoder component.
111
+ src_vocab_size: Size of the source (text) vocabulary.
112
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
113
+ initializer_range: Range for initializing weights, defaults to 0.02.
114
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
115
+ torch_dtype: Data type for model weights in PyTorch, defaults to "float32".
116
+ bos_token_id: Beginning-of-sequence token ID, defaults to 1026.
117
+ eos_token_id: End-of-sequence token ID, defaults to 1024.
118
+ pad_token_id: Padding token ID, defaults to 1025.
119
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
120
+ rope_scaling: Optional scaling factor for RoPE.
121
+ transformers_version: Version of the transformers library, defaults to "4.53.0.dev0".
122
+ architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"].
123
+ delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15].
124
+ """
125
+
126
+ architectures: list[str] = Field(default_factory=lambda: ["DiaForConditionalGeneration"])
127
+ bos_token_id: int = Field(default=1026)
128
+ decoder_config: DecoderConfig
129
+ delay_pattern: list[int] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
130
+ encoder_config: EncoderConfig
131
+ eos_token_id: int = Field(default=1024)
132
+ initializer_range: float = Field(default=0.02)
133
+ is_encoder_decoder: bool = Field(default=True)
134
+ model_type: str = Field(default="dia")
135
+ norm_eps: float = Field(default=1e-5)
136
+ pad_token_id: int = Field(default=1025)
137
+ torch_dtype: str = Field(default="float32")
138
+ transformers_version: str = Field(default="4.53.0.dev0")
139
+
140
+ def save(self, path: str) -> None:
141
+ """Save the current configuration instance to a JSON file.
142
+
143
+ Ensures the parent directory exists and the file has a .json extension.
144
+
145
+ Args:
146
+ path: The target file path to save the configuration.
147
+
148
+ Raises:
149
+ ValueError: If the path is not a file with a .json extension.
150
+ """
151
+ os.makedirs(os.path.dirname(path), exist_ok=True)
152
+ config_json = self.model_dump_json(indent=2)
153
+ with open(path, "w") as f:
154
+ f.write(config_json)
155
+
156
+ @classmethod
157
+ def load(cls, path: str) -> "DiaConfig | None":
158
+ """Load and validate a Dia configuration from a JSON file.
159
+
160
+ Args:
161
+ path: The path to the configuration file.
162
+
163
+ Returns:
164
+ A validated DiaConfig instance if the file exists and is valid,
165
+ otherwise None if the file is not found.
166
+
167
+ Raises:
168
+ ValueError: If the path does not point to an existing .json file.
169
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
170
+ """
171
+ try:
172
+ with open(path, "r") as f:
173
+ content = f.read()
174
+ return cls.model_validate_json(content)
175
+ except FileNotFoundError:
176
+ return None
layers.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from torch import Tensor
6
+ from torch.nn import RMSNorm
7
+
8
+ from .config import DecoderConfig, DiaConfig, EncoderConfig
9
+ from .state import DecoderInferenceState, EncoderInferenceState, KVCache
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ class DenseGeneral(nn.Module):
17
+ """
18
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
19
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
21
+ and parameters created during initialization based on config.
22
+ `load_weights` validates shapes and copies data.
23
+ Attributes:
24
+ axis (Tuple[int, ...]): Input axis or axes to contract.
25
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
26
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
27
+ use_bias (bool): Whether to add a bias term.
28
+ weight (nn.Parameter): The kernel parameter.
29
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_shapes: tuple[int, ...],
35
+ out_features: tuple[int, ...],
36
+ axis: tuple[int, ...] = (-1,),
37
+ weight_dtype: torch.dtype | None = None,
38
+ device: torch.device | None = None,
39
+ ):
40
+ super().__init__()
41
+ self.in_shapes = in_shapes
42
+ self.out_features = out_features
43
+ self.axis = axis
44
+ self.kernel_shape = self.in_shapes + self.out_features
45
+
46
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
47
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
48
+
49
+ def forward(self, inputs: Tensor) -> Tensor:
50
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
51
+ kernel_contract_axes = tuple(range(len(norm_axis)))
52
+
53
+ output = torch.tensordot(
54
+ inputs.to(self.weight.dtype),
55
+ self.weight,
56
+ dims=(norm_axis, kernel_contract_axes),
57
+ ).to(inputs.dtype)
58
+ return output
59
+
60
+
61
+ class MlpBlock(nn.Module):
62
+ """MLP block using DenseGeneral."""
63
+
64
+ def __init__(self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype):
65
+ super().__init__()
66
+ self.dtype = compute_dtype
67
+
68
+ self.wi_fused = DenseGeneral(
69
+ in_shapes=(embed_dim,),
70
+ out_features=(2, intermediate_dim),
71
+ axis=(-1,),
72
+ weight_dtype=compute_dtype,
73
+ )
74
+
75
+ self.wo = DenseGeneral(
76
+ in_shapes=(intermediate_dim,),
77
+ out_features=(embed_dim,),
78
+ axis=(-1,),
79
+ weight_dtype=compute_dtype,
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """Forward pass."""
84
+ fused_x = self.wi_fused(x)
85
+
86
+ gate = fused_x[..., 0, :]
87
+ up = fused_x[..., 1, :]
88
+
89
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
90
+
91
+ output = self.wo(hidden)
92
+ return output
93
+
94
+
95
+ class RotaryEmbedding(nn.Module):
96
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
97
+
98
+ def __init__(
99
+ self,
100
+ embedding_dims: int,
101
+ min_timescale: float = 1.0,
102
+ max_timescale: float = 10000.0,
103
+ dtype: torch.dtype = torch.float32,
104
+ ):
105
+ super().__init__()
106
+ if embedding_dims % 2 != 0:
107
+ raise ValueError("Embedding dim must be even for RoPE.")
108
+ self.embedding_dims = embedding_dims
109
+ self.min_timescale = min_timescale
110
+ self.max_timescale = max_timescale
111
+ self.compute_dtype = dtype
112
+
113
+ half_embedding_dim = embedding_dims // 2
114
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
115
+ timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(torch.float32)
116
+ self.register_buffer("timescale", timescale, persistent=False)
117
+
118
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
119
+ """Applies RoPE."""
120
+ position = position.unsqueeze(-1).unsqueeze(-1)
121
+ sinusoid_inp = position / self.timescale
122
+ sin = torch.sin(sinusoid_inp)
123
+ cos = torch.cos(sinusoid_inp)
124
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
125
+ first_part = first_half * cos - second_half * sin
126
+ second_part = second_half * cos + first_half * sin
127
+ return torch.cat(
128
+ (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)),
129
+ dim=-1,
130
+ )
131
+
132
+ def apply_rope(self, inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor):
133
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
134
+ first_part = first_half * cos - second_half * sin
135
+ second_part = second_half * cos + first_half * sin
136
+ return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)
137
+
138
+
139
+ def custom_scaled_dot_product_attention(
140
+ query: torch.Tensor,
141
+ key: torch.Tensor,
142
+ value: torch.Tensor,
143
+ attn_mask: torch.Tensor | None = None,
144
+ scale: float = 1.0,
145
+ is_causal: bool = False,
146
+ num_gqa_groups: int = 1,
147
+ ) -> torch.Tensor:
148
+ """
149
+ Custom scaled dot-product attention with GQA support for MPS compatibility.
150
+
151
+ Args:
152
+ query: (B, N_q, T, H) - Query tensor, N_q = num_query_heads
153
+ key: (B, N_kv, S, H) - Key tensor, N_kv = num_kv_heads
154
+ value: (B, N_kv, S, H) - Value tensor
155
+ attn_mask: (B, 1, T, S) - Attention mask, optional
156
+ scale: Scaling factor for attention scores
157
+ is_causal: If True, apply causal masking
158
+ num_gqa_groups: Number of query groups per KV head (N_q / N_kv)
159
+
160
+ Returns:
161
+ output: (B, N_q, T, H) - Attention output
162
+ """
163
+ B, N_q, T, H = query.shape
164
+ _, N_kv, S, _ = key.shape
165
+
166
+ # For GQA, repeat key and value tensors to match query heads
167
+ if num_gqa_groups > 1:
168
+ key = key.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
169
+ value = value.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
170
+
171
+ # Compute attention scores: (B, N_q, T, H) @ (B, N_q, H, S) -> (B, N_q, T, S)
172
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
173
+
174
+ # Apply causal mask if needed
175
+ if is_causal:
176
+ causal_mask = torch.tril(torch.ones(T, S, dtype=torch.bool, device=query.device))
177
+ scores = scores.masked_fill(~causal_mask, float("-inf"))
178
+
179
+ # Apply attention mask if provided
180
+ if attn_mask is not None:
181
+ scores = scores.masked_fill(~attn_mask, float("-inf"))
182
+
183
+ # Softmax over the last dimension (S)
184
+ attn_weights = F.softmax(scores, dim=-1)
185
+
186
+ # Compute output: (B, N_q, T, S) @ (B, N_q, S, H) -> (B, N_q, T, H)
187
+ output = torch.matmul(attn_weights, value)
188
+
189
+ return output
190
+
191
+
192
+ class CrossAttention(nn.Module):
193
+ """Cross-Attention using DenseGeneral."""
194
+
195
+ def __init__(
196
+ self,
197
+ config: EncoderConfig | DecoderConfig,
198
+ q_embed_dim: int,
199
+ kv_embed_dim: int,
200
+ num_query_heads: int,
201
+ num_kv_heads: int,
202
+ head_dim: int,
203
+ compute_dtype: torch.dtype,
204
+ out_embed_dim: int | None = None,
205
+ ):
206
+ super().__init__()
207
+ self.num_query_heads = num_query_heads
208
+ self.num_kv_heads = num_kv_heads
209
+ self.head_dim = head_dim
210
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
211
+ self.projected_query_dim = num_query_heads * head_dim
212
+ if num_query_heads % num_kv_heads != 0:
213
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
214
+ self.num_gqa_groups = num_query_heads // num_kv_heads
215
+
216
+ # --- Projection Layers using DenseGeneral ---
217
+ self.q_proj = DenseGeneral(
218
+ in_shapes=(q_embed_dim,),
219
+ out_features=(num_query_heads, head_dim),
220
+ axis=(-1,),
221
+ weight_dtype=compute_dtype,
222
+ )
223
+ self.k_proj = DenseGeneral(
224
+ in_shapes=(kv_embed_dim,),
225
+ out_features=(num_kv_heads, head_dim),
226
+ axis=(-1,),
227
+ weight_dtype=compute_dtype,
228
+ )
229
+ self.v_proj = DenseGeneral(
230
+ in_shapes=(kv_embed_dim,),
231
+ out_features=(num_kv_heads, head_dim),
232
+ axis=(-1,),
233
+ weight_dtype=compute_dtype,
234
+ )
235
+ self.o_proj = DenseGeneral(
236
+ in_shapes=(num_query_heads, head_dim),
237
+ out_features=(self.output_dim,),
238
+ axis=(-2, -1),
239
+ weight_dtype=compute_dtype,
240
+ )
241
+
242
+ # --- Rotary Embedding ---
243
+ self.rotary_emb = RotaryEmbedding(
244
+ embedding_dims=self.head_dim,
245
+ max_timescale=config.rope_theta,
246
+ dtype=compute_dtype,
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
252
+ q_positions: torch.Tensor, # (B, T)
253
+ kv_positions: torch.Tensor | None = None, # (B, S)
254
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
255
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
256
+ is_causal: bool = False,
257
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
258
+ """
259
+ Performs attention calculation with optional KV caching.
260
+
261
+ Args:
262
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
263
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
264
+ q_positions: Positions for queries (B, T).
265
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
266
+ attn_mask: Attention mask.
267
+ cache: KVCache.
268
+
269
+ Returns:
270
+ A tuple containing:
271
+ - output: The attention output tensor (B, T, output_dim).
272
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
273
+ """
274
+ if kv_positions is None:
275
+ kv_positions = q_positions
276
+ original_dtype = Xq.dtype
277
+
278
+ Xq_BxTxNxH = self.q_proj(Xq)
279
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
280
+
281
+ attn_k: torch.Tensor | None = cache.k if cache is not None else None
282
+ attn_v: torch.Tensor | None = cache.v if cache is not None else None
283
+
284
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
285
+ is_mps = Xq.device.type == "mps" and torch.backends.mps.is_available()
286
+ if is_mps:
287
+ attn_output = custom_scaled_dot_product_attention(
288
+ query=Xq_BxNxTxH,
289
+ key=attn_k,
290
+ value=attn_v,
291
+ attn_mask=attn_mask if not is_causal else None,
292
+ scale=1.0,
293
+ is_causal=is_causal,
294
+ num_gqa_groups=self.num_gqa_groups,
295
+ )
296
+ else:
297
+ attn_output = F.scaled_dot_product_attention(
298
+ Xq_BxNxTxH,
299
+ attn_k,
300
+ attn_v,
301
+ attn_mask=attn_mask if not is_causal else None,
302
+ scale=1.0,
303
+ enable_gqa=self.num_gqa_groups > 1,
304
+ is_causal=is_causal,
305
+ )
306
+
307
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
308
+ output = self.o_proj(attn_output)
309
+
310
+ return output.to(original_dtype)
311
+
312
+
313
+ class FusedQKV(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_features: int,
317
+ out_features: int,
318
+ bias: bool = False,
319
+ num_q_heads: int = 1,
320
+ q_head_dim: int = 1,
321
+ num_kv_heads: int = 1,
322
+ kv_head_dim: int = 1,
323
+ ):
324
+ super().__init__()
325
+ self.num_q_heads = num_q_heads
326
+ self.q_head_dim = q_head_dim
327
+ self.num_kv_heads = num_kv_heads
328
+ self.kv_head_dim = kv_head_dim
329
+ self.q_output_dim = num_q_heads * q_head_dim
330
+ self.kv_output_dim = num_kv_heads * kv_head_dim
331
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
332
+
333
+ def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
334
+ x = self.linear(inputs)
335
+
336
+ q, k, v = x.split([self.q_output_dim, self.kv_output_dim, self.kv_output_dim], dim=-1)
337
+
338
+ q = q.reshape(q.shape[:-1] + (self.num_q_heads, self.q_head_dim))
339
+ k = k.reshape(k.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
340
+ v = v.reshape(v.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
341
+
342
+ return q, k, v
343
+
344
+
345
+ class SelfAttention(nn.Module):
346
+ """Attention using DenseGeneral."""
347
+
348
+ def __init__(
349
+ self,
350
+ config: EncoderConfig | DecoderConfig,
351
+ q_embed_dim: int,
352
+ kv_embed_dim: int,
353
+ num_query_heads: int,
354
+ num_kv_heads: int,
355
+ head_dim: int,
356
+ compute_dtype: torch.dtype,
357
+ out_embed_dim: int | None = None,
358
+ ):
359
+ super().__init__()
360
+ self.num_query_heads = num_query_heads
361
+ self.num_kv_heads = num_kv_heads
362
+ self.head_dim = head_dim
363
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
364
+ self.projected_query_dim = num_query_heads * head_dim
365
+ if num_query_heads % num_kv_heads != 0:
366
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
367
+ self.num_gqa_groups = num_query_heads // num_kv_heads
368
+ self.kv_embed_dim = kv_embed_dim
369
+ self.q_embed_dim = q_embed_dim
370
+
371
+ # --- Projection Layers using DenseGeneral ---
372
+ self.q_proj = DenseGeneral(
373
+ in_shapes=(q_embed_dim,),
374
+ out_features=(num_query_heads, head_dim),
375
+ axis=(-1,),
376
+ weight_dtype=compute_dtype,
377
+ )
378
+ self.k_proj = DenseGeneral(
379
+ in_shapes=(kv_embed_dim,),
380
+ out_features=(num_kv_heads, head_dim),
381
+ axis=(-1,),
382
+ weight_dtype=compute_dtype,
383
+ )
384
+ self.v_proj = DenseGeneral(
385
+ in_shapes=(kv_embed_dim,),
386
+ out_features=(num_kv_heads, head_dim),
387
+ axis=(-1,),
388
+ weight_dtype=compute_dtype,
389
+ )
390
+ self.o_proj = DenseGeneral(
391
+ in_shapes=(num_query_heads, head_dim),
392
+ out_features=(self.output_dim,),
393
+ axis=(-2, -1),
394
+ weight_dtype=compute_dtype,
395
+ )
396
+
397
+ # --- Rotary Embedding ---
398
+ self.rotary_emb = RotaryEmbedding(
399
+ embedding_dims=self.head_dim,
400
+ max_timescale=config.rope_theta,
401
+ dtype=compute_dtype,
402
+ )
403
+
404
+ self.is_fused_qkv = False
405
+
406
+ def get_linear_weight(self, dense: DenseGeneral):
407
+ W_dg = dense.weight.data
408
+
409
+ out_features = 1
410
+ input_features = 1
411
+ for dim in dense.out_features:
412
+ out_features *= dim
413
+ for dim in dense.in_shapes:
414
+ input_features *= dim
415
+
416
+ W_dg_reshaped_for_linear_T = W_dg.reshape(input_features, out_features)
417
+ linear_weight = W_dg_reshaped_for_linear_T.transpose(0, 1).contiguous()
418
+ return linear_weight
419
+
420
+ def patch_fused_qkv(self):
421
+ q_proj_weight = self.get_linear_weight(self.q_proj)
422
+ k_proj_weight = self.get_linear_weight(self.k_proj)
423
+ v_proj_weight = self.get_linear_weight(self.v_proj)
424
+
425
+ self.qkv = FusedQKV(
426
+ self.kv_embed_dim,
427
+ (self.num_query_heads * self.head_dim + 2 * (self.num_kv_heads * self.head_dim)),
428
+ bias=False,
429
+ num_q_heads=self.num_query_heads,
430
+ q_head_dim=self.head_dim,
431
+ num_kv_heads=self.num_kv_heads,
432
+ kv_head_dim=self.head_dim,
433
+ )
434
+ self.qkv.linear.weight.data = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
435
+
436
+ # print(f"qkv.weight.shape: {self.qkv.linear.weight.shape}")
437
+ self.is_fused_qkv = True
438
+
439
+ def forward(
440
+ self,
441
+ X: torch.Tensor, # (B, T, D) T = 1 in AR generation
442
+ q_positions: torch.Tensor, # (B, T)
443
+ kv_positions: torch.Tensor | None = None, # (B, S)
444
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
445
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
446
+ prefill: bool = False,
447
+ is_causal: bool = False,
448
+ current_idx: torch.Tensor | None = None,
449
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
450
+ """
451
+ Performs attention calculation with optional KV caching.
452
+ Args:
453
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
454
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
455
+ q_positions: Positions for queries (B, T).
456
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
457
+ attn_mask: Attention mask.
458
+ cache: KVCache.
459
+ prefill: If True, use prefill mode.
460
+ Returns:
461
+ A tuple containing:
462
+ - output: The attention output tensor (B, T, output_dim).
463
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
464
+ """
465
+ if kv_positions is None:
466
+ kv_positions = q_positions
467
+
468
+ original_dtype = X.dtype
469
+
470
+ if self.is_fused_qkv:
471
+ Xq_BxTxNxH, Xk_BxSxKxH, Xv_BxSxKxH = self.qkv(X)
472
+ else:
473
+ Xq_BxTxNxH = self.q_proj(X)
474
+ Xk_BxSxKxH = self.k_proj(X)
475
+ Xv_BxSxKxH = self.v_proj(X)
476
+
477
+ position = q_positions.unsqueeze(-1).unsqueeze(-1)
478
+ sinusoid_inp = position / self.rotary_emb.timescale
479
+ sin = torch.sin(sinusoid_inp)
480
+ cos = torch.cos(sinusoid_inp)
481
+
482
+ Xq_BxTxNxH = self.rotary_emb.apply_rope(Xq_BxTxNxH, sin, cos)
483
+ Xk_BxSxKxH = self.rotary_emb.apply_rope(Xk_BxSxKxH, sin, cos)
484
+
485
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
486
+
487
+ attn_k: torch.Tensor | None = cache.k if cache is not None else None
488
+ attn_v: torch.Tensor | None = cache.v if cache is not None else None
489
+
490
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
491
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
492
+
493
+ if cache is None:
494
+ attn_k = Xk_BxKxSxH
495
+ attn_v = Xv_BxKxSxH
496
+ elif prefill:
497
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
498
+ cache.prefill(attn_k, attn_v)
499
+ else:
500
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH, current_idx)
501
+
502
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
503
+ is_mps = Xv_BxSxKxH.device.type == "mps" and torch.backends.mps.is_available()
504
+ if is_mps:
505
+ attn_output = custom_scaled_dot_product_attention(
506
+ query=Xq_BxNxTxH,
507
+ key=attn_k,
508
+ value=attn_v,
509
+ attn_mask=attn_mask if not is_causal else None,
510
+ scale=1.0,
511
+ is_causal=is_causal,
512
+ num_gqa_groups=self.num_gqa_groups,
513
+ )
514
+ else:
515
+ attn_output = F.scaled_dot_product_attention(
516
+ Xq_BxNxTxH,
517
+ attn_k,
518
+ attn_v,
519
+ attn_mask=attn_mask if not is_causal else None,
520
+ scale=1.0,
521
+ enable_gqa=self.num_gqa_groups > 1,
522
+ is_causal=is_causal,
523
+ )
524
+
525
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
526
+ output = self.o_proj(attn_output)
527
+
528
+ return output.to(original_dtype)
529
+
530
+
531
+ class EncoderLayer(nn.Module):
532
+ """Transformer Encoder Layer using DenseGeneral."""
533
+
534
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
535
+ super().__init__()
536
+ self.config = config
537
+ enc_config = config.encoder_config
538
+ embed_dim = enc_config.hidden_size
539
+ self.compute_dtype = compute_dtype
540
+
541
+ self.pre_sa_norm = RMSNorm(
542
+ embed_dim,
543
+ eps=enc_config.norm_eps,
544
+ dtype=torch.float32,
545
+ )
546
+ self.self_attention = SelfAttention(
547
+ enc_config,
548
+ q_embed_dim=embed_dim,
549
+ kv_embed_dim=embed_dim,
550
+ num_query_heads=enc_config.num_attention_heads,
551
+ num_kv_heads=enc_config.num_key_value_heads,
552
+ head_dim=enc_config.head_dim,
553
+ compute_dtype=compute_dtype,
554
+ out_embed_dim=embed_dim,
555
+ )
556
+ self.post_sa_norm = RMSNorm(
557
+ embed_dim,
558
+ eps=enc_config.norm_eps,
559
+ dtype=torch.float32,
560
+ )
561
+ self.mlp = MlpBlock(
562
+ embed_dim=embed_dim,
563
+ intermediate_dim=enc_config.intermediate_size,
564
+ compute_dtype=compute_dtype,
565
+ )
566
+
567
+ def forward(
568
+ self,
569
+ x: torch.Tensor,
570
+ state: EncoderInferenceState,
571
+ ) -> torch.Tensor:
572
+ residual = x
573
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
574
+
575
+ sa_out = self.self_attention(
576
+ X=x_norm,
577
+ q_positions=state.positions,
578
+ kv_positions=state.positions,
579
+ attn_mask=state.attn_mask,
580
+ )
581
+ x = residual + sa_out
582
+
583
+ residual = x
584
+ x_norm = self.post_sa_norm(x).to(self.compute_dtype)
585
+ mlp_out = self.mlp(x_norm)
586
+ x = residual + mlp_out
587
+
588
+ return x
589
+
590
+
591
+ class Encoder(nn.Module):
592
+ """Transformer Encoder Stack using DenseGeneral."""
593
+
594
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
595
+ super().__init__()
596
+ self.config = config
597
+ enc_config = config.encoder_config
598
+ self.compute_dtype = compute_dtype
599
+
600
+ self.embedding = nn.Embedding(
601
+ enc_config.vocab_size,
602
+ enc_config.hidden_size,
603
+ dtype=compute_dtype,
604
+ )
605
+ self.layers = nn.ModuleList([EncoderLayer(config, compute_dtype) for _ in range(enc_config.num_hidden_layers)])
606
+ self.norm = RMSNorm(
607
+ enc_config.hidden_size,
608
+ eps=enc_config.norm_eps,
609
+ dtype=torch.float32,
610
+ )
611
+
612
+ def forward(
613
+ self,
614
+ x_ids: torch.Tensor,
615
+ state: EncoderInferenceState,
616
+ ) -> torch.Tensor:
617
+ x = self.embedding(x_ids)
618
+
619
+ for layer in self.layers:
620
+ x = layer(x, state)
621
+
622
+ x = self.norm(x).to(self.compute_dtype)
623
+ return x
624
+
625
+
626
+ class DecoderLayer(nn.Module):
627
+ """Transformer Decoder Layer using DenseGeneral."""
628
+
629
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
630
+ super().__init__()
631
+ self.config = config
632
+ dec_config = config.decoder_config
633
+ enc_config = config.encoder_config
634
+ dec_embed_dim = dec_config.hidden_size
635
+ enc_embed_dim = enc_config.hidden_size
636
+ self.compute_dtype = compute_dtype
637
+
638
+ # Norms
639
+ self.pre_sa_norm = RMSNorm(
640
+ dec_embed_dim,
641
+ eps=dec_config.norm_eps,
642
+ dtype=torch.float32,
643
+ )
644
+ self.pre_ca_norm = RMSNorm(
645
+ dec_embed_dim,
646
+ eps=dec_config.norm_eps,
647
+ dtype=torch.float32,
648
+ )
649
+ self.pre_mlp_norm = RMSNorm(
650
+ dec_embed_dim,
651
+ eps=dec_config.norm_eps,
652
+ dtype=torch.float32,
653
+ )
654
+
655
+ # Self-Attention (GQA) with Causal Masking
656
+ self.self_attention = SelfAttention(
657
+ dec_config,
658
+ q_embed_dim=dec_embed_dim,
659
+ kv_embed_dim=dec_embed_dim,
660
+ num_query_heads=dec_config.num_attention_heads,
661
+ num_kv_heads=dec_config.num_key_value_heads,
662
+ head_dim=dec_config.head_dim,
663
+ compute_dtype=compute_dtype,
664
+ out_embed_dim=dec_embed_dim,
665
+ )
666
+ # Cross-Attention (MHA)
667
+ self.cross_attention = CrossAttention(
668
+ dec_config,
669
+ q_embed_dim=dec_embed_dim,
670
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
671
+ num_query_heads=dec_config.cross_num_attention_heads,
672
+ num_kv_heads=dec_config.cross_num_key_value_heads,
673
+ head_dim=dec_config.cross_head_dim,
674
+ compute_dtype=compute_dtype,
675
+ out_embed_dim=dec_embed_dim,
676
+ )
677
+ # MLP
678
+ self.mlp = MlpBlock(
679
+ embed_dim=dec_embed_dim,
680
+ intermediate_dim=dec_config.intermediate_size,
681
+ compute_dtype=compute_dtype,
682
+ )
683
+
684
+ def forward(
685
+ self,
686
+ x: torch.Tensor,
687
+ state: DecoderInferenceState,
688
+ self_attn_cache: KVCache | None = None,
689
+ cross_attn_cache: KVCache | None = None,
690
+ prefill: bool = False,
691
+ current_idx: int = 0,
692
+ ) -> torch.Tensor:
693
+ residual = x
694
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
695
+
696
+ self_attn_mask = state.casual_attn_mask[None, None, current_idx]
697
+
698
+ sa_out = self.self_attention(
699
+ X=x_norm, # (2, 1, D)
700
+ q_positions=state.dec_positions, # (2, 1)
701
+ kv_positions=state.dec_positions, # (2, 1)
702
+ attn_mask=self_attn_mask,
703
+ cache=self_attn_cache,
704
+ prefill=prefill,
705
+ is_causal=prefill,
706
+ current_idx=current_idx,
707
+ )
708
+
709
+ x = residual + sa_out
710
+
711
+ residual = x
712
+ x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
713
+ ca_out = self.cross_attention(
714
+ Xq=x_norm,
715
+ q_positions=state.dec_positions,
716
+ kv_positions=state.enc_positions,
717
+ attn_mask=state.cross_attn_mask,
718
+ cache=cross_attn_cache,
719
+ )
720
+ x = residual + ca_out
721
+
722
+ residual = x
723
+ x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
724
+ mlp_out = self.mlp(x_norm)
725
+ x = residual + mlp_out
726
+
727
+ return x
728
+
729
+
730
+ class Decoder(nn.Module):
731
+ """Transformer Decoder Stack using DenseGeneral."""
732
+
733
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
734
+ super().__init__()
735
+ self.config = config
736
+ dec_config = config.decoder_config
737
+ self.num_channels = dec_config.num_channels
738
+ self.num_layers = dec_config.num_hidden_layers
739
+
740
+ self.embeddings = nn.ModuleList(
741
+ [
742
+ nn.Embedding(dec_config.vocab_size, dec_config.hidden_size, dtype=compute_dtype)
743
+ for _ in range(self.num_channels)
744
+ ]
745
+ )
746
+ self.layers = nn.ModuleList(
747
+ [DecoderLayer(config=config, compute_dtype=compute_dtype) for _ in range(self.num_layers)]
748
+ )
749
+
750
+ self.norm = RMSNorm(
751
+ dec_config.hidden_size,
752
+ eps=dec_config.norm_eps,
753
+ dtype=torch.float32,
754
+ )
755
+
756
+ self.logits_dense = DenseGeneral(
757
+ in_shapes=(dec_config.hidden_size,),
758
+ out_features=(self.num_channels, dec_config.vocab_size),
759
+ axis=(-1,),
760
+ weight_dtype=compute_dtype,
761
+ )
762
+
763
+ def precompute_cross_attn_cache(
764
+ self,
765
+ enc_out: torch.Tensor, # (B, S, E)
766
+ ) -> list[KVCache]:
767
+ """
768
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
769
+ """
770
+ per_layer_kv_cache: list[KVCache] = []
771
+
772
+ for layer in self.layers:
773
+ cross_attn_module = layer.cross_attention
774
+ k_proj = cross_attn_module.k_proj(enc_out)
775
+ v_proj = cross_attn_module.v_proj(enc_out)
776
+
777
+ k = k_proj.transpose(1, 2)
778
+ v = v_proj.transpose(1, 2)
779
+
780
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
781
+
782
+ return per_layer_kv_cache
783
+
784
+ def decode_step(
785
+ self,
786
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
787
+ state: DecoderInferenceState,
788
+ current_idx: int,
789
+ ) -> torch.Tensor:
790
+ """
791
+ Performs a single decoding step, managing KV caches layer by layer.
792
+ Returns:
793
+ A tuple containing:
794
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
795
+ """
796
+
797
+ x = None
798
+ for i in range(self.num_channels):
799
+ channel_tokens = tgt_ids_Bx1xC[..., i]
800
+ channel_embed = self.embeddings[i](channel_tokens)
801
+ x = channel_embed if x is None else x + channel_embed
802
+
803
+ for i, layer in enumerate(self.layers):
804
+ self_cache = state.self_attn_cache[i]
805
+ cross_cache = state.cross_attn_cache[i]
806
+ x = layer(
807
+ x, # (2, 1, D)
808
+ state,
809
+ self_attn_cache=self_cache,
810
+ cross_attn_cache=cross_cache,
811
+ current_idx=current_idx,
812
+ )
813
+
814
+ x = self.norm(x)
815
+ logits_Bx1xCxV = self.logits_dense(x)
816
+
817
+ return logits_Bx1xCxV.to(torch.float32)
818
+
819
+ def forward(self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState) -> torch.Tensor:
820
+ """
821
+ Forward pass for the Decoder stack, managing KV caches.
822
+ Args:
823
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
824
+ encoder_out: Output from the encoder (B, S, E).
825
+ tgt_positions: Positions for target sequence (B, T).
826
+ src_positions: Positions for source sequence (B, S).
827
+ self_attn_mask: Mask for self-attention.
828
+ cross_attn_mask: Mask for cross-attention.
829
+ past_key_values: List containing the self-attention KV cache for each layer
830
+ from the previous decoding step. `len(past_key_values)` should
831
+ equal `num_layers`.
832
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
833
+ derived from `encoder_out`. This is passed identically
834
+ to all layers.
835
+ Returns:
836
+ A tuple containing:
837
+ - logits: The final output logits (B, T, C * V), cast to float32.
838
+ - present_key_values: A list containing the updated self-attention KV cache
839
+ for each layer for the *current* decoding step.
840
+ """
841
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
842
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
843
+
844
+ # Embeddings
845
+ x = None
846
+ for i in range(self.num_channels):
847
+ channel_tokens = tgt_ids_BxTxC[..., i]
848
+ channel_embed = self.embeddings[i](channel_tokens)
849
+ x = channel_embed if x is None else x + channel_embed
850
+
851
+ for i, layer in enumerate(self.layers):
852
+ self_cache = state.self_attn_cache[i]
853
+ cross_cache = state.cross_attn_cache[i]
854
+ x = layer(
855
+ x,
856
+ state,
857
+ self_attn_cache=self_cache,
858
+ cross_attn_cache=cross_cache,
859
+ prefill=True,
860
+ )
861
+
862
+ # Final Norm
863
+ x = self.norm(x)
864
+ logits_BxTxCxV = self.logits_dense(x)
865
+
866
+ return logits_BxTxCxV.to(torch.float32)
867
+
868
+
869
+ class DiaModel(
870
+ nn.Module,
871
+ PyTorchModelHubMixin,
872
+ repo_url="https://github.com/nari-labs/dia",
873
+ pipeline_tag="text-to-speech",
874
+ license="apache-2.0",
875
+ coders={
876
+ DiaConfig: (
877
+ lambda x: x.model_dump(),
878
+ lambda data: DiaConfig.model_validate(data),
879
+ ),
880
+ },
881
+ ):
882
+ """PyTorch Dia Model using DenseGeneral."""
883
+
884
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
885
+ super().__init__()
886
+ self.config = config
887
+ self.encoder = Encoder(config, compute_dtype)
888
+ self.decoder = Decoder(config, compute_dtype)
model.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+
10
+ from .audio import apply_audio_delay, build_delay_indices, build_revert_indices, revert_audio_delay
11
+ from .config import DiaConfig
12
+ from .layers import DiaModel
13
+ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
14
+
15
+
16
+ DEFAULT_SAMPLE_RATE = 44100
17
+ SAMPLE_RATE_RATIO = 512
18
+
19
+
20
+ def _get_default_device():
21
+ if torch.cuda.is_available():
22
+ return torch.device("cuda")
23
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
24
+ return torch.device("mps")
25
+ return torch.device("cpu")
26
+
27
+
28
+ def _sample_next_token(
29
+ logits_BCxV: torch.Tensor,
30
+ temperature: float,
31
+ top_p: float,
32
+ top_k: int | None,
33
+ audio_eos_value: int,
34
+ ) -> torch.Tensor:
35
+ if temperature == 0.0:
36
+ return torch.argmax(logits_BCxV, dim=-1)
37
+
38
+ logits_BCxV = logits_BCxV / temperature
39
+
40
+ if audio_eos_value is not None and audio_eos_value >= 0:
41
+ top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1)
42
+ eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value
43
+ mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
44
+ mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True
45
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf)
46
+ eos_highest_mask_BC = top_logit_indices_BC == audio_eos_value
47
+ mask_eos_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
48
+ mask_eos_highest_BCxV[eos_highest_mask_BC, :audio_eos_value] = True
49
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_highest_BCxV, -torch.inf)
50
+
51
+ if top_k is not None:
52
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1)
53
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
54
+ mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False)
55
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
56
+
57
+ if top_p < 1.0:
58
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
59
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
60
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
61
+
62
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
63
+ sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1)
64
+ sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0])
65
+
66
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
67
+ indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(
68
+ dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
69
+ )
70
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
71
+
72
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
73
+
74
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
75
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
76
+ return sampled_indices_C
77
+
78
+
79
+ class ComputeDtype(str, Enum):
80
+ FLOAT32 = "float32"
81
+ FLOAT16 = "float16"
82
+ BFLOAT16 = "bfloat16"
83
+
84
+ def to_dtype(self) -> torch.dtype:
85
+ if self == ComputeDtype.FLOAT32:
86
+ return torch.float32
87
+ elif self == ComputeDtype.FLOAT16:
88
+ return torch.float16
89
+ elif self == ComputeDtype.BFLOAT16:
90
+ return torch.bfloat16
91
+ else:
92
+ raise ValueError(f"Unsupported compute dtype: {self}")
93
+
94
+
95
+ class Dia:
96
+ def __init__(
97
+ self,
98
+ config: DiaConfig,
99
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
100
+ device: torch.device | None = None,
101
+ load_dac: bool = True,
102
+ ):
103
+ """Initializes the Dia model.
104
+
105
+ Args:
106
+ config: The configuration object for the model.
107
+ compute_dtype: The computation dtype to use.
108
+ device: The device to load the model onto. If None, will automatically select the best available device.
109
+ load_dac: Whether to load the DAC model.
110
+
111
+ Raises:
112
+ RuntimeError: If there is an error loading the DAC model.
113
+ """
114
+ super().__init__()
115
+ self.config = config
116
+ self.device = device if device is not None else _get_default_device()
117
+ if isinstance(compute_dtype, str):
118
+ compute_dtype = ComputeDtype(compute_dtype)
119
+ self.compute_dtype = compute_dtype.to_dtype()
120
+ self.model: DiaModel = DiaModel(config, self.compute_dtype)
121
+ self.dac_model = None
122
+ self._compiled_step = None
123
+ self.load_dac = load_dac
124
+
125
+ if not self.load_dac:
126
+ print("Warning: DAC model will not be loaded. This is not recommended.")
127
+
128
+ if torch.cuda.is_available():
129
+ torch.backends.cuda.matmul.allow_tf32 = True
130
+
131
+ @classmethod
132
+ def from_local(
133
+ cls,
134
+ config_path: str,
135
+ checkpoint_path: str,
136
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
137
+ device: torch.device | None = None,
138
+ load_dac: bool = True,
139
+ ) -> "Dia":
140
+ """Loads the Dia model from local configuration and checkpoint files.
141
+
142
+ Args:
143
+ config_path: Path to the configuration JSON file.
144
+ checkpoint_path: Path to the model checkpoint (.pth) file.
145
+ compute_dtype: The computation dtype to use.
146
+ device: The device to load the model onto. If None, will automatically select the best available device.
147
+ load_dac: Whether to load the DAC model.
148
+
149
+ Returns:
150
+ An instance of the Dia model loaded with weights and set to eval mode.
151
+
152
+ Raises:
153
+ FileNotFoundError: If the config or checkpoint file is not found.
154
+ RuntimeError: If there is an error loading the checkpoint.
155
+ """
156
+ config = DiaConfig.load(config_path)
157
+ if config is None:
158
+ raise FileNotFoundError(f"Config file not found at {config_path}")
159
+
160
+ dia = cls(config, compute_dtype, device, load_dac)
161
+
162
+ try:
163
+ state_dict = torch.load(checkpoint_path, map_location=dia.device)
164
+ dia.model.load_state_dict(state_dict)
165
+ except FileNotFoundError:
166
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
167
+ except Exception as e:
168
+ raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
169
+
170
+ dia.model.to(dia.device)
171
+ dia.model.eval()
172
+ if load_dac:
173
+ dia._load_dac_model()
174
+ return dia
175
+
176
+ @classmethod
177
+ def from_pretrained(
178
+ cls,
179
+ model_name: str = "nari-labs/Dia-1.6B-0626",
180
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
181
+ device: torch.device | None = None,
182
+ load_dac: bool = True,
183
+ ) -> "Dia":
184
+ """Loads the Dia model from a Hugging Face Hub repository.
185
+
186
+ Downloads the configuration and checkpoint files from the specified
187
+ repository ID and then loads the model.
188
+
189
+ Args:
190
+ model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B-0626").
191
+ compute_dtype: The computation dtype to use.
192
+ device: The device to load the model onto. If None, will automatically select the best available device.
193
+ load_dac: Whether to load the DAC model.
194
+
195
+ Returns:
196
+ An instance of the Dia model loaded with weights and set to eval mode.
197
+
198
+ Raises:
199
+ FileNotFoundError: If config or checkpoint download/loading fails.
200
+ RuntimeError: If there is an error loading the checkpoint.
201
+ """
202
+ if isinstance(compute_dtype, str):
203
+ compute_dtype = ComputeDtype(compute_dtype)
204
+
205
+ # Load model directly using DiaModel's from_pretrained which handles HF download
206
+ try:
207
+ loaded_model = DiaModel.from_pretrained(model_name, compute_dtype=compute_dtype.to_dtype())
208
+ except Exception as e:
209
+ raise RuntimeError(f"Error loading model from Hugging Face Hub ({model_name})") from e
210
+
211
+ config = loaded_model.config # Get config from the loaded model
212
+ dia = cls(config, compute_dtype, device, load_dac)
213
+
214
+ dia.model = loaded_model # Assign the already loaded model
215
+ dia.model.to(dia.device)
216
+ dia.model.eval()
217
+ if load_dac:
218
+ dia._load_dac_model()
219
+ return dia
220
+
221
+ def _load_dac_model(self):
222
+ """Loads the Descript Audio Codec (DAC) model.
223
+
224
+ Downloads the DAC model if necessary and loads it onto the specified device.
225
+ Sets the DAC model to evaluation mode.
226
+
227
+ Raises:
228
+ RuntimeError: If downloading or loading the DAC model fails.
229
+ """
230
+ import dac
231
+
232
+ try:
233
+ dac_model_path = dac.utils.download()
234
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
235
+ dac_model.eval() # Ensure DAC is in eval mode
236
+ except Exception as e:
237
+ raise RuntimeError("Failed to load DAC model") from e
238
+ self.dac_model = dac_model
239
+
240
+ def _encode_text(self, text: str) -> torch.Tensor:
241
+ """Encodes the input text string into a tensor of token IDs using byte-level encoding.
242
+
243
+ Special tokens [S1] and [S2] are replaced by their byte values. The resulting
244
+ sequence is truncated to the maximum configured text length.
245
+
246
+ Args:
247
+ text: The input text string.
248
+
249
+ Returns:
250
+ A tensor containing the encoded byte token IDs.
251
+ """
252
+ max_len = self.config.encoder_config.max_position_embeddings
253
+
254
+ byte_text = text.encode("utf-8")
255
+ # Replace special tokens with their byte values if needed by the specific tokenizer/config
256
+ # Assuming byte values 1 and 2 are correct placeholders based on original code
257
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
258
+ text_tokens = list(replaced_bytes)
259
+ return torch.tensor(
260
+ text_tokens[:max_len],
261
+ dtype=torch.long,
262
+ device=self.device,
263
+ )
264
+
265
+ def _pad_text_input(self, text_tokens: list[torch.Tensor]) -> torch.Tensor:
266
+ """Pads the text input to the maximum length."""
267
+ text_pad_value = 0
268
+ max_len = self.config.encoder_config.max_position_embeddings
269
+ batch_size = len(text_tokens)
270
+
271
+ src_tokens = torch.full(
272
+ (batch_size, 1, max_len),
273
+ fill_value=text_pad_value,
274
+ dtype=torch.long,
275
+ device=self.device,
276
+ )
277
+ for i in range(batch_size):
278
+ current_len = len(text_tokens[i])
279
+ src_tokens[i, 0, :current_len] = text_tokens[i]
280
+ return src_tokens
281
+
282
+ def _prepare_audio_prompt(self, audio_prompts: list[torch.Tensor | None]) -> tuple[torch.Tensor, list[int]]:
283
+ """Prepares the audio prompt tensor for the decoder.
284
+
285
+ Handles padding, adds the beginning-of-sequence (BOS) token, applies the
286
+ delay pattern, and determines the number of prefill steps for each item
287
+ in the batch.
288
+
289
+ Args:
290
+ audio_prompts: A list of audio prompt tensors (encoded DAC frames) or None.
291
+ Each tensor should have shape [T, C].
292
+
293
+ Returns:
294
+ A tuple containing:
295
+ - delayed_batch (torch.Tensor): The prepared audio prompt tensor with
296
+ delays applied, shape [B, T_max_padded, C].
297
+ - prefill_steps (list[int]): A list containing the number of valid
298
+ tokens (including BOS) for each prompt in the batch.
299
+ """
300
+ num_channels = self.config.decoder_config.num_channels
301
+ audio_bos_value = self.config.bos_token_id
302
+ delay_pattern = self.config.delay_pattern
303
+ max_delay_pattern = max(delay_pattern)
304
+ batch_size = len(audio_prompts)
305
+
306
+ max_len = max(p.shape[0] if p is not None else 0 for p in audio_prompts) + max_delay_pattern
307
+ prefill_steps = []
308
+
309
+ prefill = torch.full(
310
+ (batch_size, max_len, num_channels),
311
+ fill_value=-1,
312
+ dtype=torch.int,
313
+ device=self.device,
314
+ )
315
+
316
+ prefill[:, 0, :] = audio_bos_value
317
+
318
+ for i in range(batch_size):
319
+ prompt = audio_prompts[i]
320
+ if prompt is not None:
321
+ prompt = prompt.to(device=self.device, dtype=torch.int)
322
+ prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
323
+ prefill_steps.append(prompt.shape[0] + 1)
324
+ else:
325
+ prefill_steps.append(1)
326
+
327
+ delay_precomp = build_delay_indices(
328
+ B=batch_size,
329
+ T=max_len,
330
+ C=num_channels,
331
+ delay_pattern=delay_pattern,
332
+ )
333
+
334
+ delayed_batch = apply_audio_delay(
335
+ audio_BxTxC=prefill,
336
+ pad_value=-1,
337
+ bos_value=audio_bos_value,
338
+ precomp=delay_precomp,
339
+ )
340
+
341
+ return delayed_batch, prefill_steps
342
+
343
+ def _prepare_generation(
344
+ self,
345
+ text: torch.Tensor,
346
+ audio_prompts: list[torch.Tensor | None],
347
+ max_tokens: int | None = None,
348
+ attn_fn: Callable = F.scaled_dot_product_attention,
349
+ ):
350
+ """Initializes the model state for generation.
351
+
352
+ Encodes the text input (conditional and unconditional), prepares the
353
+ encoder and decoder states (including KV caches and cross-attention),
354
+ prepares the audio prompt, and performs the initial decoder prefill steps
355
+ based on the audio prompts.
356
+
357
+ Args:
358
+ text: The padded text input tensor, shape [B, 1, T_text].
359
+ audio_prompts: A list of prepared audio prompt tensors or None.
360
+
361
+ Returns:
362
+ A tuple containing:
363
+ - dec_state (DecoderInferenceState): The initialized decoder state.
364
+ - dec_output (DecoderOutput): The initialized decoder output manager,
365
+ containing the prefilled audio tokens.
366
+ """
367
+ batch_size = text.shape[0]
368
+
369
+ enc_input_uncond = torch.zeros_like(text)
370
+ enc_input_cond = text
371
+ stacked_inputs = torch.stack([enc_input_uncond, enc_input_cond], dim=1)
372
+ enc_input = stacked_inputs.view(2 * batch_size, -1)
373
+
374
+ enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
375
+ encoder_out = self.model.encoder(enc_input, enc_state)
376
+
377
+ dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(encoder_out)
378
+ dec_state = DecoderInferenceState.new(
379
+ self.config,
380
+ enc_state,
381
+ encoder_out,
382
+ dec_cross_attn_cache,
383
+ self.compute_dtype,
384
+ max_generation_length=max_tokens,
385
+ )
386
+ prefill, prefill_steps = self._prepare_audio_prompt(audio_prompts)
387
+
388
+ dec_output = DecoderOutput.new(batch_size, self.config, self.device)
389
+ dec_output.prefill(prefill, prefill_steps)
390
+
391
+ dec_step = min(prefill_steps) - 1
392
+ if dec_step > 0:
393
+ dec_state.prepare_step(0, dec_step)
394
+ tokens_BxTxC = dec_output.get_tokens_at(0, dec_step).repeat_interleave(2, dim=0)
395
+ self.model.decoder.forward(tokens_BxTxC, dec_state)
396
+
397
+ return dec_state, dec_output
398
+
399
+ def _decoder_step(
400
+ self,
401
+ tokens_Bx1xC: torch.Tensor,
402
+ dec_state: DecoderInferenceState,
403
+ cfg_scale: float,
404
+ temperature: float,
405
+ top_p: float,
406
+ top_k: int,
407
+ current_idx: int,
408
+ ) -> torch.Tensor:
409
+ """Performs a single step of the decoder inference.
410
+
411
+ Takes the tokens from the previous step, runs them through the decoder
412
+ (for both conditional and unconditional paths), applies classifier-free
413
+ guidance (CFG), samples the next token using temperature, top-p, and top-k
414
+ sampling, and applies constraints (e.g., preventing EOS in certain channels).
415
+
416
+ Args:
417
+ tokens_Bx1xC: The input tokens for the current step, shape [2*B, 1, C].
418
+ Repeated for CFG (unconditional and conditional).
419
+ dec_state: The current state of the decoder (KV caches, etc.).
420
+ cfg_scale: The scale factor for classifier-free guidance.
421
+ temperature: The temperature for sampling.
422
+ top_p: The cumulative probability threshold for top-p sampling.
423
+ top_k: The number of top logits to consider for top-k sampling.
424
+ current_idx: The current generation step index.
425
+
426
+ Returns:
427
+ torch.Tensor: The sampled next tokens for each item in the batch,
428
+ shape [B, C].
429
+ """
430
+ B = tokens_Bx1xC.shape[0] // 2
431
+
432
+ audio_eos_value = self.config.eos_token_id
433
+ logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state, current_idx)
434
+
435
+ logits_last_2BxCxV = logits_Bx1xCxV[:, -1]
436
+ logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, 2, *logits_last_2BxCxV.shape[1:])
437
+
438
+ uncond_logits_BxCxV = logits_last_Bx2xCxV[:, 0, :, :] # Shape [B, C, V]
439
+ cond_logits_BxCxV = logits_last_Bx2xCxV[:, 1, :, :] # Shape [B, C, V]
440
+ logits_BxCxV = cond_logits_BxCxV + cfg_scale * (cond_logits_BxCxV - uncond_logits_BxCxV)
441
+
442
+ _, top_k_indices_BxCxk = torch.topk(logits_BxCxV, k=top_k, dim=-1)
443
+ mask_BxCxV = torch.ones_like(logits_BxCxV, dtype=torch.bool)
444
+ mask_BxCxV = mask_BxCxV.scatter(dim=-1, index=top_k_indices_BxCxk, value=False)
445
+ logits_BxCxV = cond_logits_BxCxV.masked_fill(mask_BxCxV, -torch.inf)
446
+
447
+ logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like(
448
+ logits_BxCxV[:, :, audio_eos_value + 1 :],
449
+ fill_value=-torch.inf,
450
+ )
451
+ logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like(
452
+ logits_BxCxV[:, 1:, audio_eos_value:],
453
+ fill_value=-torch.inf,
454
+ )
455
+
456
+ flat_logits_BCxV = logits_BxCxV.view(B * self.config.decoder_config.num_channels, -1)
457
+
458
+ pred_BC = _sample_next_token(
459
+ flat_logits_BCxV.float(),
460
+ temperature=temperature,
461
+ top_p=top_p,
462
+ top_k=top_k,
463
+ audio_eos_value=audio_eos_value,
464
+ )
465
+
466
+ pred_BxC = pred_BC.view(B, self.config.decoder_config.num_channels)
467
+ return pred_BxC
468
+
469
+ def _generate_output(self, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor) -> list[np.ndarray]:
470
+ """Converts generated delayed codes into audio waveforms.
471
+
472
+ Reverts the delay pattern applied during generation, decodes the resulting
473
+ codebook using the DAC model (if loaded), and returns a list of audio
474
+ waveforms as NumPy arrays. If DAC is not loaded, returns the raw codebook indices.
475
+
476
+ Args:
477
+ generated_codes: The tensor of generated audio codes with delays,
478
+ shape [B, T_gen, C].
479
+ lengths_Bx: A tensor containing the valid length of generated codes
480
+ (excluding padding and BOS/EOS markers) for each item
481
+ in the batch, shape [B].
482
+
483
+ Returns:
484
+ A list of NumPy arrays, where each array represents the generated audio
485
+ waveform for one item in the batch. If DAC is not loaded, returns the
486
+ raw, reverted codebook indices as NumPy arrays.
487
+ """
488
+ num_channels = self.config.decoder_config.num_channels
489
+ batch_size = generated_codes.shape[0]
490
+ seq_length = generated_codes.shape[1]
491
+ delay_pattern = self.config.delay_pattern
492
+ audio_pad_value = self.config.pad_token_id
493
+ max_delay_pattern = max(delay_pattern)
494
+
495
+ revert_precomp = build_revert_indices(
496
+ B=batch_size,
497
+ T=seq_length,
498
+ C=num_channels,
499
+ delay_pattern=delay_pattern,
500
+ )
501
+
502
+ codebook = revert_audio_delay(
503
+ audio_BxTxC=generated_codes,
504
+ pad_value=audio_pad_value,
505
+ precomp=revert_precomp,
506
+ T=seq_length,
507
+ )[:, :-max_delay_pattern, :]
508
+
509
+ min_valid_index = 0
510
+ max_valid_index = 1023
511
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
512
+ codebook[invalid_mask] = 0
513
+
514
+ audios = []
515
+
516
+ if self.load_dac:
517
+ for i in range(batch_size):
518
+ audio = self._decode(codebook[i, : lengths_Bx[i], :])
519
+ audio_np = audio.cpu().numpy()
520
+ audios.append(audio_np)
521
+ else:
522
+ for i in range(batch_size):
523
+ audios.append(codebook[i, : lengths_Bx[i], :].cpu().numpy())
524
+ return audios
525
+
526
+ @torch.no_grad()
527
+ @torch.inference_mode()
528
+ def _encode(self, audio: torch.Tensor) -> torch.Tensor:
529
+ """
530
+ Encodes the given audio waveform into a tensor of DAC codebook indices
531
+ """
532
+ audio = audio.unsqueeze(0)
533
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
534
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data)
535
+ encoded_frame: torch.Tensor
536
+ return encoded_frame.squeeze(0).transpose(0, 1)
537
+
538
+ @torch.no_grad()
539
+ @torch.inference_mode()
540
+ def _decode(self, audio_codes: torch.Tensor) -> torch.Tensor:
541
+ """
542
+ Decodes the given frames into an output audio waveform
543
+ """
544
+ audio_codes = audio_codes.unsqueeze(0).transpose(1, 2)
545
+ audio_values, _, _ = self.dac_model.quantizer.from_codes(audio_codes)
546
+ audio_values = self.dac_model.decode(audio_values)
547
+ audio_values: torch.Tensor
548
+ return audio_values.squeeze()
549
+
550
+ def load_audio(self, audio_path: str) -> torch.Tensor:
551
+ """Loads and preprocesses an audio file for use as a prompt.
552
+
553
+ Loads the audio file, resamples it to the target sample rate if necessary,
554
+ preprocesses it using the DAC model's preprocessing, and encodes it into
555
+ DAC codebook indices.
556
+
557
+ Args:
558
+ audio_path: Path to the audio file.
559
+
560
+ Returns:
561
+ torch.Tensor: The encoded audio prompt as DAC codebook indices,
562
+ shape [T, C].
563
+
564
+ Raises:
565
+ RuntimeError: If the DAC model is not loaded (`load_dac=False` during init).
566
+ FileNotFoundError: If the audio file cannot be found.
567
+ Exception: If there's an error during loading or processing.
568
+ """
569
+ if self.dac_model is None:
570
+ raise RuntimeError("DAC model is required for loading audio prompts but was not loaded.")
571
+ audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
572
+ if sr != DEFAULT_SAMPLE_RATE:
573
+ audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
574
+ # Convert to mono if stereo
575
+ if audio.shape[0] > 1:
576
+ audio = torch.mean(audio, dim=0, keepdim=True) # Average channels to get mono
577
+ return self._encode(audio.to(self.device))
578
+
579
+ def save_audio(self, path: str, audio: np.ndarray):
580
+ """Saves the generated audio waveform to a file.
581
+
582
+ Uses the soundfile library to write the NumPy audio array to the specified
583
+ path with the default sample rate.
584
+
585
+ Args:
586
+ path: The path where the audio file will be saved.
587
+ audio: The audio waveform as a NumPy array.
588
+ """
589
+ import soundfile as sf
590
+
591
+ sf.write(path, audio, DEFAULT_SAMPLE_RATE)
592
+
593
+ @torch.inference_mode()
594
+ def generate(
595
+ self,
596
+ text: str | list[str],
597
+ max_tokens: int = 3072,
598
+ cfg_scale: float = 3.0,
599
+ temperature: float = 1.2,
600
+ top_p: float = 0.95,
601
+ use_torch_compile: bool = False,
602
+ cfg_filter_top_k: int = 45,
603
+ audio_prompt: list[str | torch.Tensor | None] | str | torch.Tensor | None = None,
604
+ audio_prompt_path: list[str | torch.Tensor | None] | str | torch.Tensor | None = None,
605
+ use_cfg_filter: bool | None = None,
606
+ verbose: bool = False,
607
+ ) -> np.ndarray | list[np.ndarray]:
608
+ """Generates audio corresponding to the input text.
609
+
610
+ Args:
611
+ text: The input text prompt, or a list of text prompts for batch generation.
612
+ max_tokens: The maximum number of audio tokens to generate per prompt.
613
+ Defaults to the model's configured audio length if None.
614
+ cfg_scale: The scale factor for classifier-free guidance (CFG). Higher values
615
+ lead to stronger guidance towards the text prompt.
616
+ temperature: The temperature for sampling. Higher values increase randomness.
617
+ top_p: The cumulative probability threshold for nucleus (top-p) sampling.
618
+ use_torch_compile: Whether to compile the generation steps using torch.compile.
619
+ Can significantly speed up generation after the initial
620
+ compilation overhead. Defaults to False.
621
+ cfg_filter_top_k: The number of top logits to consider during CFG filtering.
622
+ (Note: This parameter name might be slightly misleading based
623
+ on the code; it's used in the `_sample_next_token` function.)
624
+ audio_prompt: An audio prompt or list of prompts to condition the generation.
625
+ Can be a file path (str), a pre-loaded tensor (DAC codes), or None.
626
+ If a list, its length must match the batch size of the text input.
627
+ audio_prompt_path: (Deprecated) Use `audio_prompt` instead.
628
+ use_cfg_filter: (Deprecated) This parameter is no longer used.
629
+ verbose: If True, prints progress information during generation, including
630
+ speed metrics.
631
+
632
+ Returns:
633
+ If a single text prompt was provided, returns a NumPy array containing the
634
+ generated audio waveform.
635
+ If a list of text prompts was provided, returns a list of NumPy arrays,
636
+ each corresponding to a prompt in the input list. Returns None for a
637
+ sequence if no audio was generated for it.
638
+ """
639
+ batch_size = len(text) if isinstance(text, list) else 1
640
+ audio_eos_value = self.config.eos_token_id
641
+ audio_pad_value = self.config.pad_token_id
642
+ delay_pattern = self.config.delay_pattern
643
+ max_delay_pattern = max(delay_pattern)
644
+ delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long)
645
+ self.model.eval()
646
+
647
+ if audio_prompt_path:
648
+ print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
649
+ audio_prompt = audio_prompt_path
650
+ if use_cfg_filter is not None:
651
+ print("Warning: use_cfg_filter is deprecated.")
652
+
653
+ if verbose:
654
+ total_start_time = time.time()
655
+
656
+ if use_torch_compile and not hasattr(self, "_compiled"):
657
+ # Compilation can take about a minute.
658
+ self._prepare_generation = torch.compile(self._prepare_generation, dynamic=True, fullgraph=True)
659
+ self._decoder_step = torch.compile(self._decoder_step, fullgraph=True, mode="max-autotune")
660
+ self._compiled = True
661
+
662
+ if isinstance(audio_prompt, list):
663
+ audio_prompt = [self.load_audio(p) if isinstance(p, str) else p for p in audio_prompt]
664
+ elif isinstance(audio_prompt, str):
665
+ audio_prompt = [self.load_audio(audio_prompt)]
666
+ elif isinstance(audio_prompt, torch.Tensor):
667
+ audio_prompt = [audio_prompt]
668
+ elif audio_prompt is None:
669
+ audio_prompt = [None] * batch_size
670
+
671
+ assert len(audio_prompt) == batch_size, "Number of audio prompts must match batch size"
672
+
673
+ if isinstance(text, list):
674
+ text = [self._encode_text(t) for t in text]
675
+ else:
676
+ text = [self._encode_text(text)]
677
+ text = self._pad_text_input(text)
678
+
679
+ dec_state, dec_output = self._prepare_generation(text, audio_prompt, max_tokens=max_tokens)
680
+ dec_step = min(dec_output.prefill_steps) - 1
681
+ current_idx = torch.tensor([dec_step], device=self.device)
682
+
683
+ eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device)
684
+ eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
685
+ finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
686
+
687
+ bos_over = False
688
+
689
+ if verbose:
690
+ print("generate: starting generation loop")
691
+ if use_torch_compile:
692
+ print("generate: using use_torch_compile=True, the first step may be slow")
693
+ start_time = time.time()
694
+
695
+ # --- Generation Loop ---
696
+ while dec_step < max_tokens:
697
+ if (eos_countdown_Bx == 0).all():
698
+ break
699
+
700
+ current_step_idx = dec_step + 1
701
+ torch.compiler.cudagraph_mark_step_begin()
702
+ dec_state.prepare_step(dec_step)
703
+ tokens_Bx1xC = dec_output.get_tokens_at(dec_step).repeat_interleave(2, dim=0) # Repeat for CFG
704
+
705
+ pred_BxC = self._decoder_step(
706
+ tokens_Bx1xC,
707
+ dec_state,
708
+ cfg_scale,
709
+ temperature,
710
+ top_p,
711
+ cfg_filter_top_k,
712
+ current_idx,
713
+ )
714
+
715
+ current_idx += 1
716
+
717
+ active_mask_Bx = eos_countdown_Bx != 0
718
+ eos_trigger_Bx = torch.zeros_like(active_mask_Bx)
719
+ if active_mask_Bx.any():
720
+ is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value)
721
+ is_max_len = current_step_idx >= max_tokens - max_delay_pattern
722
+ eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len
723
+ eos_detected_Bx |= eos_trigger_Bx
724
+ start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0)
725
+ if start_countdown_mask_Bx.any():
726
+ eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern
727
+ finished_step_Bx[start_countdown_mask_Bx] = current_step_idx
728
+
729
+ padding_mask_Bx = eos_countdown_Bx > 0
730
+ if padding_mask_Bx.any():
731
+ pred_active_BxC = pred_BxC[padding_mask_Bx].clone()
732
+ countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx]
733
+ step_after_eos_Bx = max_delay_pattern - countdown_active_Bx
734
+ step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1)
735
+ delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0)
736
+ eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_
737
+ pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_
738
+ pred_active_BxC[eos_mask_NxC] = audio_eos_value
739
+ pred_active_BxC[pad_mask_NxC] = audio_pad_value
740
+ pred_BxC[padding_mask_Bx] = pred_active_BxC
741
+ eos_countdown_Bx[padding_mask_Bx] -= 1
742
+
743
+ # --- Update BOS flag (Original) ---
744
+ if not bos_over:
745
+ bos_over = all(
746
+ dec_step - prefill_step > max_delay_pattern for prefill_step in dec_output.prefill_steps
747
+ )
748
+
749
+ dec_output.update_one(pred_BxC, current_step_idx, not bos_over)
750
+
751
+ dec_step += 1
752
+
753
+ if verbose and dec_step % 86 == 0:
754
+ duration = time.time() - start_time
755
+ if duration > 0:
756
+ print(
757
+ f"generate step {dec_step}: speed={86 * batch_size / duration:.3f} tokens/s, realtime factor={batch_size / duration:.3f}x"
758
+ )
759
+ start_time = time.time()
760
+
761
+ # --- Finalize and Extract Output ---
762
+ final_step = dec_step + 1
763
+
764
+ finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern
765
+
766
+ prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device)
767
+ lengths_Bx = finished_step_Bx - prefill_steps_tensor
768
+ lengths_Bx = torch.clamp(lengths_Bx, min=0)
769
+
770
+ max_len = lengths_Bx.max().item() + max_delay_pattern
771
+ outputs = []
772
+
773
+ if max_len > 0:
774
+ num_channels = self.config.decoder_config.num_channels
775
+ audio_pad_value = self.config.pad_token_id
776
+ generated_codes = torch.full(
777
+ (batch_size, max_len, num_channels),
778
+ fill_value=audio_pad_value,
779
+ dtype=torch.long,
780
+ device=self.device,
781
+ )
782
+
783
+ for i in range(batch_size):
784
+ start_step = dec_output.prefill_steps[i]
785
+ actual_len = lengths_Bx[i].item() + max_delay_pattern
786
+ if actual_len > 0:
787
+ tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :]
788
+ generated_codes[i, :actual_len, :] = tokens_to_copy
789
+
790
+ if verbose:
791
+ avg_steps = lengths_Bx.float().mean().item()
792
+ total_duration = time.time() - total_start_time
793
+ print(f"generate: avg steps={avg_steps:.1f}, total duration={total_duration:.3f}s")
794
+
795
+ del dec_state
796
+
797
+ outputs = self._generate_output(generated_codes, lengths_Bx)
798
+ else:
799
+ print("Warning: Nothing generated for any sequence in the batch.")
800
+ outputs = [None] * batch_size
801
+
802
+ return outputs if batch_size > 1 else outputs[0]
state.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from .config import DiaConfig
7
+
8
+
9
+ def create_attn_mask(
10
+ q_padding_mask_1d: torch.Tensor,
11
+ k_padding_mask_1d: torch.Tensor,
12
+ device: torch.device,
13
+ is_causal: bool = False,
14
+ ) -> torch.Tensor:
15
+ """
16
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
17
+ """
18
+ # B1, Tq = q_padding_mask_1d.shape
19
+ # B2, Tk = k_padding_mask_1d.shape
20
+
21
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
+
24
+ # Condition A: Non-padding query attends to non-padding key
25
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
+
27
+ # Condition B: Padding query attends to padding key
28
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
+
30
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
31
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
+
33
+ if is_causal:
34
+ # assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
35
+ causal_mask_2d = torch.tril(torch.ones_like(mask[0], dtype=torch.bool, device=device)) # Shape [B, Tq, Tk]
36
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
37
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
38
+ else:
39
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
40
+
41
+
42
+ @dataclass
43
+ class EncoderInferenceState:
44
+ """Parameters specifically for encoder inference."""
45
+
46
+ max_seq_len: int
47
+ device: torch.device
48
+ positions: torch.Tensor
49
+ padding_mask: torch.Tensor
50
+ attn_mask: torch.Tensor
51
+
52
+ @classmethod
53
+ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
54
+ """Creates EtorchrInferenceParams from DiaConfig and a device."""
55
+ device = cond_src.device
56
+
57
+ positions = torch.arange(
58
+ config.encoder_config.max_position_embeddings, dtype=torch.float32, device=device
59
+ ).unsqueeze(0)
60
+ padding_mask = (cond_src.squeeze(1) != 0).to(device).repeat_interleave(2, dim=0)
61
+ attn_mask = create_attn_mask(padding_mask, padding_mask, device, is_causal=False)
62
+
63
+ return cls(
64
+ max_seq_len=config.encoder_config.max_position_embeddings,
65
+ device=device,
66
+ positions=positions,
67
+ padding_mask=padding_mask,
68
+ attn_mask=attn_mask,
69
+ )
70
+
71
+
72
+ class KVCache(torch.nn.Module):
73
+ k: torch.Tensor
74
+ v: torch.Tensor
75
+
76
+ def __init__(
77
+ self,
78
+ batch_size: int,
79
+ num_heads: int,
80
+ max_len: int,
81
+ head_dim: int,
82
+ dtype: torch.dtype,
83
+ device: torch.device,
84
+ k: torch.Tensor | None = None,
85
+ v: torch.Tensor | None = None,
86
+ ):
87
+ k = torch.zeros((2 * batch_size, num_heads, max_len, head_dim), dtype=dtype, device=device) if k is None else k
88
+ v = torch.zeros((2 * batch_size, num_heads, max_len, head_dim), dtype=dtype, device=device) if v is None else v
89
+ super().__init__()
90
+
91
+ self.register_buffer("k", k)
92
+ self.register_buffer("v", v)
93
+
94
+ @classmethod
95
+ def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
96
+ return cls(
97
+ batch_size=k.shape[0] // 2,
98
+ num_heads=k.shape[1],
99
+ max_len=k.shape[2],
100
+ head_dim=k.shape[3],
101
+ dtype=k.dtype,
102
+ device=k.device,
103
+ k=k,
104
+ v=v,
105
+ )
106
+
107
+ def update(self, k: torch.Tensor, v: torch.Tensor, current_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
108
+ k_out, v_out = self.k, self.v
109
+ k_out[:, :, current_idx, :] = k
110
+ v_out[:, :, current_idx, :] = v
111
+ return self.k, self.v
112
+
113
+ def prefill(self, k: torch.Tensor, v: torch.Tensor):
114
+ prefill_len = k.shape[2]
115
+ self.k[:, :, :prefill_len, :] = k
116
+ self.v[:, :, :prefill_len, :] = v
117
+
118
+
119
+ @dataclass
120
+ class DecoderInferenceState:
121
+ """Parameters specifically for decoder inference."""
122
+
123
+ device: torch.device
124
+ dtype: torch.dtype
125
+ enc_out: torch.Tensor
126
+ enc_positions: torch.Tensor
127
+ dec_positions: torch.Tensor
128
+ self_attn_cache: list[KVCache]
129
+ cross_attn_cache: list[KVCache]
130
+ casual_attn_mask: torch.Tensor
131
+ cross_attn_mask: torch.Tensor
132
+
133
+ @classmethod
134
+ def new(
135
+ cls,
136
+ config: DiaConfig,
137
+ enc_state: EncoderInferenceState,
138
+ enc_out: torch.Tensor,
139
+ dec_cross_attn_cache: list[KVCache],
140
+ compute_dtype: torch.dtype,
141
+ max_generation_length: Optional[int] = None,
142
+ ) -> "DecoderInferenceState":
143
+ """Creates DecoderInferenceParams from DiaConfig and a device."""
144
+ device = enc_out.device
145
+ max_audio_len = max_generation_length or config.decoder_config.max_position_embeddings
146
+ batch_size = enc_out.shape[0] // 2
147
+
148
+ dec_positions = torch.full((2 * batch_size, 1), fill_value=0, dtype=torch.int32, device=device)
149
+ causal_mask = torch.tril(torch.ones(max_audio_len, max_audio_len, dtype=torch.bool, device=device))
150
+ dec_mask = torch.ones((2 * batch_size, 1), dtype=torch.bool, device=device)
151
+ cross_attn_mask = create_attn_mask(dec_mask, enc_state.padding_mask, device, is_causal=False)
152
+
153
+ self_attn_cache = [
154
+ KVCache(
155
+ batch_size,
156
+ config.decoder_config.num_key_value_heads,
157
+ max_audio_len,
158
+ config.decoder_config.head_dim,
159
+ compute_dtype,
160
+ device,
161
+ )
162
+ for _ in range(config.decoder_config.num_hidden_layers)
163
+ ]
164
+
165
+ return cls(
166
+ device=device,
167
+ dtype=compute_dtype,
168
+ enc_out=enc_out,
169
+ enc_positions=enc_state.positions,
170
+ dec_positions=dec_positions,
171
+ self_attn_cache=self_attn_cache,
172
+ cross_attn_cache=dec_cross_attn_cache,
173
+ casual_attn_mask=causal_mask,
174
+ cross_attn_mask=cross_attn_mask,
175
+ )
176
+
177
+ def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
178
+ if step_to is None:
179
+ step_to = step_from + 1
180
+ self.dec_positions = torch.arange(step_from, step_to, dtype=torch.int32, device=self.device).unsqueeze(0)
181
+
182
+
183
+ @dataclass
184
+ class DecoderOutput:
185
+ generated_tokens: torch.Tensor
186
+ prefill_steps: list[int]
187
+
188
+ @classmethod
189
+ def new(cls, batch_size: int, config: DiaConfig, device: torch.device) -> "DecoderOutput":
190
+ max_audio_len = config.decoder_config.max_position_embeddings
191
+ return cls(
192
+ generated_tokens=torch.full(
193
+ (batch_size, max_audio_len, config.decoder_config.num_channels),
194
+ fill_value=-1,
195
+ dtype=torch.int,
196
+ device=device,
197
+ ),
198
+ prefill_steps=[],
199
+ )
200
+
201
+ def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
202
+ if step_to is None:
203
+ step_to = step_from + 1
204
+ return self.generated_tokens[:, step_from:step_to, :]
205
+
206
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
207
+ dec_out = dec_out.to(self.generated_tokens.dtype)
208
+ if apply_mask:
209
+ mask = self.generated_tokens[:, step, :] == -1
210
+ self.generated_tokens[:, step, :] = torch.where(mask, dec_out, self.generated_tokens[:, step, :])
211
+ else:
212
+ self.generated_tokens[:, step, :] = dec_out
213
+
214
+ def prefill(self, dec_out: torch.Tensor, prefill_steps: list[int]):
215
+ length = dec_out.shape[1]
216
+ self.generated_tokens[:, :length, :] = dec_out
217
+ self.prefill_steps = prefill_steps