Hhhh / models.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
4.41 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from configs import *
from extensions import *
class SentimentClassifierModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
self.fc = nn.Linear(config.d_model * 2, 3)
def forward(self, input_ids):
embedded = self.embedding(input_ids)
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
packed_output, _ = self.lstm(packed_embedded)
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
pooled = output[:, -1, :]
logits = self.fc(pooled)
return logits
class STTModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
self.fc = nn.Linear(128 * 2, config.vocab_size)
def forward(self, audio_data):
x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.transpose(1, 2).contiguous()
x = x.view(x.size(0), -1, x.size(2))
packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
packed_output, _ = self.lstm(packed_output)
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
logits = self.fc(output)
return logits
class TTSModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
self.fc = nn.Linear(config.d_model * 2, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids):
embedded = self.embedding(input_ids)
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
packed_output, _ = self.lstm(packed_embedded)
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
logits = self.fc(output)
audio = self.sigmoid(logits)
return audio
class MusicGenModel(nn.Module):
def __init__(self, config: MusicGenConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids):
embedded_tokens = self.embedding(input_ids)
hidden_states = embedded_tokens
for layer in self.transformer_layers:
hidden_states = layer(hidden_states)
logits = self.fc_out(hidden_states)
return logits
def sample(self, attributes, sample_rate, duration):
input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
audio_output = []
num_steps = int(duration * sample_rate / 1024)
for _ in tqdm(range(num_steps), desc="Generating music"):
logits = self.forward(input_tokens)
predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
audio_output.append(predicted_token.cpu())
input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
audio_output = torch.cat(audio_output, dim=1).float()
return audio_output