Spaces:
Running
Running
File size: 11,221 Bytes
1991049 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import sys
if len(sys.argv) >= 6:
ckpt = sys.argv[1]
drop_prompt = bool(int(sys.argv[2]))
test_scp = sys.argv[3]
start = int(sys.argv[4])
end = int(sys.argv[5])
step = 1
out_dir = sys.argv[6]
print("inference", ckpt, drop_prompt, test_scp, start, end, out_dir)
else:
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/315000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more/60000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano5/4_2_8000.pt"
ckpt = "./ckpts/piano5_4_2_8000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano6/dpo_100.pt"
drop_prompt = False
test_scp = "/ailab-train/speech/zhanghaomin/scps/VGGSound/test.scp"
#test_scp = "./tests/vgg_test.scp"
####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/test.scp"
####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_2h/test.scp"
####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_20h/v2a_giant_piano2/test.scp"
start = 0
end = 2
step = 1
out_dir = "./outputs_vgg/"
####out_dir = "./outputs_piano/"
#####out_dir = "./outputs2t_20h_dpo/"
import torch
from e2_tts_pytorch.e2_tts_crossatt3 import E2TTS, DurationPredictor
from e2_tts_pytorch.e2_tts_crossatt3 import MelSpec, EncodecWrapper
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from e2_tts_pytorch.trainer_multigpus_alldatas3 import HFDataset, Text2AudioDataset
from einops import einsum, rearrange, repeat, reduce, pack, unpack
import torchaudio
from datetime import datetime
import json
import numpy as np
import os
from moviepy.editor import VideoFileClip, AudioFileClip
import traceback
audiocond_drop_prob = 1.1
#audiocond_drop_prob = 0.3
#cond_proj_in_bias = True
#cond_drop_prob = 1.1
cond_drop_prob = -0.1
prompt_drop_prob = -0.1
#prompt_drop_prob = 1.1
video_text = True
def main():
#duration_predictor = DurationPredictor(
# transformer = dict(
# dim = 512,
# depth = 6,
# )
#)
duration_predictor = None
e2tts = E2TTS(
duration_predictor = duration_predictor,
transformer = dict(
#depth = 12,
#dim = 512,
#heads = 8,
#dim_head = 64,
depth = 12,
dim = 1024,
dim_text = 1280,
heads = 16,
dim_head = 64,
if_text_modules = (cond_drop_prob < 1.0),
if_cross_attn = (prompt_drop_prob < 1.0),
if_audio_conv = True,
if_text_conv = True,
),
#tokenizer = 'char_utf8',
tokenizer = 'phoneme_zh',
audiocond_drop_prob = audiocond_drop_prob,
cond_drop_prob = cond_drop_prob,
prompt_drop_prob = prompt_drop_prob,
frac_lengths_mask = (0.7, 1.0),
#audiocond_snr = None,
#audiocond_snr = (5.0, 10.0),
if_cond_proj_in = (audiocond_drop_prob < 1.0),
#cond_proj_in_bias = cond_proj_in_bias,
if_embed_text = (cond_drop_prob < 1.0) and (not video_text),
if_text_encoder2 = (prompt_drop_prob < 1.0),
if_clip_encoder = video_text,
video_encoder = "clip_vit",
pretrained_vocos_path = 'facebook/encodec_24khz',
num_channels = 128,
sampling_rate = 24000,
)
e2tts = e2tts.to("cuda")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt", map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu")
#for key in list(checkpoint['model_state_dict'].keys()):
# if key.startswith('mel_spec.'):
# del checkpoint['model_state_dict'][key]
# if key.startswith('transformer.text_registers'):
# del checkpoint['model_state_dict'][key]
e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
e2tts.vocos = EncodecWrapper("facebook/encodec_24khz")
for param in e2tts.vocos.parameters():
param.requires_grad = False
e2tts.vocos.eval()
e2tts.vocos.to("cuda")
#dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
#sample = dataset[1]
#mel_spec_raw = sample["mel_spec"].unsqueeze(0)
#mel_spec = rearrange(mel_spec_raw, 'b d n -> b n d')
#print(mel_spec.shape, sample["text"])
#out_dir = "/user-fs/zhanghaomin/v2a_generated/v2a_190000_tests/"
#out_dir = "/user-fs/zhanghaomin/v2a_generated/tv2a_98500_clips/"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
#bs = list(range(10)) + [14,16]
bs = None
SCORE_THRESHOLD_TRAIN = '{"/zhanghaomin/datas/audiocaps": -9999.0, "/radiostorage/WavCaps": -9999.0, "/radiostorage/AudioGroup": 9999.0, "/ckptstorage/zhanghaomin/audioset": -9999.0, "/ckptstorage/zhanghaomin/BBCSoundEffects": 9999.0, "/ckptstorage/zhanghaomin/CLAP_freesound": 9999.0, "/zhanghaomin/datas/musiccap": -9999.0, "/ckptstorage/zhanghaomin/TangoPromptBank": -9999.0, "audioset": "af-audioset", "/ckptstorage/zhanghaomin/audiosetsl": 9999.0, "/ckptstorage/zhanghaomin/giantsoundeffects": -9999.0}' # /root/datasets/ /radiostorage/
SCORE_THRESHOLD_TRAIN = json.loads(SCORE_THRESHOLD_TRAIN)
for key in SCORE_THRESHOLD_TRAIN:
if key == "audioset":
continue
if SCORE_THRESHOLD_TRAIN[key] <= -9000.0:
SCORE_THRESHOLD_TRAIN[key] = -np.inf
print("SCORE_THRESHOLD_TRAIN", SCORE_THRESHOLD_TRAIN)
stft = EncodecWrapper("facebook/encodec_24khz")
####eval_dataset = Text2AudioDataset(None, "val_instruments", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
eval_dataset = Text2AudioDataset(None, "val_vggsound", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=1, collate_fn=eval_dataset.collate_fn, num_workers=1, drop_last=False, pin_memory=True)
i = 0
for b, batch in enumerate(eval_dataloader):
if (bs is not None) and (b not in bs):
continue
#text, mel_spec, _, mel_lengths = batch
text, mel_spec, video_paths, mel_lengths, video_drop_prompt, audio_drop_prompt, frames, midis = batch
print(mel_spec.shape, mel_lengths, text, video_paths, video_drop_prompt, audio_drop_prompt, frames.shape if frames is not None and not isinstance(frames, float) else frames, midis.shape if midis is not None else midis, midis.sum() if midis is not None else midis)
text = text[i:i+1]
mel_spec = mel_spec[i:i+1, 0:mel_lengths[i], :]
mel_lengths = mel_lengths[i:i+1]
video_paths = video_paths[i:i+1]
video_path = out_dir + video_paths[0].replace("/", "__")
audio_path = video_path.replace(".mp4", ".wav")
name = video_paths[0].rsplit("/", 1)[1].rsplit(".", 1)[0]
num = 1
l = mel_lengths[0]
#cond = mel_spec.repeat(num, 1, 1)
cond = torch.randn(num, l, e2tts.num_channels)
duration = torch.tensor([l]*num, dtype=torch.int32)
lens = torch.tensor([l]*num, dtype=torch.int32)
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
#e2tts.sample(text=[""]*num, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename="test.wav", steps=16, cfg_strength=3.0, remove_parallel_component=False, sway_sampling=True)
e2tts.sample(text=None, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename=audio_path, steps=64, prompt=text*num, video_drop_prompt=video_drop_prompt, audio_drop_prompt=audio_drop_prompt, cfg_strength=2.0, remove_parallel_component=False, sway_sampling=True, video_paths=video_paths, frames=(frames if frames is None or isinstance(frames, float) else frames.to("cuda")), midis=(midis if midis is None else midis.to("cuda")))
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "sample")
#one_audio = e2tts.vocos.decode(mel_spec_raw.to("cuda"))
#one_audio = e2tts.vocos.decode(cond.transpose(-1,-2).to("cuda"))
#print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "vocoder")
#torchaudio.save("ref.wav", one_audio.detach().cpu(), sample_rate = e2tts.sampling_rate)
try:
os.system("cp \"" + video_paths[0] + "\" \"" + video_path + "\"")
video = VideoFileClip(video_path)
audio = AudioFileClip(audio_path)
print("duration", video.duration, audio.duration)
if video.duration >= audio.duration:
video = video.subclip(0, audio.duration)
else:
audio = audio.subclip(0, video.duration)
final_video = video.set_audio(audio)
final_video.write_videofile(video_path.replace(".mp4", ".v2a.mp4"), codec="libx264", audio_codec="aac")
print("\"" + video_path.replace(".mp4", ".v2a.mp4") + "\"")
except Exception as e:
print("Exception write_videofile:", video_path.replace(".mp4", ".v2a.mp4"))
traceback.print_exc()
if False:
if not os.path.exists(out_dir+"groundtruth/"):
os.makedirs(out_dir+"groundtruth/")
if not os.path.exists(out_dir+"generated/"):
os.makedirs(out_dir+"generated/")
duration_gt = video.duration
duration_gr = final_video.duration
duration = min(duration_gt, duration_gr)
audio_gt = video.audio.subclip(0, duration)
audio_gr = final_video.audio.subclip(0, duration)
audio_gt.write_audiofile(out_dir+"groundtruth/"+name+".wav", fps=24000)
audio_gr.write_audiofile(out_dir+"generated/"+name+".wav", fps=24000)
if __name__ == "__main__":
main()
|