|
"""Configuration management module for the Dia model. |
|
|
|
This module provides comprehensive configuration management for the Dia model, |
|
utilizing Pydantic for validation. It defines configurations for data processing, |
|
model architecture (encoder and decoder), and training settings. |
|
|
|
Key components: |
|
- DataConfig: Parameters for data loading and preprocessing. |
|
- EncoderConfig: Architecture details for the encoder module. |
|
- DecoderConfig: Architecture details for the decoder module. |
|
- ModelConfig: Combined model architecture settings. |
|
- TrainingConfig: Training hyperparameters and settings. |
|
- DiaConfig: Master configuration combining all components. |
|
""" |
|
|
|
import os |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
class EncoderConfig(BaseModel, frozen=True): |
|
"""Configuration for the encoder component of the Dia model. |
|
|
|
Attributes: |
|
model_type: Type of the model, defaults to "dia_encoder". |
|
hidden_size: Size of the encoder layers, defaults to 1024. |
|
intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096. |
|
num_hidden_layers: Number of hidden layers in the encoder, defaults to 12. |
|
num_attention_heads: Number of attention heads in the encoder, defaults to 16. |
|
num_key_value_heads: Number of key-value heads in the encoder, defaults to 16. |
|
head_dim: Dimension of each attention head, defaults to 128. |
|
hidden_act: Activation function in the encoder, defaults to "silu". |
|
max_position_embeddings: Maximum number of position embeddings, defaults to 1024. |
|
initializer_range: Range for initializing weights, defaults to 0.02. |
|
norm_eps: Epsilon value for normalization layers, defaults to 1e-5. |
|
rope_theta: Theta value for RoPE, defaults to 10000.0. |
|
rope_scaling: Optional scaling factor for RoPE. |
|
vocab_size: Vocabulary size, defaults to 256. |
|
""" |
|
|
|
head_dim: int = Field(default=128, gt=0) |
|
hidden_act: str = Field(default="silu") |
|
hidden_size: int = Field(default=1024, gt=0) |
|
initializer_range: float = Field(default=0.02) |
|
intermediate_size: int = Field(default=4096, gt=0) |
|
max_position_embeddings: int = Field(default=1024, gt=0) |
|
model_type: str = Field(default="dia_encoder") |
|
norm_eps: float = Field(default=1e-5) |
|
num_attention_heads: int = Field(default=16, gt=0) |
|
num_hidden_layers: int = Field(default=12, gt=0) |
|
num_key_value_heads: int = Field(default=16, gt=0) |
|
rope_scaling: float | None = Field(default=None) |
|
rope_theta: float = Field(default=10000.0) |
|
vocab_size: int = Field(default=256, gt=0) |
|
|
|
|
|
class DecoderConfig(BaseModel, frozen=True): |
|
"""Configuration for the decoder component of the Dia model. |
|
|
|
Attributes: |
|
model_type: Type of the model, defaults to "dia_decoder". |
|
hidden_size: Size of the decoder layers, defaults to 2048. |
|
intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192. |
|
num_hidden_layers: Number of hidden layers in the decoder, defaults to 18. |
|
num_attention_heads: Number of attention heads in the decoder, defaults to 16. |
|
num_key_value_heads: Number of key-value heads in the decoder, defaults to 4. |
|
head_dim: Dimension of each attention head, defaults to 128. |
|
cross_hidden_size: Size of the cross-attention layers, defaults to 1024. |
|
cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16. |
|
cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16. |
|
cross_head_dim: Dimension of each cross-attention head, defaults to 128. |
|
hidden_act: Activation function in the decoder, defaults to "silu". |
|
max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072. |
|
initializer_range: Range for initializing weights in the decoder, defaults to 0.02. |
|
norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5. |
|
rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0. |
|
rope_scaling: Optional scaling factor for RoPE in the decoder. |
|
vocab_size: Vocabulary size for the decoder, defaults to 1028. |
|
num_channels: Number of channels in the decoder, defaults to 9. |
|
""" |
|
|
|
cross_head_dim: int = Field(default=128, gt=0) |
|
cross_hidden_size: int = Field(default=1024, gt=0) |
|
cross_num_attention_heads: int = Field(default=16, gt=0) |
|
cross_num_key_value_heads: int = Field(default=16, gt=0) |
|
head_dim: int = Field(default=128, gt=0) |
|
hidden_act: str = Field(default="silu") |
|
hidden_size: int = Field(default=2048, gt=0) |
|
initializer_range: float = Field(default=0.02) |
|
intermediate_size: int = Field(default=8192, gt=0) |
|
max_position_embeddings: int = Field(default=3072, gt=0) |
|
model_type: str = Field(default="dia_decoder") |
|
norm_eps: float = Field(default=1e-5) |
|
num_attention_heads: int = Field(default=16, gt=0) |
|
num_channels: int = Field(default=9, gt=0) |
|
num_hidden_layers: int = Field(default=18, gt=0) |
|
num_key_value_heads: int = Field(default=4, gt=0) |
|
rope_scaling: float | None = Field(default=None) |
|
rope_theta: float = Field(default=10000.0) |
|
vocab_size: int = Field(default=1028, gt=0) |
|
|
|
|
|
class DiaConfig(BaseModel, frozen=True): |
|
"""Main configuration container for the Dia model architecture. |
|
|
|
Attributes: |
|
model_type: Type of the model, defaults to "dia". |
|
is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True. |
|
encoder: Configuration for the encoder component. |
|
decoder: Configuration for the decoder component. |
|
src_vocab_size: Size of the source (text) vocabulary. |
|
tgt_vocab_size: Size of the target (audio code) vocabulary. |
|
initializer_range: Range for initializing weights, defaults to 0.02. |
|
norm_eps: Epsilon value for normalization layers, defaults to 1e-5. |
|
torch_dtype: Data type for model weights in PyTorch, defaults to "float32". |
|
bos_token_id: Beginning-of-sequence token ID, defaults to 1026. |
|
eos_token_id: End-of-sequence token ID, defaults to 1024. |
|
pad_token_id: Padding token ID, defaults to 1025. |
|
rope_theta: Theta value for RoPE, defaults to 10000.0. |
|
rope_scaling: Optional scaling factor for RoPE. |
|
transformers_version: Version of the transformers library, defaults to "4.53.0.dev0". |
|
architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"]. |
|
delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15]. |
|
""" |
|
|
|
architectures: list[str] = Field(default_factory=lambda: ["DiaForConditionalGeneration"]) |
|
bos_token_id: int = Field(default=1026) |
|
decoder_config: DecoderConfig |
|
delay_pattern: list[int] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]) |
|
encoder_config: EncoderConfig |
|
eos_token_id: int = Field(default=1024) |
|
initializer_range: float = Field(default=0.02) |
|
is_encoder_decoder: bool = Field(default=True) |
|
model_type: str = Field(default="dia") |
|
norm_eps: float = Field(default=1e-5) |
|
pad_token_id: int = Field(default=1025) |
|
torch_dtype: str = Field(default="float32") |
|
transformers_version: str = Field(default="4.53.0.dev0") |
|
|
|
def save(self, path: str) -> None: |
|
"""Save the current configuration instance to a JSON file. |
|
|
|
Ensures the parent directory exists and the file has a .json extension. |
|
|
|
Args: |
|
path: The target file path to save the configuration. |
|
|
|
Raises: |
|
ValueError: If the path is not a file with a .json extension. |
|
""" |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
config_json = self.model_dump_json(indent=2) |
|
with open(path, "w") as f: |
|
f.write(config_json) |
|
|
|
@classmethod |
|
def load(cls, path: str) -> "DiaConfig | None": |
|
"""Load and validate a Dia configuration from a JSON file. |
|
|
|
Args: |
|
path: The path to the configuration file. |
|
|
|
Returns: |
|
A validated DiaConfig instance if the file exists and is valid, |
|
otherwise None if the file is not found. |
|
|
|
Raises: |
|
ValueError: If the path does not point to an existing .json file. |
|
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. |
|
""" |
|
try: |
|
with open(path, "r") as f: |
|
content = f.read() |
|
return cls.model_validate_json(content) |
|
except FileNotFoundError: |
|
return None |
|
|