File size: 11,221 Bytes
1991049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import sys

if len(sys.argv) >= 6:
    ckpt = sys.argv[1]
    drop_prompt = bool(int(sys.argv[2]))
    test_scp = sys.argv[3]
    start = int(sys.argv[4])
    end = int(sys.argv[5])
    step = 1
    out_dir = sys.argv[6]
    print("inference", ckpt, drop_prompt, test_scp, start, end, out_dir)
else:
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt"
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt"
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/315000.pt"
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more/60000.pt"
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano5/4_2_8000.pt"
    ckpt = "./ckpts/piano5_4_2_8000.pt"
    #ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano6/dpo_100.pt"
    drop_prompt = False
    test_scp = "/ailab-train/speech/zhanghaomin/scps/VGGSound/test.scp"
    #test_scp = "./tests/vgg_test.scp"
    ####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/test.scp"
    ####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_2h/test.scp"
    ####test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_20h/v2a_giant_piano2/test.scp"
    start = 0
    end = 2
    step = 1
    out_dir = "./outputs_vgg/"
    ####out_dir = "./outputs_piano/"
    #####out_dir = "./outputs2t_20h_dpo/"


import torch
from e2_tts_pytorch.e2_tts_crossatt3 import E2TTS, DurationPredictor
from e2_tts_pytorch.e2_tts_crossatt3 import MelSpec, EncodecWrapper

from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

from e2_tts_pytorch.trainer_multigpus_alldatas3 import HFDataset, Text2AudioDataset

from einops import einsum, rearrange, repeat, reduce, pack, unpack
import torchaudio

from datetime import datetime
import json
import numpy as np
import os
from moviepy.editor import VideoFileClip, AudioFileClip
import traceback


audiocond_drop_prob = 1.1
#audiocond_drop_prob = 0.3
#cond_proj_in_bias = True
#cond_drop_prob = 1.1
cond_drop_prob = -0.1
prompt_drop_prob = -0.1
#prompt_drop_prob = 1.1
video_text = True


def main():
    #duration_predictor = DurationPredictor(
    #    transformer = dict(
    #        dim = 512,
    #        depth = 6,
    #    )
    #)
    duration_predictor = None

    e2tts = E2TTS(
        duration_predictor = duration_predictor,
        transformer = dict(
            #depth = 12,
            #dim = 512,
            #heads = 8,
            #dim_head = 64,
            depth = 12,
            dim = 1024,
            dim_text = 1280,
            heads = 16,
            dim_head = 64,
            if_text_modules = (cond_drop_prob < 1.0),
            if_cross_attn = (prompt_drop_prob < 1.0),
            if_audio_conv = True,
            if_text_conv = True,
        ),
        #tokenizer = 'char_utf8',
        tokenizer = 'phoneme_zh',
        audiocond_drop_prob = audiocond_drop_prob,
        cond_drop_prob = cond_drop_prob,
        prompt_drop_prob = prompt_drop_prob,
        frac_lengths_mask = (0.7, 1.0),
        #audiocond_snr = None,
        #audiocond_snr = (5.0, 10.0),
        
        if_cond_proj_in = (audiocond_drop_prob < 1.0),
        #cond_proj_in_bias = cond_proj_in_bias,
        if_embed_text = (cond_drop_prob < 1.0) and (not video_text),
        if_text_encoder2 = (prompt_drop_prob < 1.0),
        if_clip_encoder = video_text,
        video_encoder = "clip_vit",
        
        pretrained_vocos_path = 'facebook/encodec_24khz',
        num_channels = 128,
        sampling_rate = 24000,
    )
    e2tts = e2tts.to("cuda")

    #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
    #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
    #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt", map_location="cpu")
    #checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt", map_location="cpu")
    checkpoint = torch.load(ckpt, map_location="cpu")

    #for key in list(checkpoint['model_state_dict'].keys()):
    #    if key.startswith('mel_spec.'):
    #        del checkpoint['model_state_dict'][key]
    #    if key.startswith('transformer.text_registers'):
    #        del checkpoint['model_state_dict'][key]
    e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
    
    e2tts.vocos = EncodecWrapper("facebook/encodec_24khz")
    for param in e2tts.vocos.parameters():
        param.requires_grad = False
    e2tts.vocos.eval()
    e2tts.vocos.to("cuda")

    #dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
    #sample = dataset[1]
    #mel_spec_raw = sample["mel_spec"].unsqueeze(0)
    #mel_spec = rearrange(mel_spec_raw, 'b d n -> b n d')
    #print(mel_spec.shape, sample["text"])

    #out_dir = "/user-fs/zhanghaomin/v2a_generated/v2a_190000_tests/"
    #out_dir = "/user-fs/zhanghaomin/v2a_generated/tv2a_98500_clips/"
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #bs = list(range(10)) + [14,16]
    bs = None
    
    SCORE_THRESHOLD_TRAIN = '{"/zhanghaomin/datas/audiocaps": -9999.0, "/radiostorage/WavCaps": -9999.0, "/radiostorage/AudioGroup": 9999.0, "/ckptstorage/zhanghaomin/audioset": -9999.0, "/ckptstorage/zhanghaomin/BBCSoundEffects": 9999.0, "/ckptstorage/zhanghaomin/CLAP_freesound": 9999.0, "/zhanghaomin/datas/musiccap": -9999.0, "/ckptstorage/zhanghaomin/TangoPromptBank": -9999.0, "audioset": "af-audioset", "/ckptstorage/zhanghaomin/audiosetsl": 9999.0, "/ckptstorage/zhanghaomin/giantsoundeffects": -9999.0}'  # /root/datasets/ /radiostorage/
    SCORE_THRESHOLD_TRAIN = json.loads(SCORE_THRESHOLD_TRAIN)
    for key in SCORE_THRESHOLD_TRAIN:
        if key == "audioset":
            continue
        if SCORE_THRESHOLD_TRAIN[key] <= -9000.0:
            SCORE_THRESHOLD_TRAIN[key] = -np.inf
    print("SCORE_THRESHOLD_TRAIN", SCORE_THRESHOLD_TRAIN)
    stft = EncodecWrapper("facebook/encodec_24khz")
    ####eval_dataset = Text2AudioDataset(None, "val_instruments", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
    eval_dataset = Text2AudioDataset(None, "val_vggsound", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=1, collate_fn=eval_dataset.collate_fn, num_workers=1, drop_last=False, pin_memory=True)
    i = 0
    for b, batch in enumerate(eval_dataloader):
        if (bs is not None) and (b not in bs):
            continue
        #text, mel_spec, _, mel_lengths = batch
        text, mel_spec, video_paths, mel_lengths, video_drop_prompt, audio_drop_prompt, frames, midis = batch
        print(mel_spec.shape, mel_lengths, text, video_paths, video_drop_prompt, audio_drop_prompt, frames.shape if frames is not None and not isinstance(frames, float) else frames, midis.shape if midis is not None else midis, midis.sum() if midis is not None else midis)
        text = text[i:i+1]
        mel_spec = mel_spec[i:i+1, 0:mel_lengths[i], :]
        mel_lengths = mel_lengths[i:i+1]
        video_paths = video_paths[i:i+1]
        video_path = out_dir + video_paths[0].replace("/", "__")
        audio_path = video_path.replace(".mp4", ".wav")
        
        name = video_paths[0].rsplit("/", 1)[1].rsplit(".", 1)[0]
        
        num = 1
        
        l = mel_lengths[0]
        #cond = mel_spec.repeat(num, 1, 1)
        cond = torch.randn(num, l, e2tts.num_channels)
        duration = torch.tensor([l]*num, dtype=torch.int32)
        lens = torch.tensor([l]*num, dtype=torch.int32)
        print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
        #e2tts.sample(text=[""]*num, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename="test.wav", steps=16, cfg_strength=3.0, remove_parallel_component=False, sway_sampling=True)
        e2tts.sample(text=None, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename=audio_path, steps=64, prompt=text*num, video_drop_prompt=video_drop_prompt, audio_drop_prompt=audio_drop_prompt, cfg_strength=2.0, remove_parallel_component=False, sway_sampling=True, video_paths=video_paths, frames=(frames if frames is None or isinstance(frames, float) else frames.to("cuda")), midis=(midis if midis is None else midis.to("cuda")))
        print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "sample")
        #one_audio = e2tts.vocos.decode(mel_spec_raw.to("cuda"))
        #one_audio = e2tts.vocos.decode(cond.transpose(-1,-2).to("cuda"))
        #print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "vocoder")
        #torchaudio.save("ref.wav", one_audio.detach().cpu(), sample_rate = e2tts.sampling_rate)
        
        try:
            os.system("cp \"" + video_paths[0] + "\" \"" + video_path + "\"")
            video = VideoFileClip(video_path)
            audio = AudioFileClip(audio_path)
            print("duration", video.duration, audio.duration)
            if video.duration >= audio.duration:
                video = video.subclip(0, audio.duration)
            else:
                audio = audio.subclip(0, video.duration)
            final_video = video.set_audio(audio)
            final_video.write_videofile(video_path.replace(".mp4", ".v2a.mp4"), codec="libx264", audio_codec="aac")
            print("\"" + video_path.replace(".mp4", ".v2a.mp4") + "\"")
        except Exception as e:
            print("Exception write_videofile:", video_path.replace(".mp4", ".v2a.mp4"))
            traceback.print_exc()
        
        if False:
            if not os.path.exists(out_dir+"groundtruth/"):
                os.makedirs(out_dir+"groundtruth/")
            if not os.path.exists(out_dir+"generated/"):
                os.makedirs(out_dir+"generated/")
            duration_gt = video.duration
            duration_gr = final_video.duration
            duration = min(duration_gt, duration_gr)
            audio_gt = video.audio.subclip(0, duration)
            audio_gr = final_video.audio.subclip(0, duration)
            audio_gt.write_audiofile(out_dir+"groundtruth/"+name+".wav", fps=24000)
            audio_gr.write_audiofile(out_dir+"generated/"+name+".wav", fps=24000)


if __name__ == "__main__":
    main()