Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from configs import MusicGenConfig | |
from extensions import CodeGenBlock | |
from TTS.tts.layers.xtts.transformer import XTransformerEncoder, XTransformerDecoder | |
from TTS.tts.layers.xtts.flow import VitsFlowModules | |
from TTS.tts.layers.xtts.tokenizer import VoiceBPE | |
class SentimentClassifierModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.embedding = nn.Embedding(config.vocab_size, config.d_model) | |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True) | |
self.fc = nn.Linear(config.d_model * 2, 3) | |
def forward(self, input_ids): | |
embedded = self.embedding(input_ids) | |
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False) | |
packed_output, _ = self.lstm(packed_embedded) | |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) | |
pooled = output[:, -1, :]; logits = self.fc(pooled); return logits | |
class STTModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1) | |
self.relu1 = nn.ReLU() | |
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) | |
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1) | |
self.relu2 = nn.ReLU() | |
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2) | |
self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True) | |
self.fc = nn.Linear(128 * 2, config.vocab_size) | |
def forward(self, audio_data): | |
x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1)))) | |
x = self.pool2(self.relu2(self.conv2(x))); x = x.transpose(1, 2).contiguous(); x = x.view(x.size(0), -1, x.size(2)) | |
packed_output = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False); packed_output, _ = self.lstm(packed_output) | |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); return logits | |
class TTSModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.embedding = nn.Embedding(config.vocab_size, config.d_model) | |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True) | |
self.fc = nn.Linear(config.d_model * 2, 1) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input_ids): | |
embedded = self.embedding(input_ids); packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False) | |
packed_output, _ = self.lstm(packed_embedded); output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); audio = self.sigmoid(logits); return audio | |
class MusicGenModel(nn.Module): | |
def __init__(self, config: MusicGenConfig): | |
super().__init__() | |
self.config = config | |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) | |
self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)]) | |
self.fc_out = nn.Linear(config.hidden_size, config.vocab_size) | |
def forward(self, input_ids): | |
embedded_tokens = self.embedding(input_ids); hidden_states = embedded_tokens | |
for layer in self.transformer_layers: hidden_states = layer(hidden_states) | |
logits = self.fc_out(hidden_states); return logits | |
def sample(self, attributes, sample_rate, duration): | |
input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device); audio_output = []; num_steps = int(duration * sample_rate / 1024) | |
for _ in tqdm(range(num_steps), desc="Generating music"): logits = self.forward(input_tokens); predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True); audio_output.append(predicted_token.cpu()); input_tokens = torch.cat((input_tokens, predicted_token), dim=1) | |
audio_output = torch.cat(audio_output, dim=1).float(); return audio_output | |
class XTTSModelClass(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25) | |
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths): | |
return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths) | |
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0): | |
return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty) | |
class XTTSModel(nn.Module): | |
def __init__(self, config, num_speakers, num_languages): | |
super().__init__() | |
self.config = config | |
self.num_speakers = num_speakers | |
self.num_languages = num_languages | |
self.encoder = XTransformerEncoder(**config.encoder_config) | |
self.decoder = XTransformerDecoder(**config.decoder_config) | |
self.flow_modules = VitsFlowModules(**config.flow_config) | |
self.voice_tokenizer = VoiceBPE(vocab_path=config.voice_tokenizer_config.vocab_path, vocab_size=config.voice_tokenizer_config.vocab_size) | |
self.language_embedding = nn.Embedding(num_languages, config.embedding_dim) | |
self.speaker_embedding = nn.Embedding(num_speakers, config.embedding_dim) | |
self.text_embedding = nn.Embedding(config.num_chars, config.embedding_dim) | |
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths): | |
lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens) | |
encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder(encoder_outputs, lang_embed + spk_embed, voice_samples); return mel_outputs, None | |
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0): | |
language_ids = torch.tensor([language_id], dtype=torch.long).to(device); speaker_ids = torch.tensor([speaker_id], dtype=torch.long).to(device) | |
text_tokens = self.voice_tokenizer.text_to_ids(text).to(device); text_lengths = torch.tensor([text_tokens.shape[0]], dtype=torch.long).to(device); voice_sample_lengths = torch.tensor([voice_sample.shape[0]], dtype=torch.long).to(device) | |
lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens) | |
encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder.inference(encoder_outputs, lang_embed + spk_embed, voice_sample, temperature=temperature, length_penalty=length_penalty) | |
return mel_outputs | |