Spaces:
Running
Running
from __future__ import annotations | |
import os | |
from tqdm import tqdm | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pylab as plt | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.optim.lr_scheduler import LinearLR, SequentialLR | |
import torchaudio | |
from einops import rearrange | |
from accelerate import Accelerator | |
from accelerate.utils import DistributedDataParallelKwargs | |
from ema_pytorch import EMA | |
from loguru import logger | |
from e2_tts_pytorch.e2_tts_crossatt6 import ( | |
E2TTS, | |
DurationPredictor, | |
MelSpec | |
) | |
import traceback | |
import numpy as np | |
from moviepy.editor import AudioFileClip, VideoFileClip | |
def exists(v): | |
return v is not None | |
def default(v, d): | |
return v if exists(v) else d | |
def to_numpy(t): | |
return t.detach().cpu().numpy() | |
# plot spectrogram | |
def plot_spectrogram(spectrogram): | |
spectrogram = to_numpy(spectrogram) | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none") | |
plt.colorbar(im, ax=ax) | |
plt.xlabel("Frames") | |
plt.ylabel("Channels") | |
plt.tight_layout() | |
fig.canvas.draw() | |
plt.close() | |
return fig | |
# collation | |
def collate_fn(batch): | |
mel_specs = [item['mel_spec'].squeeze(0) for item in batch] | |
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) | |
max_mel_length = mel_lengths.amax() | |
padded_mel_specs = [] | |
for spec in mel_specs: | |
padding = (0, max_mel_length - spec.size(-1)) | |
padded_spec = F.pad(spec, padding, value = 0) | |
padded_mel_specs.append(padded_spec) | |
mel_specs = torch.stack(padded_mel_specs) | |
text = [item['text'] for item in batch] | |
text_lengths = torch.LongTensor([len(item) for item in text]) | |
return dict( | |
mel = mel_specs, | |
mel_lengths = mel_lengths, | |
text = text, | |
text_lengths = text_lengths, | |
) | |
# dataset | |
class HFDataset(Dataset): | |
def __init__( | |
self, | |
hf_dataset: Dataset, | |
target_sample_rate = 24_000, | |
hop_length = 256 | |
): | |
self.data = hf_dataset | |
self.target_sample_rate = target_sample_rate | |
self.hop_length = hop_length | |
self.mel_spectrogram = MelSpec(sampling_rate=target_sample_rate) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
row = self.data[index] | |
audio = row['audio']['array'] | |
#logger.info(f"Audio shape: {audio.shape}") | |
sample_rate = row['audio']['sampling_rate'] | |
duration = audio.shape[-1] / sample_rate | |
if duration > 20 or duration < 0.3: | |
logger.warning(f"Skipping due to duration out of bound: {duration}") | |
return self.__getitem__((index + 1) % len(self.data)) | |
audio_tensor = torch.from_numpy(audio).float() | |
if sample_rate != self.target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) | |
audio_tensor = resampler(audio_tensor) | |
audio_tensor = rearrange(audio_tensor, 't -> 1 t') | |
mel_spec = self.mel_spectrogram(audio_tensor) | |
mel_spec = rearrange(mel_spec, '1 d t -> d t') | |
text = row['transcript'] | |
return dict( | |
mel_spec = mel_spec, | |
text = text, | |
) | |
# trainer | |
class E2Trainer: | |
def __init__( | |
self, | |
model: E2TTS, | |
optimizer, | |
num_warmup_steps=20000, | |
grad_accumulation_steps=1, | |
duration_predictor: DurationPredictor | None = None, | |
checkpoint_path = None, | |
log_file = "logs.txt", | |
max_grad_norm = 1.0, | |
sample_rate = 22050, | |
tensorboard_log_dir = 'runs/e2_tts_experiment', | |
accelerate_kwargs: dict = dict(), | |
ema_kwargs: dict = dict(), | |
use_switch_ema = False, | |
if_text = False, | |
if_prompt = False | |
): | |
logger.add(log_file) | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) | |
self.accelerator = Accelerator( | |
log_with = "all", | |
kwargs_handlers = [ddp_kwargs], | |
gradient_accumulation_steps = grad_accumulation_steps, | |
**accelerate_kwargs | |
) | |
self.accelerator.wait_for_everyone() | |
self.target_sample_rate = sample_rate | |
self.model = model | |
self.need_velocity_consistent_loss = model.velocity_consistency_weight > 0. | |
#self.ema_model = EMA( | |
# model, | |
# include_online_model = False, | |
# **ema_kwargs | |
#) | |
self.use_switch_ema = use_switch_ema | |
self.duration_predictor = duration_predictor | |
self.optimizer = optimizer | |
self.num_warmup_steps = num_warmup_steps | |
self.checkpoint_path = default(checkpoint_path, 'model.pth') | |
self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) | |
self.model, self.optimizer = self.accelerator.prepare( | |
self.model, self.optimizer | |
) | |
#self.ema_model = self.accelerator.prepare(self.ema_model) | |
self.max_grad_norm = max_grad_norm | |
self.writer = SummaryWriter(log_dir=tensorboard_log_dir) | |
self.tensorboard_log_dir = tensorboard_log_dir | |
self.if_text = if_text | |
self.if_prompt = if_prompt | |
self.device_id = self.accelerator.device.index | |
self.num_processes = self.accelerator.num_processes | |
def is_main(self): | |
return self.accelerator.is_main_process | |
def save_checkpoint(self, step, finetune=False): | |
self.accelerator.wait_for_everyone() | |
if self.is_main: | |
checkpoint = dict( | |
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), | |
#optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), | |
#ema_model_state_dict = self.ema_model.state_dict(), | |
#scheduler_state_dict = self.scheduler.state_dict(), | |
#step = step, | |
) | |
self.accelerator.save(checkpoint, self.tensorboard_log_dir + "/" + str(step) + ".pt") | |
def load_checkpoint(self): | |
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path): | |
return 0 | |
checkpoint = torch.load(self.checkpoint_path, map_location='cpu') | |
for key in list(checkpoint['model_state_dict'].keys()): | |
#if key.startswith('mel_spec.'): | |
# del checkpoint['model_state_dict'][key] | |
if key.startswith('transformer.text_registers'): | |
if checkpoint['model_state_dict'][key].shape[1] != self.accelerator.unwrap_model(self.model).transformer.text_registers.shape[1]: | |
print('miss match: transformer.text_registers', checkpoint['model_state_dict'][key].shape, self.accelerator.unwrap_model(self.model).transformer.text_registers.shape) | |
del checkpoint['model_state_dict'][key] | |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'], strict=False) | |
#self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) | |
#if self.is_main: | |
# self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) | |
#if self.scheduler: | |
# self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
#return checkpoint['step'] | |
return 0 | |
def evaluate(self, eval_dataloader, epoch, epochs, global_step): | |
if eval_dataloader is None: | |
return | |
total_val_loss, N, total_lossmore1, total_lossmore2 = 0, 0, 0, 0 | |
self.model.eval() | |
eval_progress_bar = tqdm(eval_dataloader, desc=f"Epoch {epoch}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process) | |
for step, batch in enumerate(eval_dataloader): | |
with self.accelerator.accumulate(self.model) and torch.no_grad(): | |
text, mel_spec, video_paths, mel_lengths, video_drop_prompt, audio_drop_prompt = batch | |
val_loss, cond, pred, pred_data, lossmore = self.model( | |
mel_spec, | |
text=(text if self.if_text else None), | |
times=0.5, | |
lens=mel_lengths, | |
velocity_consistency_model=None, | |
prompt=(text if self.if_prompt else None), | |
video_drop_prompt=video_drop_prompt, | |
audio_drop_prompt=audio_drop_prompt, | |
val=True, | |
video_paths=video_paths | |
) | |
a = torch.tensor(val_loss.item()*len(text), dtype=torch.float32).reshape(1).to(val_loss.device) | |
b = torch.tensor(len(text), dtype=torch.int32).reshape(1).to(val_loss.device) | |
c = torch.tensor(lossmore[0].item()*len(text), dtype=torch.float32).reshape(1).to(lossmore[0].device) | |
d = torch.tensor(lossmore[1].item()*len(text), dtype=torch.float32).reshape(1).to(lossmore[1].device) | |
val_loss_gather, N_gather, lossmore_gather1, lossmore_gather2 = self.accelerator.gather_for_metrics((a, b, c, d)) | |
for i in range(val_loss_gather.shape[0]): | |
total_val_loss += val_loss_gather[i].item() | |
N += N_gather[i].item() | |
total_lossmore1 += lossmore_gather1[i].item() | |
total_lossmore2 += lossmore_gather2[i].item() | |
eval_progress_bar.update(1) | |
if self.accelerator.is_local_main_process: | |
total_val_loss = round(total_val_loss/float(N), 4) | |
total_lossmore1 = round(total_lossmore1/float(N), 4) | |
total_lossmore2 = round(total_lossmore2/float(N), 4) | |
result_string = "Epoch: {}, GlobalStep: {}, ValLoss: {}, N: {}, Lossmore1: {}, Lossmore2: {} (average loss)\n".format(epoch, global_step, total_val_loss, N, total_lossmore1, total_lossmore2) | |
logger.info(result_string) | |
torch.cuda.empty_cache() | |
def train(self, datasets, epochs, batch_size, num_workers=12, save_step=1000): | |
params_d = {} | |
trainable_d = {} | |
for n, p in self.model.named_parameters(): | |
key = ".".join(n.split(".")[:2]) | |
if key not in params_d: | |
params_d[key] = 0 | |
trainable_d[key] = p.requires_grad | |
params_d[key] += p.numel() | |
assert(trainable_d[key] == p.requires_grad) | |
print(params_d) | |
print(trainable_d) | |
num_trainable_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
print("Num trainable parameters: {}".format(num_trainable_parameters)) | |
train_dataset = datasets[0] | |
eval_datasets = datasets[1:] | |
#train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True) | |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size*train_dataset.multi, collate_fn=train_dataset.collate_fn, num_workers=num_workers, drop_last=True, pin_memory=True) | |
eval_dataloaders = [DataLoader(eval_dataset, shuffle=False, batch_size=16, collate_fn=eval_dataset.collate_fn, num_workers=num_workers, drop_last=False, pin_memory=True) if eval_dataset is not None else None for eval_dataset in eval_datasets] | |
total_steps = len(train_dataloader) * epochs | |
decay_steps = total_steps - self.num_warmup_steps | |
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=self.num_warmup_steps) | |
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) | |
self.scheduler = SequentialLR(self.optimizer, | |
schedulers=[warmup_scheduler, decay_scheduler], | |
milestones=[self.num_warmup_steps]) | |
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) | |
eval_dataloaders = [self.accelerator.prepare(eval_dataloader) for eval_dataloader in eval_dataloaders if eval_dataloader is not None] | |
start_step = self.load_checkpoint() | |
global_step = start_step | |
for epoch in range(epochs): | |
if epoch == 0: | |
[self.evaluate(eval_dataloader, 1, epochs, 0) for eval_dataloader in eval_dataloaders] | |
self.model.train() | |
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process) | |
epoch_loss = 0.0 | |
for batch in progress_bar: | |
with self.accelerator.accumulate(self.model): | |
#text_inputs = batch['text'] | |
#mel_spec = rearrange(batch['mel'], 'b d n -> b n d') | |
#mel_lengths = batch["mel_lengths"] | |
text, mel_spec, video_paths, mel_lengths, video_drop_prompt, audio_drop_prompt = batch | |
#print("batchsize", len(text)) | |
#print("batch", text, mel_spec.shape, video_paths, mel_lengths) | |
if exists(self.duration_predictor): | |
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) | |
self.writer.add_scalar('duration loss', dur_loss.detach().cpu().item(), global_step) | |
velocity_consistency_model = None | |
#if self.need_velocity_consistent_loss and self.ema_model.initted: | |
# velocity_consistency_model = self.accelerator.unwrap_model(self.ema_model).ema_model | |
loss, cond, pred, pred_data, lossmore = self.model( | |
mel_spec, | |
text=(text if self.if_text else None), | |
lens=mel_lengths, | |
velocity_consistency_model=velocity_consistency_model, | |
prompt=(text if self.if_prompt else None), | |
video_drop_prompt=video_drop_prompt, | |
audio_drop_prompt=audio_drop_prompt, | |
video_paths=video_paths | |
) | |
self.accelerator.backward(loss) | |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | |
self.optimizer.step() | |
self.scheduler.step() | |
self.optimizer.zero_grad() | |
#self.accelerator.unwrap_model(self.ema_model).update() | |
if self.accelerator.is_local_main_process: | |
logger.info(f"step {global_step+1}: loss = {loss.detach().cpu().item():.4f}") | |
self.writer.add_scalar('loss', loss.detach().cpu().item(), global_step) | |
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) | |
global_step += 1 | |
epoch_loss += loss.detach().cpu().item() | |
progress_bar.set_postfix(loss=loss.detach().cpu().item()) | |
if global_step % save_step == 0: | |
self.save_checkpoint(global_step) | |
self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:]), global_step) | |
self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:]), global_step) | |
self.writer.add_figure("mel/prediction", plot_spectrogram(pred_data[0,:,:]), global_step) | |
[self.evaluate(eval_dataloader, epoch+1, epochs, global_step) for eval_dataloader in eval_dataloaders] | |
#if global_step % 10 == 0: | |
# torch.cuda.empty_cache() | |
epoch_loss /= len(train_dataloader) | |
if self.accelerator.is_local_main_process: | |
logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}") | |
self.writer.add_scalar('epoch average loss', epoch_loss, epoch) | |
#if self.use_switch_ema: | |
# self.ema_model.update_model_with_ema() | |
self.writer.close() | |
import json | |
import random | |
import pandas as pd | |
from e2_tts_pytorch import torch_tools | |
DURATION = torch_tools.total_length | |
#DURATION = 3000 | |
#beta = 1.5960 | |
#theta = 0.3259 | |
cand = 99999999 | |
class Text2AudioDataset(Dataset): | |
def __init__(self, dataset, part, prefix, text_column, audio_column, num_examples=-1, samples=-1, stft=None, augment=-1, main_process=True, SCORE_THRESHOLD_TRAIN="", train_file="", theta=0.0, vggsound=0, video_drop_prompt=None, audio_drop_prompt=None, device_id=0, vgg_test=None, video_encoder="clip_vit", val_length=None, num_processes=8): | |
#inputs = list(dataset[text_column]) | |
#self.inputs = [prefix + inp for inp in inputs] | |
#self.audios = list(dataset[audio_column]) | |
#self.indices = list(range(len(self.inputs))) | |
# | |
#print("audios", len(self.audios)) | |
#self.new_audios = [] | |
#for index, audio in enumerate(self.audios): | |
# utt, fmt = audio.split(".") | |
# new_audio = "/zhanghaomin/datas/audioset_sl/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/AudioSet_SL_flac/" + utt + ".flac" | |
# #if os.path.exists(new_audio): | |
# self.new_audios.append(new_audio) | |
#self.audios = self.new_audios | |
#N = len(self.audios) | |
#print("audios", len(self.new_audios)) | |
test_final = "/ailab-train/speech/zhanghaomin/scps/tango-master/data/test_audiocaps_subset.json" | |
test_utts = {} | |
with open(test_final, "r") as fr: | |
for line in fr.readlines(): | |
wav = json.loads(line.strip())["location"] | |
utt = wav.rsplit("/", 1)[-1].rsplit("_", 1)[0] | |
utt = "Y"+utt | |
assert(utt not in test_utts) | |
test_utts[utt] = 1 | |
main_process and print("test_final", len(test_utts)) | |
bbc_soundeffects_utts = {} | |
freesound_utts = {} | |
audioset_filter_labels = {"Music": 0, "Speech": 0, "Vehicle": 0, "Musical instrument": 0} | |
self.inputs = [] | |
self.audios = [] | |
self.indices = [] | |
N = 0 | |
audiocaps = True | |
if SCORE_THRESHOLD_TRAIN["/zhanghaomin/datas/audiocaps"] >= 9000.0: | |
audiocaps = False | |
audioset_sl = True | |
bbc_soundeffects = True | |
freesound = True | |
soundbible = True | |
if SCORE_THRESHOLD_TRAIN["/radiostorage/WavCaps"] >= 9000.0: | |
audioset_sl = False | |
bbc_soundeffects = False | |
freesound = False | |
soundbible = False | |
soundeffects = True | |
if SCORE_THRESHOLD_TRAIN["/radiostorage/AudioGroup"] >= 9000.0: | |
soundeffects = False | |
self.soundeffects = soundeffects | |
audioset = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/audioset"] >= 9000.0: | |
audioset = False | |
bbc_soundeffects2 = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/BBCSoundEffects"] >= 9000.0: | |
bbc_soundeffects2 = False | |
freesound2 = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/CLAP_freesound"] >= 9000.0: | |
freesound2 = False | |
musiccaps = True | |
if SCORE_THRESHOLD_TRAIN["/zhanghaomin/datas/musiccap"] >= 9000.0: | |
musiccaps = False | |
tangopromptbank = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/TangoPromptBank"] >= 9000.0: | |
tangopromptbank = False | |
audioset_sl_2ch = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/audiosetsl"] >= 9000.0: | |
audioset_sl_2ch = False | |
self.audioset_sl_2ch = audioset_sl_2ch | |
boom_epic = True | |
if SCORE_THRESHOLD_TRAIN["/ckptstorage/zhanghaomin/giantsoundeffects"] >= 9000.0: | |
boom_epic = False | |
self.boom_epic = boom_epic | |
if isinstance(part, list): | |
part, scp_ac, start_ac, end_ac = part | |
assert(part == "val_audiocaps") | |
else: | |
scp_ac = None | |
if (audioset_sl and part in ["train", "train_val_audioset_sl"]) or (part == "val_audioset_sl"): | |
self.audioset_sl_inputs = [] | |
self.audioset_sl_audios = [] | |
self.audioset_sl_indices = [] | |
audioset_sl_path_train = "/zhanghaomin/codes2/tango-master/data/train_audioset_sl.json" | |
audioset_sl_path_val = "/zhanghaomin/codes2/tango-master/data/val_audioset_sl.json" | |
audioset_sl_path_train_val = "/ailab-train/speech/zhanghaomin/scps/tango-master/data/train_val_audioset_sl.json" | |
if part == "train": | |
audioset_sl_path = audioset_sl_path_train | |
elif part == "train_val_audioset_sl": | |
audioset_sl_path = audioset_sl_path_train_val | |
else: | |
audioset_sl_path = audioset_sl_path_val | |
FN = 0 | |
with open(audioset_sl_path, "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"].rsplit(".", 1)[0] | |
if part in ["train", "train_val_audioset_sl"] and utt in test_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
audio = "/radiostorage/WavCaps/Zip_files/AudioSet_SL/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/AudioSet_SL_flac/" + utt + ".flac" | |
self.audioset_sl_inputs.append(caption) | |
self.audioset_sl_audios.append(audio) | |
self.audioset_sl_indices.append(N + index) | |
main_process and print(part, "audioset_sl", len(self.audioset_sl_audios), "filtered", FN) | |
self.inputs.extend(self.audioset_sl_inputs) | |
self.audios.extend(self.audioset_sl_audios) | |
self.indices.extend(self.audioset_sl_indices) | |
N = len(self.audios) | |
main_process and print(part, "audioset_sl audios", len(self.audios)) | |
if (audiocaps and part in ["train", "train_val_audioset_sl"]) or (part == "val_audiocaps"): | |
self.audiocaps_inputs = [] | |
self.audiocaps_audios = [] | |
self.audiocaps_indices = [] | |
audiocaps_path_train = "/ailab-train/speech/zhanghaomin/scps/tango-master/data/audiocaps/train_audiocaps.json" | |
audiocaps_path_val = "/ailab-train/speech/zhanghaomin/scps/tango-master/data/audiocaps/test_audiocaps.json" | |
if scp_ac is not None: | |
audiocaps_path_val = scp_ac | |
if part in ["train", "train_val_audioset_sl"]: | |
audiocaps_path = audiocaps_path_train | |
else: | |
audiocaps_path = audiocaps_path_val | |
FN = 0 | |
with open(audiocaps_path, "r") as fr: | |
lines = fr.readlines() | |
if scp_ac is not None: | |
lines = lines[start_ac: end_ac] | |
for index, line in enumerate(lines): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
if part in ["train", "train_val_audioset_sl"] and utt in test_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
audio = jsondata["audio"] | |
self.audiocaps_inputs.append(caption) | |
self.audiocaps_audios.append(audio) | |
self.audiocaps_indices.append(N + index) | |
main_process and print(part, "audiocaps", len(self.audiocaps_audios), "filtered", FN) | |
self.inputs.extend(self.audiocaps_inputs) | |
self.audios.extend(self.audiocaps_audios) | |
self.indices.extend(self.audiocaps_indices) | |
N = len(self.audios) | |
main_process and print(part, "audiocaps audios", len(self.audios)) | |
if bbc_soundeffects and part in ["train", "train_val_audioset_sl"]: | |
self.bbc_soundeffects_inputs = [] | |
self.bbc_soundeffects_audios = [] | |
self.bbc_soundeffects_indices = [] | |
with open("/ailab-train/speech/zhanghaomin/scps/tango-master/data/train_bbc_sound_effects.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
bbc_soundeffects_utts[utt] = 1 | |
caption = jsondata["caption"] | |
audio = "/radiostorage/WavCaps/Zip_files/BBC_Sound_Effects/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/BBC_Sound_Effects_flac/" + utt + ".flac" | |
self.bbc_soundeffects_inputs.append(caption) | |
self.bbc_soundeffects_audios.append(audio) | |
self.bbc_soundeffects_indices.append(N + index) | |
main_process and print(part, "bbc_soundeffects", len(self.bbc_soundeffects_audios)) | |
self.inputs.extend(self.bbc_soundeffects_inputs) | |
self.audios.extend(self.bbc_soundeffects_audios) | |
self.indices.extend(self.bbc_soundeffects_indices) | |
N = len(self.audios) | |
main_process and print(part, "bbc_soundeffects audios", len(self.audios)) | |
if freesound and part in ["train", "train_val_audioset_sl"]: | |
self.freesound_inputs = [] | |
self.freesound_audios = [] | |
self.freesound_indices = [] | |
with open("/ailab-train/speech/zhanghaomin/scps/tango-master/data/train_freesound.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
freesound_utts[utt] = 1 | |
caption = jsondata["caption"] | |
audio = "/radiostorage/WavCaps/Zip_files/FreeSound/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/" + utt + ".flac" | |
self.freesound_inputs.append(caption) | |
self.freesound_audios.append(audio) | |
self.freesound_indices.append(N + index) | |
main_process and print(part, "freesound", len(self.freesound_audios)) | |
self.inputs.extend(self.freesound_inputs) | |
self.audios.extend(self.freesound_audios) | |
self.indices.extend(self.freesound_indices) | |
N = len(self.audios) | |
main_process and print(part, "freesound audios", len(self.audios)) | |
if soundbible and part in ["train", "train_val_audioset_sl"]: | |
self.soundbible_inputs = [] | |
self.soundbible_audios = [] | |
self.soundbible_indices = [] | |
with open("/ailab-train/speech/zhanghaomin/scps/tango-master/data/train_soundbible.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
caption = jsondata["caption"] | |
audio = "/radiostorage/WavCaps/Zip_files/SoundBible/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/SoundBible_flac/" + utt + ".flac" | |
self.soundbible_inputs.append(caption) | |
self.soundbible_audios.append(audio) | |
self.soundbible_indices.append(N + index) | |
main_process and print(part, "soundbible", len(self.soundbible_audios)) | |
self.inputs.extend(self.soundbible_inputs) | |
self.audios.extend(self.soundbible_audios) | |
self.indices.extend(self.soundbible_indices) | |
N = len(self.audios) | |
main_process and print(part, "soundbible audios", len(self.audios)) | |
if (soundeffects and part in ["train", "train_val_audioset_sl"]) or (part == "val_soundeffects"): | |
self.soundeffects_inputs = [] | |
self.soundeffects_audios = [] | |
self.soundeffects_indices = [] | |
#soundeffects_path_train = "/zhanghaomin/codes2/audiocaption/wav_all_train.scp" | |
#soundeffects_path_val = "/zhanghaomin/codes2/audiocaption/wav_all_val.scp" | |
#soundeffects_path_train = "/zhanghaomin/codes2/audiocaption/wav_msclap_all_train.scp" | |
soundeffects_path_train = train_file | |
soundeffects_path_val = "/zhanghaomin/codes2/audiocaption/wav_msclap_all_val.scp" | |
if part in ["train", "train_val_audioset_sl"]: | |
soundeffects_path = soundeffects_path_train | |
else: | |
soundeffects_path = soundeffects_path_val | |
with open(soundeffects_path, 'r') as fr: | |
for index, line in enumerate(fr.readlines()): | |
if soundeffects_path.endswith("msclapcap_v1.list"): | |
utt, wav, caption1, score = line.strip().split('"@$&#"') | |
caption2 = "blank" | |
name = "blank" | |
else: | |
utt, wav, name, caption1, caption2 = line.strip().split('"@$&#"') | |
wav = wav.replace("/radiostorage/AudioGroup/", "/radiostorage/AudioGroup/") | |
period = int(utt.split('_')[-1]) | |
self.soundeffects_inputs.append((caption1, caption2, name)) | |
self.soundeffects_audios.append((wav, utt, period)) | |
self.soundeffects_indices.append(N + index) | |
main_process and print(part, "soundeffects", len(self.soundeffects_audios)) | |
self.inputs.extend(self.soundeffects_inputs) | |
self.audios.extend(self.soundeffects_audios) | |
self.indices.extend(self.soundeffects_indices) | |
N = len(self.audios) | |
main_process and print(part, "soundeffects audios", len(self.audios)) | |
if audioset and part in ["train", "train_val_audioset_sl"]: | |
self.audioset_inputs = [] | |
self.audioset_audios = [] | |
self.audioset_indices = [] | |
FN = 0 | |
FN2 = 0 | |
if SCORE_THRESHOLD_TRAIN["audioset"] == "af-audioset": | |
audioset_path = "/ailab-train/speech/zhanghaomin/scps/audioset/audioset_train_af.json" | |
else: | |
audioset_path = "/ckptstorage/zhanghaomin/audioset/audioset_train.json" | |
with open(audioset_path, "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
if SCORE_THRESHOLD_TRAIN["audioset"] == "af-audioset": | |
utt = jsondata["id"] | |
if part in ["train", "train_val_audioset_sl"] and utt in test_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
audio = jsondata["audio"] | |
else: | |
utt = jsondata["id"] | |
if part in ["train", "train_val_audioset_sl"] and utt in test_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
#caption = caption.replace("@", ", ") | |
captions = caption.split("@") | |
captions_new = [] | |
for c in captions: | |
if c in audioset_filter_labels: | |
audioset_filter_labels[c] += 1 | |
else: | |
captions_new.append(c) | |
if len(captions_new) == 0: | |
FN2 += 1 | |
continue | |
caption = "".join(captions_new) | |
audio = jsondata["audio"] | |
self.audioset_inputs.append(caption) | |
self.audioset_audios.append(audio) | |
self.audioset_indices.append(N + index) | |
main_process and print(part, "audioset", len(self.audioset_audios), "filtered", FN, FN2, audioset_filter_labels) | |
self.inputs.extend(self.audioset_inputs) | |
self.audios.extend(self.audioset_audios) | |
self.indices.extend(self.audioset_indices) | |
N = len(self.audios) | |
main_process and print(part, "audioset audios", len(self.audios)) | |
if bbc_soundeffects2 and part in ["train", "train_val_audioset_sl"]: | |
self.bbc_soundeffects2_inputs = [] | |
self.bbc_soundeffects2_audios = [] | |
self.bbc_soundeffects2_indices = [] | |
FN = 0 | |
with open("/ckptstorage/zhanghaomin/BBCSoundEffects/bbcsoundeffects_train.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
if part in ["train", "train_val_audioset_sl"] and utt in bbc_soundeffects_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
caption = caption.split("(")[0].strip() | |
audio = jsondata["audio"] | |
self.bbc_soundeffects2_inputs.append(caption) | |
self.bbc_soundeffects2_audios.append(audio) | |
self.bbc_soundeffects2_indices.append(N + index) | |
main_process and print(part, "bbc_soundeffects2", len(self.bbc_soundeffects2_audios), "filtered", FN) | |
self.inputs.extend(self.bbc_soundeffects2_inputs) | |
self.audios.extend(self.bbc_soundeffects2_audios) | |
self.indices.extend(self.bbc_soundeffects2_indices) | |
N = len(self.audios) | |
main_process and print(part, "bbc_soundeffects2 audios", len(self.audios)) | |
if freesound2 and part in ["train", "train_val_audioset_sl"]: | |
self.freesound2_inputs = [] | |
self.freesound2_audios = [] | |
self.freesound2_indices = [] | |
FN = 0 | |
with open("/ckptstorage/zhanghaomin/CLAP_freesound/freesound_train.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
utt = jsondata["id"] | |
if part in ["train", "train_val_audioset_sl"] and utt in freesound_utts: | |
FN += 1 | |
continue | |
caption = jsondata["caption"] | |
caption = caption.split('"@$&#"') | |
#caption = caption[0].split("(")[0].strip() | |
caption = tuple([c.split("(")[0].strip() for c in caption]) | |
audio = jsondata["audio"] | |
self.freesound2_inputs.append(caption) | |
self.freesound2_audios.append(audio) | |
self.freesound2_indices.append(N + index) | |
main_process and print(part, "freesound2", len(self.freesound2_audios), "filtered", FN) | |
self.inputs.extend(self.freesound2_inputs) | |
self.audios.extend(self.freesound2_audios) | |
self.indices.extend(self.freesound2_indices) | |
N = len(self.audios) | |
main_process and print(part, "freesound2 audios", len(self.audios)) | |
if tangopromptbank and part in ["train", "train_val_audioset_sl"]: | |
self.tangopromptbank_inputs = [] | |
self.tangopromptbank_audios = [] | |
self.tangopromptbank_indices = [] | |
with open("/ailab-train/speech/zhanghaomin/scps/TangoPromptBank/data.json", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
caption = jsondata["captions"] | |
audio = jsondata["location"] | |
self.tangopromptbank_inputs.append(caption) | |
self.tangopromptbank_audios.append(audio) | |
self.tangopromptbank_indices.append(N + index) | |
main_process and print(part, "tangopromptbank", len(self.tangopromptbank_audios)) | |
self.inputs.extend(self.tangopromptbank_inputs) | |
self.audios.extend(self.tangopromptbank_audios) | |
self.indices.extend(self.tangopromptbank_indices) | |
N = len(self.audios) | |
main_process and print(part, "tangopromptbank audios", len(self.audios)) | |
if musiccaps and part in ["train", "train_val_audioset_sl"]: | |
self.musiccaps_inputs = [] | |
self.musiccaps_audios = [] | |
self.musiccaps_indices = [] | |
with open("/ailab-train/speech/zhanghaomin/scps/musiccap/musiccaps.jsonl", "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
caption = jsondata["caption"] | |
audio = jsondata["audio"] | |
self.musiccaps_inputs.append(caption) | |
self.musiccaps_audios.append(audio) | |
self.musiccaps_indices.append(N + index) | |
main_process and print(part, "musiccaps", len(self.musiccaps_audios)) | |
self.inputs.extend(self.musiccaps_inputs) | |
self.audios.extend(self.musiccaps_audios) | |
self.indices.extend(self.musiccaps_indices) | |
N = len(self.audios) | |
main_process and print(part, "musiccaps audios", len(self.audios)) | |
if (audioset_sl_2ch and part in ["train", "train_val_audioset_sl"]) or (part == "val_audioset_sl_2ch"): | |
self.audioset_sl_2ch_inputs = [] | |
self.audioset_sl_2ch_audios = [] | |
self.audioset_sl_2ch_indices = [] | |
audioset_sl_2ch_train = "/ckptstorage/zhanghaomin/audiosetsl/wavs/train.jsonl" | |
audioset_sl_2ch_val = "/ckptstorage/zhanghaomin/audiosetsl/wavs/test.jsonl" | |
if part in ["train", "train_val_audioset_sl"]: | |
audioset_sl_2ch_path = audioset_sl_2ch_train | |
else: | |
audioset_sl_2ch_path = audioset_sl_2ch_val | |
with open(audioset_sl_2ch_path, "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
caption = jsondata["caption"] | |
audio = jsondata["audio"] | |
self.audioset_sl_2ch_inputs.append(caption) | |
self.audioset_sl_2ch_audios.append(audio) | |
self.audioset_sl_2ch_indices.append(N + index) | |
main_process and print(part, "audioset_sl_2ch", len(self.audioset_sl_2ch_audios)) | |
self.inputs.extend(self.audioset_sl_2ch_inputs) | |
self.audios.extend(self.audioset_sl_2ch_audios) | |
self.indices.extend(self.audioset_sl_2ch_indices) | |
N = len(self.audios) | |
main_process and print(part, "audioset_sl_2ch audios", len(self.audios)) | |
if (boom_epic and part in ["train", "train_val_audioset_sl"]) or (part == "val_boom_epic"): | |
self.boom_epic_inputs = [] | |
self.boom_epic_audios = [] | |
self.boom_epic_indices = [] | |
#boom_epic_train = "/ckptstorage/zhanghaomin/giantsoundeffects/train_animals_mixture2.jsonl" | |
#boom_epic_val = "/ckptstorage/zhanghaomin/giantsoundeffects/test_animals_mixture2.jsonl" | |
boom_epic_train = "/ailab-train/speech/zhanghaomin/scps/giantsoundeffects/train.jsonl" | |
boom_epic_val = "/ailab-train/speech/zhanghaomin/scps/giantsoundeffects/test.jsonl" | |
if part in ["train", "train_val_audioset_sl"]: | |
boom_epic_path = boom_epic_train | |
else: | |
boom_epic_path = boom_epic_val | |
with open(boom_epic_path, "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
jsondata = json.loads(line.strip()) | |
caption = jsondata["caption"] | |
audio = jsondata["audio"] | |
self.boom_epic_inputs.append(caption) | |
self.boom_epic_audios.append(audio) | |
self.boom_epic_indices.append(N + index) | |
main_process and print(part, "boom_epic", len(self.boom_epic_audios)) | |
repeats = 1 | |
for _ in range(repeats): | |
self.inputs.extend(self.boom_epic_inputs) | |
self.audios.extend(self.boom_epic_audios) | |
self.indices.extend(self.boom_epic_indices) | |
N = len(self.audios) | |
main_process and print(part, "boom_epic audios", len(self.audios)) | |
self.boom_epic = boom_epic | |
if vggsound: | |
self.inputs_vggsound = [] | |
self.audios_vggsound = [] | |
self.indices_vggsound = [] | |
if part in ["train", "train_val_audioset_sl"]: | |
path = "/ailab-train/speech/zhanghaomin/scps/VGGSound/train.scp" | |
with open(path, "r") as fr: | |
for index, line in enumerate(fr.readlines()): | |
video_path, text = line.strip().split("\t") | |
self.inputs_vggsound.append("the sound of " + text.strip().replace("(", "").replace(")", "")) | |
self.audios_vggsound.append(video_path) | |
self.indices_vggsound.append(index) | |
N = len(self.audios_vggsound) | |
print(part, "vggsound train audios", len(self.audios_vggsound), device_id, num_processes) | |
elif part == "val_vggsound": | |
if vgg_test is not None: | |
path = vgg_test[0] | |
start = vgg_test[1] | |
end = vgg_test[2] | |
else: | |
path = "/ailab-train/speech/zhanghaomin/scps/VGGSound/test.scp" | |
start = 0 | |
end = 200 | |
with open(path, "r") as fr: | |
for index, line in enumerate(fr.readlines()[start:end]): | |
video_path, text = line.strip().split("\t") | |
self.inputs.append("the sound of " + text.strip().replace("(", "").replace(")", "")) | |
self.audios.append(video_path) | |
self.indices.append(N + index) | |
N = len(self.audios) | |
print(part, "vggsound eval audios", len(self.audios), device_id, num_processes) | |
self.vggsound = vggsound | |
self.video_drop_prompt = video_drop_prompt | |
self.audio_drop_prompt = audio_drop_prompt | |
self.device_id = device_id | |
self.num_processes = num_processes | |
self.bad_ids = {} | |
self.video_encoder = video_encoder | |
self.val_length = val_length if val_length is not None else torch_tools.MAX_TARGET_LEN | |
print("val_length", self.val_length) | |
#self.mapper = {} | |
#for index, audio, text in zip(self.indices, self.audios, self.inputs): | |
# self.mapper[index] = [audio, text] | |
if num_examples != -1: | |
self.inputs, self.audios = self.inputs[:num_examples], self.audios[:num_examples] | |
self.indices = self.indices[:num_examples] | |
self.samples = samples | |
self.stft = stft | |
self.target_length = DURATION | |
self.augment = augment | |
self.part = part | |
self.main_process = main_process | |
self.SCORE_THRESHOLD_TRAIN = SCORE_THRESHOLD_TRAIN | |
self.theta = theta | |
self.multi = 4 | |
def __len__(self): | |
return len(self.inputs) | |
def get_num_instances(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
s1, s2, s3 = self.inputs[index], self.audios[index], self.indices[index] | |
return s1, s2, s3 | |
def read_audio_from_video(self, video_path): | |
if video_path.startswith("/ailab-train/speech/zhanghaomin/VGGSound/"): | |
audio_path = video_path.replace("/video/", "/audio/").replace(".mp4", ".wav") | |
else: | |
audio_path = video_path.replace(".mp4", ".generated.wav") | |
if os.path.exists(audio_path): | |
#print("video wav exist", audio_path) | |
waveform, sr = torchaudio.load(audio_path) | |
else: | |
#print("video wav not exist", video_path) | |
try: | |
clip = AudioFileClip(video_path) | |
sound_array = np.array(list(clip.iter_frames())) | |
waveform = torch.from_numpy(sound_array).transpose(0,1).to(torch.float32) | |
waveform = waveform[0:1, :] | |
if clip.fps != torch_tools.new_freq: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=clip.fps, new_freq=torch_tools.new_freq) | |
waveform = torch_tools.normalize_wav(waveform) | |
torchaudio.save(audio_path, waveform, torch_tools.new_freq) | |
except: | |
print("Error read_audio_from_video", audio_path) | |
traceback.print_exc() | |
return None | |
return waveform | |
def collate_fn(self, data): | |
# 452463+1471396->452463+3430704->452463+2978587 more 452463+1037241+15973+310169 real 1183416+2000 | |
# theta (1183416)*0.5/(452463+1037241+15973+310169)=0.3259 | |
# beta (452463+1037241+15973+310169+3430704)/(452463+1037241+15973+310169+1471396)=1.5960 (452463+1037241+15973+310169+2978587)/(452463+1037241+15973+310169+1471396)=1.4585 | |
if self.part in ["train", "train_val_audioset_sl"]: | |
val = False | |
else: | |
val = True | |
if self.audioset_sl_2ch: | |
nch = 2 | |
else: | |
nch = 1 | |
while True: | |
if self.part in ["train", "train_val_audioset_sl"]: | |
#print("data raw", len(data), data[0]) | |
#data_sampled = random.sample(data, self.samples) | |
if (self.soundeffects or self.boom_epic) and self.theta > 0: | |
data_len = len(data) | |
data_1 = [] | |
data_2 = [] | |
for sample in data: | |
if isinstance(sample[1], tuple): | |
if sample[1][0].startswith("/radiostorage/"): | |
prefix = "/".join(sample[1][0].split("/")[:3]) | |
else: | |
prefix = "/".join(sample[1][0].split("/")[:4]) | |
else: | |
if sample[1].startswith("/radiostorage/"): | |
prefix = "/".join(sample[1].split("/")[:3]) | |
else: | |
prefix = "/".join(sample[1].split("/")[:4]) | |
if torch_tools.SOUNDEFFECT[prefix]: | |
data_2.append(sample) | |
else: | |
data_1.append(sample) | |
#print("data splitted", len(data_1), len(data_2), float(len(data_1))/len(data_2)) | |
data_len_1 = len(data_1) | |
data_len_2 = len(data_2) | |
if data_len_1 == 0 or data_len_2 == 0: | |
data_1_sampled = data_1 | |
data_2_sampled = data_2 | |
else: | |
data_len_1_sampled = int(data_len_2 / self.theta) | |
data_len_2_sampled = int(data_len_1 * self.theta) | |
if data_len_1_sampled < data_len_1: | |
data_1_sampled = random.sample(data_1, data_len_1_sampled) | |
data_2_sampled = data_2 | |
else: | |
data_1_sampled = data_1 | |
data_2_sampled = random.sample(data_2, data_len_2_sampled) | |
#print("data sampled", len(data_1_sampled), len(data_2_sampled), float(len(data_1_sampled))/len(data_2_sampled), self.samples*cand) | |
data_sampled = data_1_sampled | |
data_sampled.extend(data_2_sampled) | |
data_sampled = random.sample(data_sampled, min(self.samples*cand, len(data_sampled))) | |
#print("data sampled", len(data_sampled)) | |
else: | |
data_sampled = random.sample(data, min(self.samples*cand, len(data))) | |
#print("data sampled", len(data_sampled)) | |
else: | |
data_sampled = data | |
dat = pd.DataFrame(data_sampled) | |
text, audios, indices = [dat[i].tolist() for i in dat] | |
if self.vggsound and val and self.part == "val_vggsound": | |
#print("vggsound val", len(audios), text) | |
fbanks = [] | |
fbank_lens = [] | |
video_paths = [] | |
text_selected = [] | |
for audio, txt in zip(audios, text): | |
waveform = self.read_audio_from_video(audio) | |
if waveform is None: | |
continue | |
length = self.val_length | |
waveform = waveform[:, :length*torch_tools.hop_size] | |
fbank = self.stft(waveform).transpose(-1,-2) | |
fbanks.append(fbank) | |
fbank_lens.append(fbank.shape[1]) | |
video_paths.append(audio) | |
text_selected.append(txt) | |
#print("stft", waveform.shape, fbank.shape) | |
max_length = max(fbank_lens) | |
for i in range(len(fbanks)): | |
if fbanks[i].shape[1] < max_length: | |
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) | |
mel = torch.cat(fbanks, 0) | |
mel_len = torch.Tensor(fbank_lens).to(torch.int32) | |
break | |
if_clap_filter = False | |
if self.part in ["val_audiocaps", "val_audioset_sl_2ch", "val_boom_epic"]: | |
if_clap_filter = False | |
mel, text_selected, _, _, _, mel_len = torch_tools.wav_to_fbank(audios, text, self.samples, self.target_length, self.stft, val, if_clap_filter, self.main_process, self.SCORE_THRESHOLD_TRAIN, nch) | |
if mel is not None: | |
if self.part in ["train", "train_val_audioset_sl"]: | |
if len(text_selected) > self.samples: | |
mel = mel[:self.samples,...] | |
text_selected = text_selected[:self.samples] | |
#waveform = waveform[:self.samples,...] | |
mel_len = mel_len[:self.samples] | |
if self.vggsound: | |
video_paths = [None] * len(text_selected) | |
else: | |
video_paths = None | |
#print("mel", mel.shape if mel is not None else None, len(text_selected) if text_selected is not None else 0, mel_len, video_paths) | |
break | |
#mel = mel.unsqueeze(1) | |
if self.augment != 0 and len(text_selected) > 1 and (not val): | |
aug_num = len(text_selected) if self.augment == -1 else self.augment | |
# the last batch of the training data may have only one instance | |
# we check the length here so that the augmentation function doesn't throw an error | |
mixed_mel, _, _, mixed_captions, _, mixed_mel_len = torch_tools.augment_wav_to_fbank(audios, text, aug_num, self.target_length, self.stft, self.main_process, self.SCORE_THRESHOLD_TRAIN, nch) | |
#print("mixed_mel", mixed_mel.shape if mixed_mel is not None else None, len(mixed_captions) if mixed_captions is not None else 0, mixed_mel_len) | |
if mixed_mel is not None: | |
if mel.shape[1] < mixed_mel.shape[1]: | |
mel = torch.cat([mel, torch.zeros(mel.shape[0], mixed_mel.shape[1]-mel.shape[1], mel.shape[2])], 1) | |
elif mixed_mel.shape[1] < mel.shape[1]: | |
mixed_mel = torch.cat([mixed_mel, torch.zeros(mixed_mel.shape[0], mel.shape[1]-mixed_mel.shape[1], mixed_mel.shape[2])], 1) | |
#mixed_mel = mixed_mel.unsqueeze(1) | |
mel = torch.cat([mel, mixed_mel], 0) | |
text_selected += mixed_captions | |
mel_len = torch.cat([mel_len, mixed_mel_len], 0) | |
if self.vggsound: | |
video_paths.extend([None] * len(mixed_captions)) | |
else: | |
video_paths = None | |
#print("mel_final", mel.shape if mel is not None else None, len(text_selected) if text_selected is not None else 0, mel_len) | |
if self.vggsound and (not val): | |
video_paths = [None] * len(text_selected) | |
fbanks = [] | |
fbank_lens = [] | |
audios = [] | |
video_captions = [] | |
indices = random.sample([self.indices_vggsound[i] for i in range(self.device_id, len(self.indices_vggsound), self.num_processes)], self.vggsound*10) | |
indices_featured = [] | |
indices_nonfeatured = [] | |
for i in indices: | |
if i in self.bad_ids: | |
continue | |
if self.audios_vggsound[i].startswith("/ailab-train/speech/zhanghaomin/VGGSound/"): | |
if self.video_encoder == "clip_vit": | |
feature_path = self.audios_vggsound[i].replace("/video/", "/feature/").replace(".mp4", ".npz") | |
elif self.video_encoder == "clip_vit2": | |
feature_path = self.audios_vggsound[i].replace("/video/", "/feature_clip_vit2/").replace(".mp4", ".npz") | |
elif self.video_encoder == "clip_convnext": | |
feature_path = self.audios_vggsound[i].replace("/video/", "/feature_clip_convnext/").replace(".mp4", ".npz") | |
elif self.video_encoder == "dinov2": | |
feature_path = self.audios_vggsound[i].replace("/video/", "/feature_dinov2/").replace(".mp4", ".npz") | |
elif self.video_encoder == "mixed": | |
feature_path = self.audios_vggsound[i].replace("/video/", "/feature_mixed/").replace(".mp4", ".npz") | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
else: | |
if self.video_encoder == "clip_vit": | |
feature_path = self.audios_vggsound[i].replace(".mp4", ".generated.npz") | |
elif self.video_encoder == "clip_vit2": | |
feature_path = self.audios_vggsound[i].replace(".mp4", ".generated.clip_vit2.npz") | |
elif self.video_encoder == "clip_convnext": | |
feature_path = self.audios_vggsound[i].replace(".mp4", ".generated.clip_convnext.npz") | |
elif self.video_encoder == "dinov2": | |
feature_path = self.audios_vggsound[i].replace(".mp4", ".generated.dinov2.npz") | |
elif self.video_encoder == "mixed": | |
feature_path = self.audios_vggsound[i].replace(".mp4", ".generated.mixed.npz") | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
if os.path.exists(feature_path): | |
indices_featured.append(i) | |
else: | |
indices_nonfeatured.append(i) | |
if len(indices_nonfeatured) >= self.vggsound: | |
break | |
#print(self.device_id, self.bad_ids, indices, indices_featured, indices_nonfeatured) | |
indices = indices_nonfeatured[:self.vggsound] | |
if len(indices) < self.vggsound: | |
indices.extend(indices_featured[:self.vggsound-len(indices)]) | |
for i in indices: | |
waveform = self.read_audio_from_video(self.audios_vggsound[i]) | |
if waveform is None: | |
print("Error audio in video", i, self.audios_vggsound[i], self.bad_ids) | |
self.bad_ids[i] = 1 | |
continue | |
length = random.randint(torch_tools.MIN_TARGET_LEN, torch_tools.MAX_TARGET_LEN) | |
waveform = waveform[:, :length*torch_tools.hop_size] | |
fbank = self.stft(waveform).transpose(-1,-2) | |
fbanks.append(fbank) | |
fbank_lens.append(fbank.shape[1]) | |
audios.append(self.audios_vggsound[i]) | |
video_captions.append(self.inputs_vggsound[i]) | |
#print("stft", waveform.shape, fbank.shape) | |
max_length = max(fbank_lens) | |
for i in range(len(fbanks)): | |
if fbanks[i].shape[1] < max_length: | |
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) | |
video_mel = torch.cat(fbanks, 0) | |
video_mel_len = torch.Tensor(fbank_lens).to(torch.int32) | |
#print("video_mel", video_mel.shape if video_mel is not None else None, len(video_captions) if video_captions is not None else 0, video_mel_len) | |
if video_mel is not None: | |
if mel.shape[1] < video_mel.shape[1]: | |
mel = torch.cat([mel, torch.zeros(mel.shape[0], video_mel.shape[1]-mel.shape[1], mel.shape[2])], 1) | |
elif video_mel.shape[1] < mel.shape[1]: | |
video_mel = torch.cat([video_mel, torch.zeros(video_mel.shape[0], mel.shape[1]-video_mel.shape[1], video_mel.shape[2])], 1) | |
#video_mel = video_mel.unsqueeze(1) | |
mel = torch.cat([mel, video_mel], 0) | |
text_selected += video_captions | |
mel_len = torch.cat([mel_len, video_mel_len], 0) | |
video_paths.extend(audios) | |
#print("mel_final", mel.shape if mel is not None else None, len(text_selected) if text_selected is not None else 0, mel_len, video_paths) | |
return [text_selected, mel, video_paths, mel_len, self.video_drop_prompt, self.audio_drop_prompt] | |
class Text2SpeechDataset(Dataset): | |
def __init__(self, samples=-1, stft=None, val=False): | |
self.inputs = [] | |
self.audios = [] | |
self.indices = [] | |
train_scp = "/ckptstorage/zhanghaomin/docker/ximalaya/ximalaya_process/data_scp/train.json" | |
test_scp = "/ckptstorage/zhanghaomin/docker/ximalaya/ximalaya_process/data_scp/test.json" | |
scp = train_scp if not val else test_scp | |
index = 0 | |
with open(scp, "r") as fr: | |
for line in fr.readlines(): | |
data = json.loads(line.strip()) | |
wav = data["wav"] | |
text = data["text"] | |
if len(text) < 2: | |
continue | |
self.inputs.append(text) | |
self.audios.append(wav) | |
self.indices.append(index) | |
index += 1 | |
print("data size", len(self.inputs), val) | |
self.samples = samples | |
self.stft = stft | |
self.sample_rate = 24000 | |
self.multi = 8 | |
self.val = val | |
def __len__(self): | |
return len(self.inputs) | |
def get_num_instances(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
s1, s2, s3 = self.inputs[index], self.audios[index], self.indices[index] | |
return s1, s2, s3 | |
def collate_fn(self, data): | |
dat = pd.DataFrame(data) | |
texts, audios, indices = [dat[i].tolist() for i in dat] | |
fbanks = [] | |
fbank_lens = [] | |
text_selected = [] | |
for text, audio in zip(texts, audios): | |
waveform, sr = torchaudio.load(audio) | |
waveform = waveform[0:1, :] | |
if sr != self.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=self.sample_rate) | |
waveform = torch_tools.normalize_wav(waveform) | |
fbank = self.stft(waveform).transpose(-1,-2) | |
#print("stft", waveform.shape, fbank.shape) | |
if self.val: | |
if waveform.shape[1] / float(self.sample_rate) < 2.0 or waveform.shape[1] / float(self.sample_rate) > 15.0: | |
continue | |
else: | |
if waveform.shape[1] / float(self.sample_rate) < 1.0 or waveform.shape[1] / float(self.sample_rate) > 20.0: | |
continue | |
fbanks.append(fbank) | |
fbank_lens.append(fbank.shape[1]) | |
text_selected.append(text) | |
if self.samples > 0 and len(text_selected) >= self.samples: | |
break | |
if self.samples > 0 and len(text_selected) > self.samples: | |
fbanks = fbanks[:self.samples] | |
fbank_lens = fbank_lens[:self.samples] | |
text_selected = text_selected[:self.samples] | |
max_length = max(fbank_lens) | |
for i in range(len(fbanks)): | |
if fbanks[i].shape[1] < max_length: | |
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) | |
mel = torch.cat(fbanks, 0) | |
mel_len = torch.Tensor(fbank_lens).to(torch.int32) | |
return [text_selected, mel, None, mel_len, None] | |