""" ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension dt - dimension text """ from __future__ import annotations from pathlib import Path from random import random from functools import partial from itertools import zip_longest from collections import namedtuple from typing import Literal, Callable import jaxtyping from beartype import beartype import torch import torch.nn.functional as F from torch import nn, tensor, Tensor, from_numpy from torch.nn import Module, ModuleList, Sequential, Linear from torch.nn.utils.rnn import pad_sequence import torchaudio from torchaudio.functional import DB_to_amplitude from torchdiffeq import odeint import einx from einops.layers.torch import Rearrange from einops import rearrange, repeat, reduce, pack, unpack from x_transformers import ( Attention, FeedForward, RMSNorm, AdaptiveRMSNorm, ) from x_transformers.x_transformers import RotaryEmbedding import sys sys.path.insert(0, "/zhanghaomin/codes3/vocos-main/") from vocos import Vocos from transformers import AutoTokenizer from transformers import T5EncoderModel from transformers import EncodecModel, AutoProcessor sys.path.insert(0, "./src/audeo/") import Video2RollNet import torchvision.transforms as transforms ####transform = transforms.Compose([lambda x: x.resize((900,100)), #### lambda x: np.reshape(x,(100,900,1)), #### lambda x: np.transpose(x,[2,0,1]), #### lambda x: x/255.]) transform = transforms.Compose([lambda x: x.resize((100,900)), lambda x: np.reshape(x,(900,100,1)), lambda x: np.transpose(x,[2,1,0]), lambda x: x/255.]) ####NOTES = 51 ####NOTTE_MIN = 15 ####NOTE_MAX = 65 NOTES = 88 NOTTE_MIN = 0#15 NOTE_MAX = 87#72 import os import math import traceback import numpy as np from moviepy.editor import AudioFileClip, VideoFileClip from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection #import open_clip from transformers import AutoImageProcessor, AutoModel from PIL import Image import time import warnings warnings.filterwarnings("ignore") def normalize_wav(waveform): waveform = waveform - torch.mean(waveform) waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) return waveform * 0.5 def read_frames_with_moviepy(video_path, max_frame_nums=None): try: clip = VideoFileClip(video_path) duration = clip.duration frames = [] for frame in clip.iter_frames(): frames.append(frame) except: print("Error read_frames_with_moviepy", video_path) traceback.print_exc() return None, None if max_frame_nums is not None: frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int) return np.array(frames)[frames_idx, ...], duration else: return np.array(frames), duration pad_sequence = partial(pad_sequence, batch_first = True) # constants class TorchTyping: def __init__(self, abstract_dtype): self.abstract_dtype = abstract_dtype def __getitem__(self, shapes: str): return self.abstract_dtype[Tensor, shapes] Float = TorchTyping(jaxtyping.Float) Int = TorchTyping(jaxtyping.Int) Bool = TorchTyping(jaxtyping.Bool) # named tuples LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency', 'a', 'b']) E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown']) # helpers def exists(v): return v is not None def default(v, d): return v if exists(v) else d def divisible_by(num, den): return (num % den) == 0 def pack_one_with_inverse(x, pattern): packed, packed_shape = pack([x], pattern) def inverse(x, inverse_pattern = None): inverse_pattern = default(inverse_pattern, pattern) return unpack(x, packed_shape, inverse_pattern)[0] return packed, inverse class Identity(Module): def forward(self, x, **kwargs): return x # tensor helpers def project(x, y): x, inverse = pack_one_with_inverse(x, 'b *') y, _ = pack_one_with_inverse(y, 'b *') dtype = x.dtype x, y = x.double(), y.double() unit = F.normalize(y, dim = -1) parallel = (x * unit).sum(dim = -1, keepdim = True) * unit orthogonal = x - parallel return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) # simple utf-8 tokenizer, since paper went character based def list_str_to_tensor( text: list[str], padding_value = -1 ) -> Int['b nt']: list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text] padded_tensor = pad_sequence(list_tensors, padding_value = -1) return padded_tensor # simple english phoneme-based tokenizer from g2p_en import G2p import jieba from pypinyin import lazy_pinyin, Style def get_g2p_en_encode(): g2p = G2p() # used by @lucasnewman successfully here # https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py phoneme_to_index = g2p.p2idx num_phonemes = len(phoneme_to_index) extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..'] num_extended_chars = len(extended_chars) extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)} phoneme_to_index = {**phoneme_to_index, **extended_chars_dict} def encode( text: list[str], padding_value = -1 ) -> Int['b nt']: phonemes = [g2p(t) for t in text] list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] padded_tensor = pad_sequence(list_tensors, padding_value = -1) return padded_tensor return encode, (num_phonemes + num_extended_chars) def all_en(word): res = word.replace("'", "").encode('utf-8').isalpha() return res def all_ch(word): res = True for w in word: if not '\u4e00' <= w <= '\u9fff': res = False return res def get_g2p_zh_encode(): puncs = [',', '。', '?', '、'] pinyins = ['a', 'a1', 'ai1', 'ai2', 'ai3', 'ai4', 'an1', 'an3', 'an4', 'ang1', 'ang2', 'ang4', 'ao1', 'ao2', 'ao3', 'ao4', 'ba', 'ba1', 'ba2', 'ba3', 'ba4', 'bai1', 'bai2', 'bai3', 'bai4', 'ban1', 'ban2', 'ban3', 'ban4', 'bang1', 'bang2', 'bang3', 'bang4', 'bao1', 'bao2', 'bao3', 'bao4', 'bei', 'bei1', 'bei2', 'bei3', 'bei4', 'ben1', 'ben2', 'ben3', 'ben4', 'beng1', 'beng2', 'beng4', 'bi1', 'bi2', 'bi3', 'bi4', 'bian1', 'bian2', 'bian3', 'bian4', 'biao1', 'biao2', 'biao3', 'bie1', 'bie2', 'bie3', 'bie4', 'bin1', 'bin4', 'bing1', 'bing2', 'bing3', 'bing4', 'bo', 'bo1', 'bo2', 'bo3', 'bo4', 'bu2', 'bu3', 'bu4', 'ca1', 'cai1', 'cai2', 'cai3', 'cai4', 'can1', 'can2', 'can3', 'can4', 'cang1', 'cang2', 'cao1', 'cao2', 'cao3', 'ce4', 'cen1', 'cen2', 'ceng1', 'ceng2', 'ceng4', 'cha1', 'cha2', 'cha3', 'cha4', 'chai1', 'chai2', 'chan1', 'chan2', 'chan3', 'chan4', 'chang1', 'chang2', 'chang3', 'chang4', 'chao1', 'chao2', 'chao3', 'che1', 'che2', 'che3', 'che4', 'chen1', 'chen2', 'chen3', 'chen4', 'cheng1', 'cheng2', 'cheng3', 'cheng4', 'chi1', 'chi2', 'chi3', 'chi4', 'chong1', 'chong2', 'chong3', 'chong4', 'chou1', 'chou2', 'chou3', 'chou4', 'chu1', 'chu2', 'chu3', 'chu4', 'chua1', 'chuai1', 'chuai2', 'chuai3', 'chuai4', 'chuan1', 'chuan2', 'chuan3', 'chuan4', 'chuang1', 'chuang2', 'chuang3', 'chuang4', 'chui1', 'chui2', 'chun1', 'chun2', 'chun3', 'chuo1', 'chuo4', 'ci1', 'ci2', 'ci3', 'ci4', 'cong1', 'cong2', 'cou4', 'cu1', 'cu4', 'cuan1', 'cuan2', 'cuan4', 'cui1', 'cui3', 'cui4', 'cun1', 'cun2', 'cun4', 'cuo1', 'cuo2', 'cuo4', 'da', 'da1', 'da2', 'da3', 'da4', 'dai1', 'dai3', 'dai4', 'dan1', 'dan2', 'dan3', 'dan4', 'dang1', 'dang2', 'dang3', 'dang4', 'dao1', 'dao2', 'dao3', 'dao4', 'de', 'de1', 'de2', 'dei3', 'den4', 'deng1', 'deng2', 'deng3', 'deng4', 'di1', 'di2', 'di3', 'di4', 'dia3', 'dian1', 'dian2', 'dian3', 'dian4', 'diao1', 'diao3', 'diao4', 'die1', 'die2', 'ding1', 'ding2', 'ding3', 'ding4', 'diu1', 'dong1', 'dong3', 'dong4', 'dou1', 'dou2', 'dou3', 'dou4', 'du1', 'du2', 'du3', 'du4', 'duan1', 'duan2', 'duan3', 'duan4', 'dui1', 'dui4', 'dun1', 'dun3', 'dun4', 'duo1', 'duo2', 'duo3', 'duo4', 'e1', 'e2', 'e3', 'e4', 'ei2', 'en1', 'en4', 'er', 'er2', 'er3', 'er4', 'fa1', 'fa2', 'fa3', 'fa4', 'fan1', 'fan2', 'fan3', 'fan4', 'fang1', 'fang2', 'fang3', 'fang4', 'fei1', 'fei2', 'fei3', 'fei4', 'fen1', 'fen2', 'fen3', 'fen4', 'feng1', 'feng2', 'feng3', 'feng4', 'fo2', 'fou2', 'fou3', 'fu1', 'fu2', 'fu3', 'fu4', 'ga1', 'ga2', 'ga4', 'gai1', 'gai3', 'gai4', 'gan1', 'gan2', 'gan3', 'gan4', 'gang1', 'gang2', 'gang3', 'gang4', 'gao1', 'gao2', 'gao3', 'gao4', 'ge1', 'ge2', 'ge3', 'ge4', 'gei2', 'gei3', 'gen1', 'gen2', 'gen3', 'gen4', 'geng1', 'geng3', 'geng4', 'gong1', 'gong3', 'gong4', 'gou1', 'gou2', 'gou3', 'gou4', 'gu', 'gu1', 'gu2', 'gu3', 'gu4', 'gua1', 'gua2', 'gua3', 'gua4', 'guai1', 'guai2', 'guai3', 'guai4', 'guan1', 'guan2', 'guan3', 'guan4', 'guang1', 'guang2', 'guang3', 'guang4', 'gui1', 'gui2', 'gui3', 'gui4', 'gun3', 'gun4', 'guo1', 'guo2', 'guo3', 'guo4', 'ha1', 'ha2', 'ha3', 'hai1', 'hai2', 'hai3', 'hai4', 'han1', 'han2', 'han3', 'han4', 'hang1', 'hang2', 'hang4', 'hao1', 'hao2', 'hao3', 'hao4', 'he1', 'he2', 'he4', 'hei1', 'hen2', 'hen3', 'hen4', 'heng1', 'heng2', 'heng4', 'hong1', 'hong2', 'hong3', 'hong4', 'hou1', 'hou2', 'hou3', 'hou4', 'hu1', 'hu2', 'hu3', 'hu4', 'hua1', 'hua2', 'hua4', 'huai2', 'huai4', 'huan1', 'huan2', 'huan3', 'huan4', 'huang1', 'huang2', 'huang3', 'huang4', 'hui1', 'hui2', 'hui3', 'hui4', 'hun1', 'hun2', 'hun4', 'huo', 'huo1', 'huo2', 'huo3', 'huo4', 'ji1', 'ji2', 'ji3', 'ji4', 'jia', 'jia1', 'jia2', 'jia3', 'jia4', 'jian1', 'jian2', 'jian3', 'jian4', 'jiang1', 'jiang2', 'jiang3', 'jiang4', 'jiao1', 'jiao2', 'jiao3', 'jiao4', 'jie1', 'jie2', 'jie3', 'jie4', 'jin1', 'jin2', 'jin3', 'jin4', 'jing1', 'jing2', 'jing3', 'jing4', 'jiong3', 'jiu1', 'jiu2', 'jiu3', 'jiu4', 'ju1', 'ju2', 'ju3', 'ju4', 'juan1', 'juan2', 'juan3', 'juan4', 'jue1', 'jue2', 'jue4', 'jun1', 'jun4', 'ka1', 'ka2', 'ka3', 'kai1', 'kai2', 'kai3', 'kai4', 'kan1', 'kan2', 'kan3', 'kan4', 'kang1', 'kang2', 'kang4', 'kao2', 'kao3', 'kao4', 'ke1', 'ke2', 'ke3', 'ke4', 'ken3', 'keng1', 'kong1', 'kong3', 'kong4', 'kou1', 'kou2', 'kou3', 'kou4', 'ku1', 'ku2', 'ku3', 'ku4', 'kua1', 'kua3', 'kua4', 'kuai3', 'kuai4', 'kuan1', 'kuan2', 'kuan3', 'kuang1', 'kuang2', 'kuang4', 'kui1', 'kui2', 'kui3', 'kui4', 'kun1', 'kun3', 'kun4', 'kuo4', 'la', 'la1', 'la2', 'la3', 'la4', 'lai2', 'lai4', 'lan2', 'lan3', 'lan4', 'lang1', 'lang2', 'lang3', 'lang4', 'lao1', 'lao2', 'lao3', 'lao4', 'le', 'le1', 'le4', 'lei', 'lei1', 'lei2', 'lei3', 'lei4', 'leng1', 'leng2', 'leng3', 'leng4', 'li', 'li1', 'li2', 'li3', 'li4', 'lia3', 'lian2', 'lian3', 'lian4', 'liang2', 'liang3', 'liang4', 'liao1', 'liao2', 'liao3', 'liao4', 'lie1', 'lie2', 'lie3', 'lie4', 'lin1', 'lin2', 'lin3', 'lin4', 'ling2', 'ling3', 'ling4', 'liu1', 'liu2', 'liu3', 'liu4', 'long1', 'long2', 'long3', 'long4', 'lou1', 'lou2', 'lou3', 'lou4', 'lu1', 'lu2', 'lu3', 'lu4', 'luan2', 'luan3', 'luan4', 'lun1', 'lun2', 'lun4', 'luo1', 'luo2', 'luo3', 'luo4', 'lv2', 'lv3', 'lv4', 'lve3', 'lve4', 'ma', 'ma1', 'ma2', 'ma3', 'ma4', 'mai2', 'mai3', 'mai4', 'man2', 'man3', 'man4', 'mang2', 'mang3', 'mao1', 'mao2', 'mao3', 'mao4', 'me', 'mei2', 'mei3', 'mei4', 'men', 'men1', 'men2', 'men4', 'meng1', 'meng2', 'meng3', 'meng4', 'mi1', 'mi2', 'mi3', 'mi4', 'mian2', 'mian3', 'mian4', 'miao1', 'miao2', 'miao3', 'miao4', 'mie1', 'mie4', 'min2', 'min3', 'ming2', 'ming3', 'ming4', 'miu4', 'mo1', 'mo2', 'mo3', 'mo4', 'mou1', 'mou2', 'mou3', 'mu2', 'mu3', 'mu4', 'n2', 'na1', 'na2', 'na3', 'na4', 'nai2', 'nai3', 'nai4', 'nan1', 'nan2', 'nan3', 'nan4', 'nang1', 'nang2', 'nao1', 'nao2', 'nao3', 'nao4', 'ne', 'ne2', 'ne4', 'nei3', 'nei4', 'nen4', 'neng2', 'ni1', 'ni2', 'ni3', 'ni4', 'nian1', 'nian2', 'nian3', 'nian4', 'niang2', 'niang4', 'niao2', 'niao3', 'niao4', 'nie1', 'nie4', 'nin2', 'ning2', 'ning3', 'ning4', 'niu1', 'niu2', 'niu3', 'niu4', 'nong2', 'nong4', 'nou4', 'nu2', 'nu3', 'nu4', 'nuan3', 'nuo2', 'nuo4', 'nv2', 'nv3', 'nve4', 'o1', 'o2', 'ou1', 'ou3', 'ou4', 'pa1', 'pa2', 'pa4', 'pai1', 'pai2', 'pai3', 'pai4', 'pan1', 'pan2', 'pan4', 'pang1', 'pang2', 'pang4', 'pao1', 'pao2', 'pao3', 'pao4', 'pei1', 'pei2', 'pei4', 'pen1', 'pen2', 'pen4', 'peng1', 'peng2', 'peng3', 'peng4', 'pi1', 'pi2', 'pi3', 'pi4', 'pian1', 'pian2', 'pian4', 'piao1', 'piao2', 'piao3', 'piao4', 'pie1', 'pie2', 'pie3', 'pin1', 'pin2', 'pin3', 'pin4', 'ping1', 'ping2', 'po1', 'po2', 'po3', 'po4', 'pou1', 'pu1', 'pu2', 'pu3', 'pu4', 'qi1', 'qi2', 'qi3', 'qi4', 'qia1', 'qia3', 'qia4', 'qian1', 'qian2', 'qian3', 'qian4', 'qiang1', 'qiang2', 'qiang3', 'qiang4', 'qiao1', 'qiao2', 'qiao3', 'qiao4', 'qie1', 'qie2', 'qie3', 'qie4', 'qin1', 'qin2', 'qin3', 'qin4', 'qing1', 'qing2', 'qing3', 'qing4', 'qiong1', 'qiong2', 'qiu1', 'qiu2', 'qiu3', 'qu1', 'qu2', 'qu3', 'qu4', 'quan1', 'quan2', 'quan3', 'quan4', 'que1', 'que2', 'que4', 'qun2', 'ran2', 'ran3', 'rang1', 'rang2', 'rang3', 'rang4', 'rao2', 'rao3', 'rao4', 're2', 're3', 're4', 'ren2', 'ren3', 'ren4', 'reng1', 'reng2', 'ri4', 'rong1', 'rong2', 'rong3', 'rou2', 'rou4', 'ru2', 'ru3', 'ru4', 'ruan2', 'ruan3', 'rui3', 'rui4', 'run4', 'ruo4', 'sa1', 'sa2', 'sa3', 'sa4', 'sai1', 'sai4', 'san1', 'san2', 'san3', 'san4', 'sang1', 'sang3', 'sang4', 'sao1', 'sao2', 'sao3', 'sao4', 'se4', 'sen1', 'seng1', 'sha1', 'sha2', 'sha3', 'sha4', 'shai1', 'shai2', 'shai3', 'shai4', 'shan1', 'shan3', 'shan4', 'shang', 'shang1', 'shang3', 'shang4', 'shao1', 'shao2', 'shao3', 'shao4', 'she1', 'she2', 'she3', 'she4', 'shei2', 'shen1', 'shen2', 'shen3', 'shen4', 'sheng1', 'sheng2', 'sheng3', 'sheng4', 'shi', 'shi1', 'shi2', 'shi3', 'shi4', 'shou1', 'shou2', 'shou3', 'shou4', 'shu1', 'shu2', 'shu3', 'shu4', 'shua1', 'shua2', 'shua3', 'shua4', 'shuai1', 'shuai3', 'shuai4', 'shuan1', 'shuan4', 'shuang1', 'shuang3', 'shui2', 'shui3', 'shui4', 'shun3', 'shun4', 'shuo1', 'shuo4', 'si1', 'si2', 'si3', 'si4', 'song1', 'song3', 'song4', 'sou1', 'sou3', 'sou4', 'su1', 'su2', 'su4', 'suan1', 'suan4', 'sui1', 'sui2', 'sui3', 'sui4', 'sun1', 'sun3', 'suo', 'suo1', 'suo2', 'suo3', 'ta1', 'ta3', 'ta4', 'tai1', 'tai2', 'tai4', 'tan1', 'tan2', 'tan3', 'tan4', 'tang1', 'tang2', 'tang3', 'tang4', 'tao1', 'tao2', 'tao3', 'tao4', 'te4', 'teng2', 'ti1', 'ti2', 'ti3', 'ti4', 'tian1', 'tian2', 'tian3', 'tiao1', 'tiao2', 'tiao3', 'tiao4', 'tie1', 'tie2', 'tie3', 'tie4', 'ting1', 'ting2', 'ting3', 'tong1', 'tong2', 'tong3', 'tong4', 'tou', 'tou1', 'tou2', 'tou4', 'tu1', 'tu2', 'tu3', 'tu4', 'tuan1', 'tuan2', 'tui1', 'tui2', 'tui3', 'tui4', 'tun1', 'tun2', 'tun4', 'tuo1', 'tuo2', 'tuo3', 'tuo4', 'wa', 'wa1', 'wa2', 'wa3', 'wa4', 'wai1', 'wai3', 'wai4', 'wan1', 'wan2', 'wan3', 'wan4', 'wang1', 'wang2', 'wang3', 'wang4', 'wei1', 'wei2', 'wei3', 'wei4', 'wen1', 'wen2', 'wen3', 'wen4', 'weng1', 'weng4', 'wo1', 'wo3', 'wo4', 'wu1', 'wu2', 'wu3', 'wu4', 'xi1', 'xi2', 'xi3', 'xi4', 'xia1', 'xia2', 'xia4', 'xian1', 'xian2', 'xian3', 'xian4', 'xiang1', 'xiang2', 'xiang3', 'xiang4', 'xiao1', 'xiao2', 'xiao3', 'xiao4', 'xie1', 'xie2', 'xie3', 'xie4', 'xin1', 'xin2', 'xin4', 'xing1', 'xing2', 'xing3', 'xing4', 'xiong1', 'xiong2', 'xiu1', 'xiu3', 'xiu4', 'xu', 'xu1', 'xu2', 'xu3', 'xu4', 'xuan1', 'xuan2', 'xuan3', 'xuan4', 'xue1', 'xue2', 'xue3', 'xue4', 'xun1', 'xun2', 'xun4', 'ya', 'ya1', 'ya2', 'ya3', 'ya4', 'yan1', 'yan2', 'yan3', 'yan4', 'yang1', 'yang2', 'yang3', 'yang4', 'yao1', 'yao2', 'yao3', 'yao4', 'ye1', 'ye2', 'ye3', 'ye4', 'yi1', 'yi2', 'yi3', 'yi4', 'yin1', 'yin2', 'yin3', 'yin4', 'ying1', 'ying2', 'ying3', 'ying4', 'yo1', 'yong1', 'yong3', 'yong4', 'you1', 'you2', 'you3', 'you4', 'yu1', 'yu2', 'yu3', 'yu4', 'yuan1', 'yuan2', 'yuan3', 'yuan4', 'yue1', 'yue4', 'yun1', 'yun2', 'yun3', 'yun4', 'za1', 'za2', 'za3', 'zai1', 'zai3', 'zai4', 'zan1', 'zan2', 'zan3', 'zan4', 'zang1', 'zang4', 'zao1', 'zao2', 'zao3', 'zao4', 'ze2', 'ze4', 'zei2', 'zen3', 'zeng1', 'zeng4', 'zha1', 'zha2', 'zha3', 'zha4', 'zhai1', 'zhai2', 'zhai3', 'zhai4', 'zhan1', 'zhan2', 'zhan3', 'zhan4', 'zhang1', 'zhang2', 'zhang3', 'zhang4', 'zhao1', 'zhao2', 'zhao3', 'zhao4', 'zhe', 'zhe1', 'zhe2', 'zhe3', 'zhe4', 'zhen1', 'zhen2', 'zhen3', 'zhen4', 'zheng1', 'zheng2', 'zheng3', 'zheng4', 'zhi1', 'zhi2', 'zhi3', 'zhi4', 'zhong1', 'zhong2', 'zhong3', 'zhong4', 'zhou1', 'zhou2', 'zhou3', 'zhou4', 'zhu1', 'zhu2', 'zhu3', 'zhu4', 'zhua1', 'zhua2', 'zhua3', 'zhuai1', 'zhuai3', 'zhuai4', 'zhuan1', 'zhuan2', 'zhuan3', 'zhuan4', 'zhuang1', 'zhuang4', 'zhui1', 'zhui4', 'zhun1', 'zhun2', 'zhun3', 'zhuo1', 'zhuo2', 'zi', 'zi1', 'zi2', 'zi3', 'zi4', 'zong1', 'zong2', 'zong3', 'zong4', 'zou1', 'zou2', 'zou3', 'zou4', 'zu1', 'zu2', 'zu3', 'zuan1', 'zuan3', 'zuan4', 'zui2', 'zui3', 'zui4', 'zun1', 'zuo1', 'zuo2', 'zuo3', 'zuo4'] ens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", ' '] ens_U = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] phoneme_to_index = {} num_phonemes = 0 for index, punc in enumerate(puncs): phoneme_to_index[punc] = index + num_phonemes num_phonemes += len(puncs) for index, pinyin in enumerate(pinyins): phoneme_to_index[pinyin] = index + num_phonemes num_phonemes += len(pinyins) for index, en in enumerate(ens): phoneme_to_index[en] = index + num_phonemes for index, en in enumerate(ens_U): phoneme_to_index[en] = index + num_phonemes num_phonemes += len(ens) #print(num_phonemes, phoneme_to_index) # 1342 def encode( text: list[str], padding_value = -1 ) -> Int['b nt']: phonemes = [] for t in text: one_phoneme = [] brk = False for word in jieba.cut(t): if all_ch(word): seg = lazy_pinyin(word, style=Style.TONE3, tone_sandhi=True) one_phoneme.extend(seg) elif all_en(word): for seg in word: one_phoneme.append(seg) elif word in [",", "。", "?", "、", "'", " "]: one_phoneme.append(word) else: for ch in word: if all_ch(ch): seg = lazy_pinyin(ch, style=Style.TONE3, tone_sandhi=True) one_phoneme.extend(seg) elif all_en(ch): for seg in ch: one_phoneme.append(seg) else: brk = True break if brk: break if not brk: phonemes.append(one_phoneme) else: print("Error Tokenized", t, list(jieba.cut(t))) list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] padded_tensor = pad_sequence(list_tensors, padding_value = -1) return padded_tensor return encode, num_phonemes # tensor helpers def log(t, eps = 1e-5): return t.clamp(min = eps).log() def lens_to_mask( t: Int['b'], length: int | None = None ) -> Bool['b n']: if not exists(length): length = t.amax() seq = torch.arange(length, device = t.device) return einx.less('n, b -> b n', seq, t) def mask_from_start_end_indices( seq_len: Int['b'], start: Int['b'], end: Int['b'] ): max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device = start.device).long() return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end) def mask_from_frac_lengths( seq_len: Int['b'], frac_lengths: Float['b'], max_length: int | None = None, val = False ): lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths if not val: rand = torch.rand_like(frac_lengths) else: rand = torch.tensor([0.5]*frac_lengths.shape[0], device=frac_lengths.device).float() start = (max_start * rand).long().clamp(min = 0) end = start + lengths out = mask_from_start_end_indices(seq_len, start, end) if exists(max_length): out = pad_to_length(out, max_length) return out def maybe_masked_mean( t: Float['b n d'], mask: Bool['b n'] | None = None ) -> Float['b d']: if not exists(mask): return t.mean(dim = 1) t = einx.where('b n, b n d, -> b n d', mask, t, 0.) num = reduce(t, 'b n d -> b d', 'sum') den = reduce(mask.float(), 'b n -> b', 'sum') return einx.divide('b d, b -> b d', num, den.clamp(min = 1.)) def pad_to_length( t: Tensor, length: int, value = None ): seq_len = t.shape[-1] if length > seq_len: t = F.pad(t, (0, length - seq_len), value = value) return t[..., :length] def interpolate_1d( x: Tensor, length: int, mode = 'bilinear' ): x = rearrange(x, 'n d -> 1 d n 1') x = F.interpolate(x, (length, 1), mode = mode) return rearrange(x, '1 d n 1 -> n d') # to mel spec class MelSpec(Module): def __init__( self, filter_length = 1024, hop_length = 256, win_length = 1024, n_mel_channels = 100, sampling_rate = 24_000, normalize = False, power = 1, norm = None, center = True, ): super().__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate = sampling_rate, n_fft = filter_length, win_length = win_length, hop_length = hop_length, n_mels = n_mel_channels, power = power, center = center, normalized = normalize, norm = norm, ) self.register_buffer('dummy', tensor(0), persistent = False) def forward(self, inp): if len(inp.shape) == 3: inp = rearrange(inp, 'b 1 nw -> b nw') assert len(inp.shape) == 2 if self.dummy.device != inp.device: self.to(inp.device) mel = self.mel_stft(inp) mel = log(mel) return mel class EncodecWrapper(Module): def __init__(self, path): super().__init__() self.model = EncodecModel.from_pretrained(path) self.processor = AutoProcessor.from_pretrained(path) for param in self.model.parameters(): param.requires_grad = False self.model.eval() def forward(self, waveform): with torch.no_grad(): inputs = self.processor(raw_audio=waveform[0], sampling_rate=self.processor.sampling_rate, return_tensors="pt") emb = self.model.encoder(inputs.input_values) return emb def decode(self, emb): with torch.no_grad(): output = self.model.decoder(emb) return output[0] from audioldm.audio.stft import TacotronSTFT from audioldm.variational_autoencoder import AutoencoderKL from audioldm.utils import default_audioldm_config, get_metadata def build_pretrained_models(name): checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu") scale_factor = checkpoint["state_dict"]["scale_factor"].item() vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} config = default_audioldm_config(name) vae_config = config["model"]["params"]["first_stage_config"]["params"] vae_config["scale_factor"] = scale_factor vae = AutoencoderKL(**vae_config) vae.load_state_dict(vae_state_dict) fn_STFT = TacotronSTFT( config["preprocessing"]["stft"]["filter_length"], config["preprocessing"]["stft"]["hop_length"], config["preprocessing"]["stft"]["win_length"], config["preprocessing"]["mel"]["n_mel_channels"], config["preprocessing"]["audio"]["sampling_rate"], config["preprocessing"]["mel"]["mel_fmin"], config["preprocessing"]["mel"]["mel_fmax"], ) vae.eval() fn_STFT.eval() return vae, fn_STFT class VaeWrapper(Module): def __init__(self): super().__init__() vae, stft = build_pretrained_models("audioldm-s-full") vae.eval() stft.eval() stft = stft.cpu() self.vae = vae for param in self.vae.parameters(): param.requires_grad = False def forward(self, waveform): return None def decode(self, emb): with torch.no_grad(): b, d, l = emb.shape latents = emb.transpose(1,2).reshape(b, l, 8, 16).transpose(1,2) mel = self.vae.decode_first_stage(latents) wave = self.vae.decode_to_waveform(mel) return wave # convolutional positional generating module # taken from https://github.com/lucidrains/voicebox-pytorch/blob/main/voicebox_pytorch/voicebox_pytorch.py#L203 class DepthwiseConv(Module): def __init__( self, dim, *, kernel_size, groups = None ): super().__init__() assert not divisible_by(kernel_size, 2) groups = default(groups, dim) # full depthwise conv by default self.dw_conv1d = nn.Sequential( nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), nn.SiLU() ) def forward( self, x, mask = None ): if exists(mask): x = einx.where('b n, b n d, -> b n d', mask, x, 0.) x = rearrange(x, 'b n c -> b c n') x = self.dw_conv1d(x) out = rearrange(x, 'b c n -> b n c') if exists(mask): out = einx.where('b n, b n d, -> b n d', mask, out, 0.) return out # adaln zero from DiT paper class AdaLNZero(Module): def __init__( self, dim, dim_condition = None, init_bias_value = -2. ): super().__init__() dim_condition = default(dim_condition, dim) self.to_gamma = nn.Linear(dim_condition, dim) nn.init.zeros_(self.to_gamma.weight) nn.init.constant_(self.to_gamma.bias, init_bias_value) def forward(self, x, *, condition): if condition.ndim == 2: condition = rearrange(condition, 'b d -> b 1 d') gamma = self.to_gamma(condition).sigmoid() return x * gamma # random projection fourier embedding class RandomFourierEmbed(Module): def __init__(self, dim): super().__init__() assert divisible_by(dim, 2) self.register_buffer('weights', torch.randn(dim // 2)) def forward(self, x): freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *') return fourier_embed # character embedding class CharacterEmbed(Module): def __init__( self, dim, num_embeds = 256, ): super().__init__() self.dim = dim self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' def forward( self, text: Int['b nt'], max_seq_len: int, **kwargs ) -> Float['b n d']: text = text + 1 # shift all other token ids up by 1 and use 0 as filler token text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address text = pad_to_length(text, max_seq_len, value = 0) return self.embed(text) class InterpolatedCharacterEmbed(Module): def __init__( self, dim, num_embeds = 256, ): super().__init__() self.dim = dim self.embed = nn.Embedding(num_embeds, dim) self.abs_pos_mlp = Sequential( Rearrange('... -> ... 1'), Linear(1, dim), nn.SiLU(), Linear(dim, dim) ) def forward( self, text: Int['b nt'], max_seq_len: int, mask: Bool['b n'] | None = None ) -> Float['b n d']: device = text.device mask = default(mask, (None,)) interp_embeds = [] interp_abs_positions = [] for one_text, one_mask in zip_longest(text, mask): valid_text = one_text >= 0 one_text = one_text[valid_text] one_text_embed = self.embed(one_text) # save the absolute positions text_seq_len = one_text.shape[0] # determine audio sequence length from mask audio_seq_len = max_seq_len if exists(one_mask): audio_seq_len = one_mask.sum().long().item() # interpolate text embedding to audio embedding length interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len) interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device) interp_embeds.append(interp_text_embed) interp_abs_positions.append(interp_abs_pos) interp_embeds = pad_sequence(interp_embeds) interp_abs_positions = pad_sequence(interp_abs_positions) interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2])) interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len) # pass interp absolute positions through mlp for implicit positions interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions) if exists(mask): interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.) return interp_embeds # text audio cross conditioning in multistream setup class TextAudioCrossCondition(Module): def __init__( self, dim, dim_text, dim_frames, cond_audio_to_text = True, ): super().__init__() #self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False) self.text_frames_to_audio = nn.Linear(dim + dim_text + dim_frames, dim, bias = False) nn.init.zeros_(self.text_frames_to_audio.weight) self.cond_audio_to_text = cond_audio_to_text if cond_audio_to_text: self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False) nn.init.zeros_(self.audio_to_text.weight) self.audio_to_frames = nn.Linear(dim + dim_frames, dim_frames, bias = False) nn.init.zeros_(self.audio_to_frames.weight) def forward( self, audio: Float['b n d'], text: Float['b n dt'], frames: Float['b n df'], ): #audio_text, _ = pack((audio, text), 'b n *') audio_text_frames, _ = pack((audio, text, frames), 'b n *') audio_text, _ = pack((audio, text), 'b n *') audio_frames, _ = pack((audio, frames), 'b n *') #text_cond = self.text_to_audio(audio_text) text_cond = self.text_frames_to_audio(audio_text_frames) audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0. audio_cond2 = self.audio_to_frames(audio_frames) if self.cond_audio_to_text else 0. return audio + text_cond, text + audio_cond, frames + audio_cond2 # attention and transformer backbone # for use in both e2tts as well as duration module class Transformer(Module): @beartype def __init__( self, *, dim, dim_text = None, # will default to half of audio dimension dim_frames = 512, depth = 8, heads = 8, dim_head = 64, ff_mult = 4, text_depth = None, text_heads = None, text_dim_head = None, text_ff_mult = None, cond_on_time = True, abs_pos_emb = True, max_seq_len = 8192, kernel_size = 31, dropout = 0.1, num_registers = 32, attn_kwargs: dict = dict( gate_value_heads = True, softclamp_logits = True, ), ff_kwargs: dict = dict(), if_text_modules = True, if_cross_attn = True, if_audio_conv = True, if_text_conv = False ): super().__init__() assert divisible_by(depth, 2), 'depth needs to be even' # absolute positional embedding self.max_seq_len = max_seq_len self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None self.dim = dim dim_text = default(dim_text, dim // 2) self.dim_text = dim_text self.dim_frames = dim_frames text_heads = default(text_heads, heads) text_dim_head = default(text_dim_head, dim_head) text_ff_mult = default(text_ff_mult, ff_mult) text_depth = default(text_depth, depth) assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers' self.depth = depth self.layers = ModuleList([]) # registers self.num_registers = num_registers self.registers = nn.Parameter(torch.zeros(num_registers, dim)) nn.init.normal_(self.registers, std = 0.02) if if_text_modules: self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text)) nn.init.normal_(self.text_registers, std = 0.02) self.frames_registers = nn.Parameter(torch.zeros(num_registers, dim_frames)) nn.init.normal_(self.frames_registers, std = 0.02) # rotary embedding self.rotary_emb = RotaryEmbedding(dim_head) self.text_rotary_emb = RotaryEmbedding(dim_head) self.frames_rotary_emb = RotaryEmbedding(dim_head) # time conditioning # will use adaptive rmsnorm self.cond_on_time = cond_on_time rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim) self.time_cond_mlp = Identity() if cond_on_time: self.time_cond_mlp = Sequential( RandomFourierEmbed(dim), Linear(dim + 1, dim), nn.SiLU() ) for ind in range(depth): is_later_half = ind >= (depth // 2) has_text = ind < text_depth # speech related if if_audio_conv: speech_conv = DepthwiseConv(dim, kernel_size = kernel_size) attn_norm = rmsnorm_klass(dim) attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs) attn_adaln_zero = postbranch_klass() if if_cross_attn: attn_norm2 = rmsnorm_klass(dim) attn2 = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs) attn_adaln_zero2 = postbranch_klass() ff_norm = rmsnorm_klass(dim) ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) ff_adaln_zero = postbranch_klass() skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None if if_cross_attn: if if_audio_conv: speech_modules = ModuleList([ skip_proj, speech_conv, attn_norm, attn, attn_adaln_zero, attn_norm2, attn2, attn_adaln_zero2, ff_norm, ff, ff_adaln_zero, ]) else: speech_modules = ModuleList([ skip_proj, attn_norm, attn, attn_adaln_zero, attn_norm2, attn2, attn_adaln_zero2, ff_norm, ff, ff_adaln_zero, ]) else: if if_audio_conv: speech_modules = ModuleList([ skip_proj, speech_conv, attn_norm, attn, attn_adaln_zero, ff_norm, ff, ff_adaln_zero, ]) else: speech_modules = ModuleList([ skip_proj, attn_norm, attn, attn_adaln_zero, ff_norm, ff, ff_adaln_zero, ]) text_modules = None if has_text and if_text_modules: # text related if if_text_conv: text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size) text_attn_norm = RMSNorm(dim_text) text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs) text_ff_norm = RMSNorm(dim_text) text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs) # cross condition is_last = ind == (text_depth - 1) cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, dim_frames = dim_frames, cond_audio_to_text = not is_last) if if_text_conv: text_modules = ModuleList([ text_conv, text_attn_norm, text_attn, text_ff_norm, text_ff, cross_condition ]) else: text_modules = ModuleList([ text_attn_norm, text_attn, text_ff_norm, text_ff, cross_condition ]) if True: frames_conv = DepthwiseConv(dim_frames, kernel_size = kernel_size) frames_attn_norm = RMSNorm(dim_frames) frames_attn = Attention(dim = dim_frames, heads = 8, dim_head = 64, dropout = dropout, **attn_kwargs) frames_ff_norm = RMSNorm(dim_frames) frames_ff = FeedForward(dim = dim_frames, glu = True, mult = 4, dropout = dropout, **ff_kwargs) # cross condition frames_modules = ModuleList([ frames_conv, frames_attn_norm, frames_attn, frames_ff_norm, frames_ff ]) self.layers.append(ModuleList([ speech_modules, text_modules, frames_modules ])) self.final_norm = RMSNorm(dim) self.if_cross_attn = if_cross_attn self.if_audio_conv = if_audio_conv self.if_text_conv = if_text_conv def forward( self, x: Float['b n d'], times: Float['b'] | Float[''] | None = None, mask: Bool['b n'] | None = None, text_embed: Float['b n dt'] | None = None, frames_embed: Float['b n df'] | None = None, context: Float['b nc dc'] | None = None, context_mask: Float['b nc'] | None = None ): batch, seq_len, device = *x.shape[:2], x.device assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa' # handle absolute positions if needed if exists(self.abs_pos_emb): assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer' seq = torch.arange(seq_len, device = device) x = x + self.abs_pos_emb(seq) # handle adaptive rmsnorm kwargs norm_kwargs = dict() if exists(times): if times.ndim == 0: times = repeat(times, ' -> b', b = batch) times = self.time_cond_mlp(times) norm_kwargs.update(condition = times) # register tokens registers = repeat(self.registers, 'r d -> b r d', b = batch) x, registers_packed_shape = pack((registers, x), 'b * d') if exists(mask): mask = F.pad(mask, (self.num_registers, 0), value = True) # rotary embedding rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2]) # text related if exists(text_embed): text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2]) text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch) text_embed, _ = pack((text_registers, text_embed), 'b * d') if exists(frames_embed): frames_rotary_pos_emb = self.frames_rotary_emb.forward_from_seq_len(x.shape[-2]) frames_registers = repeat(self.frames_registers, 'r d -> b r d', b = batch) frames_embed, _ = pack((frames_registers, frames_embed), 'b * d') # skip connection related stuff skips = [] # go through the layers for ind, (speech_modules, text_modules, frames_modules) in enumerate(self.layers): layer = ind + 1 if self.if_cross_attn: if self.if_audio_conv: ( maybe_skip_proj, speech_conv, attn_norm, attn, maybe_attn_adaln_zero, attn_norm2, attn2, maybe_attn_adaln_zero2, ff_norm, ff, maybe_ff_adaln_zero ) = speech_modules else: ( maybe_skip_proj, attn_norm, attn, maybe_attn_adaln_zero, attn_norm2, attn2, maybe_attn_adaln_zero2, ff_norm, ff, maybe_ff_adaln_zero ) = speech_modules else: if self.if_audio_conv: ( maybe_skip_proj, speech_conv, attn_norm, attn, maybe_attn_adaln_zero, ff_norm, ff, maybe_ff_adaln_zero ) = speech_modules else: ( maybe_skip_proj, attn_norm, attn, maybe_attn_adaln_zero, ff_norm, ff, maybe_ff_adaln_zero ) = speech_modules # smaller text transformer if exists(text_embed) and exists(text_modules): if self.if_text_conv: ( text_conv, text_attn_norm, text_attn, text_ff_norm, text_ff, cross_condition ) = text_modules else: ( text_attn_norm, text_attn, text_ff_norm, text_ff, cross_condition ) = text_modules if self.if_text_conv: text_embed = text_conv(text_embed, mask = mask) + text_embed text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed text_embed = text_ff(text_ff_norm(text_embed)) + text_embed # frames transformer ( frames_conv, frames_attn_norm, frames_attn, frames_ff_norm, frames_ff ) = frames_modules frames_embed = frames_conv(frames_embed, mask = mask) + frames_embed frames_embed = frames_attn(frames_attn_norm(frames_embed), rotary_pos_emb = frames_rotary_pos_emb, mask = mask) + frames_embed frames_embed = frames_ff(frames_ff_norm(frames_embed)) + frames_embed # cross condition x, text_embed, frames_embed = cross_condition(x, text_embed, frames_embed) # skip connection logic is_first_half = layer <= (self.depth // 2) is_later_half = not is_first_half if is_first_half: skips.append(x) if is_later_half: skip = skips.pop() x = torch.cat((x, skip), dim = -1) x = maybe_skip_proj(x) # position generating convolution if self.if_audio_conv: x = speech_conv(x, mask = mask) + x # attention and feedforward blocks attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs) if self.if_cross_attn: attn_out = attn2(attn_norm2(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, context = context, context_mask = context_mask) x = x + maybe_attn_adaln_zero2(attn_out, **norm_kwargs) ff_out = ff(ff_norm(x, **norm_kwargs)) x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs) assert len(skips) == 0 _, x = unpack(x, registers_packed_shape, 'b * d') return self.final_norm(x) # main classes class DurationPredictor(Module): @beartype def __init__( self, transformer: dict | Transformer, num_channels = None, mel_spec_kwargs: dict = dict(), char_embed_kwargs: dict = dict(), text_num_embeds = None, tokenizer: ( Literal['char_utf8', 'phoneme_en'] | Callable[[list[str]], Int['b nt']] ) = 'char_utf8' ): super().__init__() if isinstance(transformer, dict): transformer = Transformer( **transformer, cond_on_time = False ) # mel spec self.mel_spec = MelSpec(**mel_spec_kwargs) self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) self.transformer = transformer dim = transformer.dim dim_text = transformer.dim_text self.dim = dim self.proj_in = Linear(self.num_channels, self.dim) # tokenizer and text embed if callable(tokenizer): assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' self.tokenizer = tokenizer elif tokenizer == 'char_utf8': text_num_embeds = 256 self.tokenizer = list_str_to_tensor elif tokenizer == 'phoneme_en': self.tokenizer, text_num_embeds = get_g2p_en_encode() elif tokenizer == 'phoneme_zh': self.tokenizer, text_num_embeds = get_g2p_zh_encode() else: raise ValueError(f'unknown tokenizer string {tokenizer}') self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) # to prediction self.to_pred = Sequential( Linear(dim, 1, bias = False), nn.Softplus(), Rearrange('... 1 -> ...') ) def forward( self, x: Float['b n d'] | Float['b nw'], *, text: Int['b nt'] | list[str] | None = None, lens: Int['b'] | None = None, return_loss = True ): # raw wave if x.ndim == 2: x = self.mel_spec(x) x = rearrange(x, 'b d n -> b n d') assert x.shape[-1] == self.dim x = self.proj_in(x) batch, seq_len, device = *x.shape[:2], x.device # text text_embed = None if exists(text): if isinstance(text, list): text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch text_embed = self.embed_text(text, seq_len) # handle lengths (duration) if not exists(lens): lens = torch.full((batch,), seq_len, device = device) mask = lens_to_mask(lens, length = seq_len) # if returning a loss, mask out randomly from an index and have it predict the duration if return_loss: rand_frac_index = x.new_zeros(batch).uniform_(0, 1) rand_index = (rand_frac_index * lens).long() seq = torch.arange(seq_len, device = device) mask &= einx.less('n, b -> b n', seq, rand_index) # attending x = self.transformer( x, mask = mask, text_embed = text_embed, ) x = maybe_masked_mean(x, mask) pred = self.to_pred(x) # return the prediction if not returning loss if not return_loss: return pred # loss return F.mse_loss(pred, lens.float()) class E2TTS(Module): @beartype def __init__( self, transformer: dict | Transformer = None, duration_predictor: dict | DurationPredictor | None = None, odeint_kwargs: dict = dict( #atol = 1e-5, #rtol = 1e-5, #method = 'midpoint' method = "euler" ), audiocond_drop_prob = 0.30, cond_drop_prob = 0.20, prompt_drop_prob = 0.10, num_channels = None, mel_spec_module: Module | None = None, char_embed_kwargs: dict = dict(), mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.), audiocond_snr: tuple[float, float] | None = None, concat_cond = False, interpolated_text = False, text_num_embeds: int | None = None, tokenizer: ( Literal['char_utf8', 'phoneme_en', 'phoneme_zh'] | Callable[[list[str]], Int['b nt']] ) = 'char_utf8', use_vocos = True, pretrained_vocos_path = 'charactr/vocos-mel-24khz', sampling_rate: int | None = None, frame_size: int = 320, #### dpo velocity_consistency_weight = -1e-5, #### dpo if_cond_proj_in = True, cond_proj_in_bias = True, if_embed_text = True, if_text_encoder2 = True, if_clip_encoder = False, video_encoder = "clip_vit" ): super().__init__() if isinstance(transformer, dict): transformer = Transformer( **transformer, cond_on_time = True ) if isinstance(duration_predictor, dict): duration_predictor = DurationPredictor(**duration_predictor) self.transformer = transformer dim = transformer.dim dim_text = transformer.dim_text dim_frames = transformer.dim_frames self.dim = dim self.dim_text = dim_text self.frac_lengths_mask = frac_lengths_mask self.audiocond_snr = audiocond_snr self.duration_predictor = duration_predictor # sampling self.odeint_kwargs = odeint_kwargs # mel spec self.mel_spec = default(mel_spec_module, None) num_channels = default(num_channels, None) self.num_channels = num_channels self.sampling_rate = default(sampling_rate, None) self.frame_size = frame_size # whether to concat condition and project rather than project both and sum self.concat_cond = concat_cond if concat_cond: self.proj_in = nn.Linear(num_channels * 2, dim) else: self.proj_in = nn.Linear(num_channels, dim) self.cond_proj_in = nn.Linear(num_channels, dim, bias=cond_proj_in_bias) if if_cond_proj_in else None #self.cond_proj_in = nn.Linear(NOTES, dim, bias=cond_proj_in_bias) if if_cond_proj_in else None # to prediction self.to_pred = Linear(dim, num_channels) # tokenizer and text embed if callable(tokenizer): assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' self.tokenizer = tokenizer elif tokenizer == 'char_utf8': text_num_embeds = 256 self.tokenizer = list_str_to_tensor elif tokenizer == 'phoneme_en': self.tokenizer, text_num_embeds = get_g2p_en_encode() elif tokenizer == 'phoneme_zh': self.tokenizer, text_num_embeds = get_g2p_zh_encode() else: raise ValueError(f'unknown tokenizer string {tokenizer}') self.audiocond_drop_prob = audiocond_drop_prob self.cond_drop_prob = cond_drop_prob self.prompt_drop_prob = prompt_drop_prob # text embedding text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) if if_embed_text else None # weight for velocity consistency self.register_buffer('zero', torch.tensor(0.), persistent = False) self.velocity_consistency_weight = velocity_consistency_weight # default vocos for mel -> audio #if pretrained_vocos_path == 'charactr/vocos-mel-24khz': # self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None #elif pretrained_vocos_path == 'facebook/encodec_24khz': # self.vocos = EncodecWrapper("facebook/encodec_24khz") if use_vocos else None #elif pretrained_vocos_path == 'vae': # self.vocos = VaeWrapper() if use_vocos else None if if_text_encoder2: self.tokenizer2 = AutoTokenizer.from_pretrained("./ckpts/flan-t5-large") self.text_encoder2 = T5EncoderModel.from_pretrained("./ckpts/flan-t5-large") for param in self.text_encoder2.parameters(): param.requires_grad = False self.text_encoder2.eval() self.proj_text = None self.proj_frames = Linear(NOTES, dim_frames) if if_clip_encoder: if video_encoder == "clip_vit": ####pass self.image_processor = CLIPImageProcessor() #self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/IP-Adapter/", subfolder="models/image_encoder") self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("./ckpts/IP-Adapter/", subfolder="sdxl_models/image_encoder") elif video_encoder == "clip_vit2": self.image_processor = AutoProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/") self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/") elif video_encoder == "clip_convnext": self.image_encoder, _, self.image_processor = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup") elif video_encoder == "dinov2": self.image_processor = AutoImageProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/") self.image_encoder = AutoModel.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/") elif video_encoder == "mixed": pass #self.image_processor1 = CLIPImageProcessor() #self.image_encoder1 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/IP-Adapter/", subfolder="sdxl_models/image_encoder") #self.image_processor2 = AutoProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/") #self.image_encoder2 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/") #self.image_encoder3, _, self.image_processor3 = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup") #self.image_processor4 = AutoImageProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/") #self.image_encoder4 = AutoModel.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/") else: self.image_processor = None self.image_encoder = None if video_encoder != "mixed": ####pass for param in self.image_encoder.parameters(): param.requires_grad = False self.image_encoder.eval() else: #for param in self.image_encoder1.parameters(): # param.requires_grad = False #self.image_encoder1.eval() #for param in self.image_encoder2.parameters(): # param.requires_grad = False #self.image_encoder2.eval() #for param in self.image_encoder3.parameters(): # param.requires_grad = False #self.image_encoder3.eval() #for param in self.image_encoder4.parameters(): # param.requires_grad = False #self.image_encoder4.eval() self.dim_text_raw = 4608 self.proj_text = Linear(self.dim_text_raw, dim_text) self.video_encoder = video_encoder #for param in self.vocos.parameters(): # param.requires_grad = False #self.vocos.eval() ########self.conv1 = nn.Conv3d(6, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.pool1 = nn.Conv3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2)) ####self.conv1 = nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.pool1 = nn.Conv3d(16, 16, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0)) #### ########self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.pool2 = nn.Conv3d(128, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2)) ####self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.pool2 = nn.Conv3d(32, 32, kernel_size=(2, 2, 2), stride=(1, 2, 2), padding=(0,0,0)) #### ########self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.pool3 = nn.Conv3d(256, 256, kernel_size=(1, 2, 2), stride=(1, 2, 2)) ####self.conv3a = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.conv3b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.pool3 = nn.Conv3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0)) #### ########self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.pool4 = nn.Conv3d(512, 512, kernel_size=(1, 2, 2), stride=(1, 2, 2)) #####self.conv4a = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) #####self.conv4b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ######self.pool4 = nn.Conv3d(64, 32, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0)) #####self.pool4 = nn.ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1)) #### ########self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ########self.pool5 = nn.ConvTranspose3d(512, 128, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1)) ####self.conv5a = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.conv5b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) ####self.pool5 = nn.ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1)) #####self.pool5 = nn.Conv3d(256, 256, kernel_size=(1, 2, 2), stride=(1, 2, 2)) #### ####self.relu = nn.ReLU() ####self.final_activation = nn.Sigmoid() ####self.dropout = nn.Dropout(p=0.50) ########self.fc5 = nn.Linear(51200, NOTES) ####self.fc5 = nn.Linear(65536, 208) ####self.fc6 = nn.Linear(208, NOTES) #### #####self.rnn = nn.RNN(NOTES, NOTES, 1) #####self.fc7 = nn.Linear(NOTES, NOTES) #### #####self.bn1 = nn.BatchNorm3d(16) #####self.bn2 = nn.BatchNorm3d(32) #####self.bn3 = nn.BatchNorm3d(64) #####self.bn4 = nn.BatchNorm3d(64) #####self.bn5 = nn.BatchNorm3d(64) #####self.bn6 = nn.BatchNorm3d(64) #####self.bn7 = nn.BatchNorm1d(208) self.video2roll_net = Video2RollNet.resnet18(num_classes=NOTES) def encode_frames(self, x, l): #print("x input", x.shape, l) # [1, 1, 251, 100, 900] b, c, t, w, h = x.shape assert(c == 1) x_all = [] for i in range(t): frames = [] for j in [-2, -1, 0, 1, 2]: f = min(max(i+j, 0), t-1) frames.append(x[:,:,f:f+1,:,:]) frames = torch.cat(frames, dim=2) # [b, 1, 5, w, h] x_all.append(frames) x = torch.cat(x_all, dim=1).reshape(b*t, 5, w, h) # [b*t, 5, w, h] #print("x", x.shape, l) # [251, 5, 100, 900] x = self.video2roll_net(x) x = nn.Sigmoid()(x) #print("x output", x.shape) # [251, 51] ####video_multi ####x = x.reshape(b, t, 1, NOTES).repeat(1,1,3,1).reshape(b, t*3, NOTES) t5 = (t*5//2)*2 x = x.reshape(b, t, 1, NOTES).repeat(1,1,5,1).reshape(b, t*5, NOTES)[:,:t5,:].reshape(b, t5//2, 2, NOTES).mean(2) b, d, _ = x.shape #print("encode_frames", x.shape, l) if d > l: x = x[:,:l,:] elif d < l: x = torch.cat((x, torch.zeros(b,l-d,NOTES,device=x.device)), 1) return x return x ####def encode_frames(self, x, l): #### x = x[:,:3,:,...] #### #print("x", x.shape) # [2, 6, 301, 320, 320] # [2, 3, 251, 128, 1024] #### x = self.conv1(x) #### #x = self.bn1(x) #### x = self.relu(x) #### x = self.dropout(x) #### #print("conv1", x.shape) # [2, 64, 301, 320, 320] # [2, 16, 251, 128, 1024] #### x = self.pool1(x) #### #print("pool1", x.shape) # [2, 64, 301, 160, 160] # [2, 16, 251, 64, 512] #### #### x = self.conv2(x) #### #x = self.bn2(x) #### x = self.relu(x) #### x = self.dropout(x) #### #print("conv2", x.shape) # [2, 128, 301, 160, 160] # [2, 32, 250, 64, 512] #### x = self.pool2(x) #### #x = self.relu(x) #### #x = self.dropout(x) #### #print("pool2", x.shape) # [2, 128, 150, 80, 80] # [2, 32, 250, 32, 256] #### #### x = self.conv3a(x) #### #x = self.bn3(x) #### x = self.relu(x) #### x = self.dropout(x) #### x = self.conv3b(x) #### #x = self.bn4(x) #### x = self.relu(x) #### x = self.dropout(x) #### #print("conv3", x.shape) # [2, 256, 150, 80, 80] # [2, 64, 250, 32, 256] #### x = self.pool3(x) #### #x = self.relu(x) #### #x = self.dropout(x) #### #print("pool3", x.shape) # [2, 256, 150, 40, 40] # [2, 64, 250, 16, 128] #### #### #x = self.conv4a(x) #### #x = self.relu(x) #### ##x = self.dropout(x) #### #x = self.conv4b(x) #### #x = self.relu(x) #### ##x = self.dropout(x) #### ###print("conv4", x.shape) # [2, 512, 150, 40, 40] # [2, 64, 250, 16, 128] #### #x = self.pool4(x) #### ##print("pool4", x.shape) # [2, 512, 150, 20, 20] # [2, 32, 250, 8, 64] #### #### x = self.conv5a(x) #### #x = self.bn5(x) #### x = self.relu(x) #### x = self.dropout(x) #### x = self.conv5b(x) #### #x = self.bn6(x) #### x = self.relu(x) #### x = self.dropout(x) #### #print("conv5", x.shape) # [2, 512, 150, 20, 20] # [2, 64, 250, 16, 128] #### x = self.pool5(x) #### #x = self.relu(x) #### #x = self.dropout(x) #### #print("pool5", x.shape) # [2, 128, 750, 20, 20] # [2, 32, 750/250, 16, 128] #### #### b, c, d, w, h = x.shape #### x = x.permute(0,2,3,4,1).reshape(b,d,w*h*c) #### x = self.fc5(x) #### #x = x.reshape(b,208,d) #### #x = self.bn7(x) #### #x = x.reshape(b,d,208) #### x = self.relu(x) #### x = self.dropout(x) #### x = self.fc6(x) #### #### #x = self.relu(x) #### #x, _ = self.rnn(x) #### #x = self.fc7(x) #### #### x = self.final_activation(x) #### #### #x = x.reshape(b,d,1,NOTES).repeat(1,1,3,1).reshape(b,d*3,NOTES) #### #d = d * 3 #### #### #print("encode_frames", x.shape, l) #### if d > l: #### x = x[:,:l,:] #### elif d < l: #### x = torch.cat((x, torch.zeros(b,l-d,NOTES,device=x.device)), 1) #### return x @property def device(self): return next(self.parameters()).device def encode_text(self, prompt): device = self.device batch = self.tokenizer2(prompt, max_length=self.tokenizer2.model_max_length, padding=True, truncation=True, return_tensors="pt") input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) with torch.no_grad(): encoder_hidden_states = self.text_encoder2(input_ids=input_ids, attention_mask=attention_mask)[0] boolean_encoder_mask = (attention_mask == 1).to(device) return encoder_hidden_states, boolean_encoder_mask def encode_video(self, video_paths, l): if self.proj_text is None: d = self.dim_text else: d = self.dim_text_raw device = self.device b = 20 with torch.no_grad(): video_embeddings = [] video_lens = [] for video_path in video_paths: if video_path is None: video_embeddings.append(None) video_lens.append(0) continue if isinstance(video_path, tuple): video_path, start_sample, max_sample = video_path else: start_sample = 0 max_sample = None if video_path.startswith("/ailab-train2/speech/zhanghaomin/VGGSound/"): if self.video_encoder == "clip_vit": feature_path = video_path.replace("/video/", "/feature/").replace(".mp4", ".npz") elif self.video_encoder == "clip_vit2": feature_path = video_path.replace("/video/", "/feature_clip_vit2/").replace(".mp4", ".npz") elif self.video_encoder == "clip_convnext": feature_path = video_path.replace("/video/", "/feature_clip_convnext/").replace(".mp4", ".npz") elif self.video_encoder == "dinov2": feature_path = video_path.replace("/video/", "/feature_dinov2/").replace(".mp4", ".npz") elif self.video_encoder == "mixed": feature_path = video_path.replace("/video/", "/feature_mixed/").replace(".mp4", ".npz") else: raise Exception("Invalid video_encoder " + self.video_encoder) else: if self.video_encoder == "clip_vit": feature_path = video_path.replace(".mp4", ".generated.npz") elif self.video_encoder == "clip_vit2": feature_path = video_path.replace(".mp4", ".generated.clip_vit2.npz") elif self.video_encoder == "clip_convnext": feature_path = video_path.replace(".mp4", ".generated.clip_convnext.npz") elif self.video_encoder == "dinov2": feature_path = video_path.replace(".mp4", ".generated.dinov2.npz") elif self.video_encoder == "mixed": feature_path = video_path.replace(".mp4", ".generated.mixed.npz") else: raise Exception("Invalid video_encoder " + self.video_encoder) if not os.path.exists(feature_path): #print("video not exist", video_path) frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=None) if frames is None: video_embeddings.append(None) video_lens.append(0) continue if self.video_encoder in ["clip_vit", "clip_vit2", "dinov2"]: images = self.image_processor(images=frames, return_tensors="pt").to(device) #print("images", images["pixel_values"].shape, images["pixel_values"].max(), images["pixel_values"].min(), torch.abs(images["pixel_values"]).mean()) elif self.video_encoder in ["clip_convnext"]: images = [] for i in range(frames.shape[0]): images.append(self.image_processor(Image.fromarray(frames[i])).unsqueeze(0)) images = torch.cat(images, dim=0).to(device) #print("images", images.shape, images.max(), images.min(), torch.abs(images).mean()) elif self.video_encoder in ["mixed"]: #images1 = self.image_processor1(images=frames, return_tensors="pt").to(device) images2 = self.image_processor2(images=frames, return_tensors="pt").to(device) images4 = self.image_processor4(images=frames, return_tensors="pt").to(device) images3 = [] for i in range(frames.shape[0]): images3.append(self.image_processor3(Image.fromarray(frames[i])).unsqueeze(0)) images3 = torch.cat(images3, dim=0).to(device) else: raise Exception("Invalid video_encoder " + self.video_encoder) image_embeddings = [] if self.video_encoder == "clip_vit": for i in range(math.ceil(images["pixel_values"].shape[0] / b)): image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) elif self.video_encoder == "clip_vit2": for i in range(math.ceil(images["pixel_values"].shape[0] / b)): image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) elif self.video_encoder == "clip_convnext": for i in range(math.ceil(images.shape[0] / b)): image_embeddings.append(self.image_encoder.encode_image(images[i*b: (i+1)*b]).cpu()) elif self.video_encoder == "dinov2": for i in range(math.ceil(images["pixel_values"].shape[0] / b)): image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu()) elif self.video_encoder == "mixed": feature_path1 = feature_path.replace("/feature_mixed/", "/feature/") if not os.path.exists(feature_path1): image_embeddings1 = [] for i in range(math.ceil(images1["pixel_values"].shape[0] / b)): image_embeddings1.append(self.image_encoder1(pixel_values=images1["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) image_embeddings1 = torch.cat(image_embeddings1, dim=0) #np.savez(feature_path1, image_embeddings1, duration) else: data1 = np.load(feature_path1) image_embeddings1 = torch.from_numpy(data1["arr_0"]) feature_path2 = feature_path.replace("/feature_mixed/", "/feature_clip_vit2/") if not os.path.exists(feature_path2): image_embeddings2 = [] for i in range(math.ceil(images2["pixel_values"].shape[0] / b)): image_embeddings2.append(self.image_encoder2(pixel_values=images2["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) image_embeddings2 = torch.cat(image_embeddings2, dim=0) np.savez(feature_path2, image_embeddings2, duration) else: data2 = np.load(feature_path2) image_embeddings2 = torch.from_numpy(data2["arr_0"]) feature_path3 = feature_path.replace("/feature_mixed/", "/feature_clip_convnext/") if not os.path.exists(feature_path3): image_embeddings3 = [] for i in range(math.ceil(images3.shape[0] / b)): image_embeddings3.append(self.image_encoder3.encode_image(images3[i*b: (i+1)*b]).cpu()) image_embeddings3 = torch.cat(image_embeddings3, dim=0) np.savez(feature_path3, image_embeddings3, duration) else: data3 = np.load(feature_path3) image_embeddings3 = torch.from_numpy(data3["arr_0"]) feature_path4 = feature_path.replace("/feature_mixed/", "/feature_dinov2/") if not os.path.exists(feature_path4): image_embeddings4 = [] for i in range(math.ceil(images4["pixel_values"].shape[0] / b)): image_embeddings4.append(self.image_encoder4(pixel_values=images4["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu()) image_embeddings4 = torch.cat(image_embeddings4, dim=0) np.savez(feature_path4, image_embeddings4, duration) else: data4 = np.load(feature_path4) image_embeddings4 = torch.from_numpy(data4["arr_0"]) mixed_l = min([image_embeddings1.shape[0], image_embeddings2.shape[0], image_embeddings3.shape[0], image_embeddings4.shape[0]]) for i in range(mixed_l): image_embeddings.append(torch.cat([image_embeddings1[i:i+1,:], image_embeddings2[i:i+1,:], image_embeddings3[i:i+1,:], image_embeddings4[i:i+1,:]], dim=1)) else: raise Exception("Invalid video_encoder " + self.video_encoder) image_embeddings = torch.cat(image_embeddings, dim=0) #print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean()) np.savez(feature_path, image_embeddings, duration) else: #print("video exist", feature_path) data = np.load(feature_path) image_embeddings = torch.from_numpy(data["arr_0"]) #print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean()) duration = data["arr_1"].item() if max_sample is None: max_sample = int(duration * self.sampling_rate) interpolated = [] for i in range(start_sample, max_sample, self.frame_size): j = min(round((i+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1) interpolated.append(image_embeddings[j:j+1]) if len(interpolated) >= l: break interpolated = torch.cat(interpolated, dim=0) #ll = list(range(start_sample, max_sample, self.frame_size)) #print("encode_video l", len(ll), l, round((ll[-1]+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1) #print("encode_video one", video_path, duration, image_embeddings.shape, interpolated.shape, l) video_embeddings.append(interpolated.unsqueeze(0)) video_lens.append(interpolated.shape[1]) max_length = max(video_lens) if max_length == 0: max_length = l else: max_length = l for i in range(len(video_embeddings)): if video_embeddings[i] is None: video_embeddings[i] = torch.zeros(1, max_length, d) continue if video_embeddings[i].shape[1] < max_length: video_embeddings[i] = torch.cat([video_embeddings[i], torch.zeros(1, max_length-video_embeddings[i].shape[1], d)], 1) video_embeddings = torch.cat(video_embeddings, 0) #print("encode_video", l, video_embeddings.shape, video_lens) return video_embeddings.to(device) @staticmethod def encode_video_frames(video_paths, l): #### skip video frames train_video_encoder = True if not train_video_encoder: midi_gts = [] for video_path in video_paths: if video_path is None: #midi_gts.append(None) continue if isinstance(video_path, tuple): video_path, start_sample, max_sample = video_path else: start_sample = 0 max_sample = None ####if video_path.startswith("/ailab-train2/speech/zhanghaomin/scps/instruments/"): if "/piano_2h_cropped2_cuts/" in video_path: pass else: #midi_gts.append(None) continue ####midi_gt midi_gt = torch.from_numpy(np.load(video_path.replace(".mp4", ".3.npy")).astype(np.float32))[:,NOTTE_MIN:NOTE_MAX+1] #print("midi_gt", midi_gt.shape, midi_gt.max(), midi_gt.min(), torch.abs(midi_gt).mean()) midi_gts.append(midi_gt.unsqueeze(0)) if len(midi_gts) == 0: return None, None max_length = l for i in range(len(midi_gts)): if midi_gts[i] is None: midi_gts[i] = torch.zeros(1, max_length, NOTES) continue if midi_gts[i].shape[1] < max_length: midi_gts[i] = torch.cat([midi_gts[i], torch.zeros(1, max_length-midi_gts[i].shape[1], NOTES)], 1) elif midi_gts[i].shape[1] > max_length: midi_gts[i] = midi_gts[i][:, :max_length, :] midi_gts = torch.cat(midi_gts, 0) video_frames = 1.0 #print("encode_video_frames", l, midi_gts.shape, midi_gts.sum()) return video_frames, midi_gts video_frames = [] video_lens = [] midi_gts = [] for video_path in video_paths: if video_path is None: #video_frames.append(None) video_lens.append(0) #midi_gts.append(None) continue if isinstance(video_path, tuple): video_path, start_sample, max_sample = video_path else: start_sample = 0 max_sample = None ####if video_path.startswith("/ailab-train2/speech/zhanghaomin/scps/instruments/"): if "/piano_2h_cropped2_cuts/" in video_path: frames_raw_path = video_path.replace(".mp4", ".generated_frames_raw.2.npz") if not os.path.exists(frames_raw_path): frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=None) if frames is None: #video_frames.append(None) video_lens.append(0) #midi_gts.append(None) continue #print("raw image size", frames.shape, video_path) frames_resized = [] for i in range(frames.shape[0]): ########frames_resized.append(np.asarray(Image.fromarray(frames[i]).resize((320, 320)))) ####frames_resized.append(np.asarray(Image.fromarray(frames[i]).resize((1024, 128)))) input_img = Image.fromarray(frames[i]).convert('L') binarr = np.array(input_img) input_img = Image.fromarray(binarr.astype(np.uint8)) frames_resized.append(transform(input_img)) ####frames_raw = np.array(frames_resized) frames_raw = np.concatenate(frames_resized).astype(np.float32)[...,np.newaxis] np.savez(frames_raw_path, frames_raw, duration) else: data = np.load(frames_raw_path) frames_raw = data["arr_0"] duration = data["arr_1"].item() ####frames_raw = frames_raw.astype(np.float32) / 255.0 #v_frames_raw = frames_raw[1:,:,:,:] - frames_raw[:-1,:,:,:] #v_frames_raw = np.concatenate((np.zeros((1,v_frames_raw.shape[1],v_frames_raw.shape[2],v_frames_raw.shape[3]), dtype=np.float32), v_frames_raw), axis=0) ##print("v_frames_raw", v_frames_raw.shape, v_frames_raw.max(), v_frames_raw.min(), np.abs(v_frames_raw).mean(), np.abs(v_frames_raw[0,:,:,:]).mean()) #frames_raw = np.concatenate((frames_raw, v_frames_raw), axis=3) frames_raw = torch.from_numpy(frames_raw) #print("frames_raw", frames_raw.shape, frames_raw.max(), frames_raw.min(), torch.abs(frames_raw).mean(), "image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean()) else: #video_frames.append(None) video_lens.append(0) #midi_gts.append(None) continue #print("frames_raw", frames_raw.shape, l) if max_sample is None: max_sample = int(duration * 24000) ####video_multi = 3.0 video_multi = 2.5 interpolated_frames_raw = [] frame_size_video = int(video_multi*320) for i in range(start_sample, max_sample+frame_size_video, frame_size_video): j = min(round(i / 24000 / (duration / (frames_raw.shape[0] - 0))), frames_raw.shape[0] - 1) #print(j) interpolated_frames_raw.append(frames_raw[j:j+1]) if len(interpolated_frames_raw) >= math.floor(l/video_multi)+1: #print("break", len(interpolated_frames_raw), l, frames_raw.shape, j) break interpolated_frames_raw = torch.cat(interpolated_frames_raw, dim=0) ####v_interpolated_frames_raw = interpolated_frames_raw[1:,:,:,:] - interpolated_frames_raw[:-1,:,:,:] ####v_interpolated_frames_raw = torch.cat((torch.zeros(1,v_interpolated_frames_raw.shape[1],v_interpolated_frames_raw.shape[2],v_interpolated_frames_raw.shape[3]), v_interpolated_frames_raw), 0) #####print("v_interpolated_frames_raw", v_interpolated_frames_raw.shape, v_interpolated_frames_raw.max(), v_interpolated_frames_raw.min(), torch.abs(v_interpolated_frames_raw).mean(), torch.abs(v_interpolated_frames_raw[0,:,:,:]).mean()) ####interpolated_frames_raw = torch.cat((interpolated_frames_raw, v_interpolated_frames_raw), 3) video_frames.append(interpolated_frames_raw.unsqueeze(0)) video_lens.append(interpolated_frames_raw.shape[0]) ####midi_gt ####midi_gt = torch.from_numpy(np.load(video_path.replace(".mp4", ".3.npy")).astype(np.float32))[:,NOTTE_MIN:NOTE_MAX+1] #####print("midi_gt", midi_gt.shape, midi_gt.max(), midi_gt.min(), torch.abs(midi_gt).mean()) ####midi_gts.append(midi_gt.unsqueeze(0)) midi_gts.append(None) if len(video_frames) == 0: return None, None max_length = max(video_lens) if max_length == 0: max_length = l else: max_length = l max_length_video = max(math.floor(l/video_multi)+1, max(video_lens)) for i in range(len(video_frames)): if video_frames[i] is None: ########video_frames[i] = torch.zeros(1, max_length_video, 320, 320, 6) ####video_frames[i] = torch.zeros(1, max_length_video, 128, 1024, 6) video_frames[i] = torch.zeros(1, max_length_video, 100, 900, 1) continue if video_frames[i].shape[1] < max_length_video: ########video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 320, 320, 6)], 1) ####video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 128, 1024, 6)], 1) video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 100, 900, 1)], 1) video_frames = torch.cat(video_frames, 0) video_frames = video_frames.permute(0,4,1,2,3) for i in range(len(midi_gts)): if midi_gts[i] is None: midi_gts[i] = torch.zeros(1, max_length, NOTES) continue if midi_gts[i].shape[1] < max_length: midi_gts[i] = torch.cat([midi_gts[i], torch.zeros(1, max_length-midi_gts[i].shape[1], NOTES)], 1) elif midi_gts[i].shape[1] > max_length: midi_gts[i] = midi_gts[i][:, :max_length, :] midi_gts = torch.cat(midi_gts, 0) #print("encode_video_frames", l, video_frames.shape, video_lens, midi_gts.shape, midi_gts.sum()) return video_frames, midi_gts def transformer_with_pred_head( self, x: Float['b n d'], cond: Float['b n d'] | None = None, times: Float['b'] | None = None, mask: Bool['b n'] | None = None, text: Int['b nt'] | Float['b nt dt'] | None = None, frames_embed: Float['b nf df'] | None = None, prompt = None, video_drop_prompt = None, audio_drop_prompt = None, drop_audio_cond: bool | None = None, drop_text_cond: bool | None = None, drop_text_prompt: bool | None = None, return_drop_conditions = False ): seq_len = x.shape[-2] bs = x.shape[0] drop_audio_cond = [default(drop_audio_cond, self.training and random() < self.audiocond_drop_prob) for _ in range(bs)] drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob) drop_text_prompt = [default(drop_text_prompt, self.training and random() < self.prompt_drop_prob) for _ in range(bs)] if cond is not None: for b in range(bs): if drop_audio_cond[b]: cond[b] = 0 if audio_drop_prompt is not None and audio_drop_prompt[b]: cond[b] = 0 if cond is not None: if self.concat_cond: # concat condition, given as using voicebox-like scheme x = torch.cat((cond, x), dim = -1) x = self.proj_in(x) if cond is not None: if not self.concat_cond: # an alternative is to simply sum the condition # seems to work fine cond = self.cond_proj_in(cond) x = x + cond # whether to use a text embedding text_embed = None if exists(text) and len(text.shape) == 3: text_embed = text.clone() if drop_text_cond: for b in range(bs): text_embed[b] = 0 elif exists(text) and not drop_text_cond: text_embed = self.embed_text(text, seq_len, mask = mask) context, context_mask = None, None if prompt is not None: #for b in range(bs): # if drop_text_prompt[b]: # prompt[b] = "" if video_drop_prompt is not None: for b in range(bs): if video_drop_prompt[b]: prompt[b] = "the sound of X X" context, context_mask = self.encode_text(prompt) for b in range(bs): if drop_text_prompt[b]: context[b] = 0 if video_drop_prompt is not None and video_drop_prompt[b]: context[b] = 0 #print("cross attention", context.shape, context_mask.shape, x.shape, mask.shape, text_embed.shape if text_embed is not None else None, torch.mean(torch.abs(text_embed), dim=(1,2))) #print("video_drop_prompt", prompt, video_drop_prompt, context.shape, torch.mean(torch.abs(context), dim=(1,2))) #print("audio_drop_prompt", audio_drop_prompt, cond.shape, torch.mean(torch.abs(cond), dim=(1,2))) if self.proj_text is not None: text_embed = self.proj_text(text_embed) frames_embed = self.proj_frames(frames_embed) # attend attended = self.transformer( x, times = times, mask = mask, text_embed = text_embed, frames_embed = frames_embed, context = context, context_mask = context_mask ) pred = self.to_pred(attended) if not return_drop_conditions: return pred return pred, drop_audio_cond, drop_text_cond, drop_text_prompt def cfg_transformer_with_pred_head( self, *args, cfg_strength: float = 1., remove_parallel_component: bool = True, keep_parallel_frac: float = 0., **kwargs, ): pred = self.transformer_with_pred_head(*args, drop_audio_cond = False, drop_text_cond = False, drop_text_prompt = False, **kwargs) if cfg_strength < 1e-5: return pred null_pred = self.transformer_with_pred_head(*args, drop_audio_cond = True, drop_text_cond = True, drop_text_prompt = True, **kwargs) cfg_update = pred - null_pred if remove_parallel_component: # https://arxiv.org/abs/2410.02416 parallel, orthogonal = project(cfg_update, pred) cfg_update = orthogonal + parallel * keep_parallel_frac return pred + cfg_update * cfg_strength def add_noise(self, signal, mask, val): if self.audiocond_snr is None: return signal if not val: snr = np.random.uniform(self.audiocond_snr[0], self.audiocond_snr[1]) else: snr = (self.audiocond_snr[0] + self.audiocond_snr[1]) / 2.0 #print("add_noise", self.audiocond_snr, snr, signal.shape, mask) # [True, ..., False] noise = torch.randn_like(signal) w = torch.abs(signal[mask]).mean() / (torch.abs(noise[mask]).mean() + 1e-6) / snr return signal + noise * w @torch.no_grad() def sample( self, cond: Float['b n d'] | Float['b nw'] | None = None, *, text: Int['b nt'] | list[str] | None = None, lens: Int['b'] | None = None, duration: int | Int['b'] | None = None, steps = 32, cfg_strength = 1., # they used a classifier free guidance strength of 1. remove_parallel_component = True, sway_sampling = True, max_duration = 4096, # in case the duration predictor goes haywire vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None, return_raw_output: bool | None = None, save_to_filename: str | None = None, prompt = None, video_drop_prompt = None, audio_drop_prompt = None, video_paths = None, frames = None, midis = None ) -> ( Float['b n d'], list[Float['_']] ): self.eval() # raw wave if cond.ndim == 2: cond = self.mel_spec(cond) cond = rearrange(cond, 'b d n -> b n d') assert cond.shape[-1] == self.num_channels batch, cond_seq_len, device = *cond.shape[:2], cond.device if frames is None: frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device) else: #### sampling settings train_video_encoder = True if train_video_encoder: frames_embed = self.encode_frames(frames, cond_seq_len) else: frames_embed = midis if frames_embed.shape[1] < cond_seq_len: frames_embed = torch.cat([frames_embed, torch.zeros(1, cond_seq_len-frames_embed.shape[1], NOTES)], 1) elif frames_embed.shape[1] > cond_seq_len: frames_embed = frames_embed[:, :cond_seq_len, :] #x0 = torch.zeros(batch, cond_seq_len, 128, device=device) print("frames_embed midis cond", frames_embed.shape if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, frames_embed.sum() if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, midis.shape if midis is not None else midis, midis.sum() if midis is not None else midis, cond.shape if cond is not None else cond, cond.sum() if cond is not None else cond) if not exists(lens): lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) if video_paths is not None: text = self.encode_video(video_paths, cond_seq_len) # text elif isinstance(text, list): text = self.tokenizer(text).to(device) assert text.shape[0] == batch if exists(text): text_lens = (text != -1).sum(dim = -1) lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters # duration cond_mask = lens_to_mask(lens) if exists(duration): if isinstance(duration, int): duration = torch.full((batch,), duration, device = device, dtype = torch.long) elif exists(self.duration_predictor): duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long() duration = torch.maximum(lens, duration) # just add one token so something is generated duration = duration.clamp(max = max_duration) assert duration.shape[0] == batch max_duration = duration.amax() cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) cond_mask = rearrange(cond_mask, '... -> ... 1') mask = lens_to_mask(duration) #print("mask", duration, mask, mask.shape, lens, cond_mask, cond_mask.shape, text) # neural ode def fn(t, x): # at each step, conditioning is fixed if lens[0] == duration[0]: print("No cond", lens, duration) step_cond = None else: step_cond = torch.where(cond_mask, self.add_noise(cond, cond_mask, True), torch.zeros_like(cond)) #step_cond = cond # predict flow return self.cfg_transformer_with_pred_head( x, step_cond, times = t, text = text, frames_embed = frames_embed, mask = mask, prompt = prompt, video_drop_prompt = video_drop_prompt, audio_drop_prompt = audio_drop_prompt, cfg_strength = cfg_strength, remove_parallel_component = remove_parallel_component ) ####torch.manual_seed(0) y0 = torch.randn_like(cond) #y0 = torch.randn_like(x0) t = torch.linspace(0, 1, steps, device = self.device) if sway_sampling: t = t + -1.0 * (torch.cos(torch.pi / 2 * t) - 1 + t) #print("@@@@", t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) sampled = trajectory[-1] out = sampled if lens[0] != duration[0]: out = torch.where(cond_mask, cond, out) # able to return raw untransformed output, if not using mel rep if exists(return_raw_output) and return_raw_output: return out # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on if exists(vocoder): assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling' out = rearrange(out, 'b n d -> b d n') out = vocoder(out) elif exists(self.vocos): audio = [] for mel, one_mask in zip(out, mask): #one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5) one_out = mel[one_mask] one_out = rearrange(one_out, 'n d -> 1 d n') one_audio = self.vocos.decode(one_out) one_audio = rearrange(one_audio, '1 nw -> nw') audio.append(one_audio) out = audio if exists(save_to_filename): assert exists(vocoder) or exists(self.vocos) assert exists(self.sampling_rate) path = Path(save_to_filename) parent_path = path.parents[0] parent_path.mkdir(exist_ok = True, parents = True) for ind, one_audio in enumerate(out): one_audio = rearrange(one_audio, 'nw -> 1 nw') if len(out) == 1: save_path = str(parent_path / f'{path.name}') else: save_path = str(parent_path / f'{ind + 1}.{path.name}') torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate) return out def forward( self, inp: Float['b n d'] | Float['b nw'], # mel or raw wave *, text: Int['b nt'] | list[str] | None = None, times: int | Int['b'] | None = None, lens: Int['b'] | None = None, velocity_consistency_model: E2TTS | None = None, velocity_consistency_delta = 1e-3, prompt = None, video_drop_prompt=None, audio_drop_prompt=None, val = False, video_paths=None, frames=None, midis=None ): need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0. # handle raw wave if inp.ndim == 2: inp = self.mel_spec(inp) inp = rearrange(inp, 'b d n -> b n d') assert inp.shape[-1] == self.num_channels batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device if video_paths is not None: text = self.encode_video(video_paths, seq_len) # handle text as string elif isinstance(text, list): text = self.tokenizer(text).to(device) #print("text tokenized", text[0]) assert text.shape[0] == batch # lens and mask if not exists(lens): lens = torch.full((batch,), seq_len, device = device) mask = lens_to_mask(lens, length = seq_len) # get a random span to mask out for training conditionally if not val: if self.audiocond_drop_prob > 1.0: frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(1.0,1.0) else: frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) else: frac_lengths = torch.tensor([(0.7+1.0)/2.0]*batch, device = self.device).float() rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len, val = val) if exists(mask): rand_span_mask &= mask # mel is x1 x1 = inp # main conditional flow training logic # just ~5 loc # x0 is gaussian noise if val: torch.manual_seed(0) x0 = torch.randn_like(x1) if val: torch.manual_seed(int(time.time()*1000)) # t is random times from above if times is None: times = torch.rand((batch,), dtype = dtype, device = self.device) else: times = torch.tensor((times,)*batch, dtype = dtype, device = self.device) t = rearrange(times, 'b -> b 1 1') # if need velocity consistency, make sure time does not exceed 1. if need_velocity_loss: t = t * (1. - velocity_consistency_delta) # sample xt (w in the paper) w = (1. - t) * x0 + t * x1 flow = x1 - x0 # only predict what is within the random mask span for infilling if self.audiocond_drop_prob > 1.0: cond = None else: cond = einx.where( 'b n, b n d, b n d -> b n d', rand_span_mask, torch.zeros_like(x1), self.add_noise(x1, ~rand_span_mask, val) ) #### training settings train_video_encoder = True train_v2a = True use_midi_gt = False train_video_encoder = train_video_encoder ####train_v2a = train_v2a or val #print("train_video_encoder", train_video_encoder, use_midi_gt, train_v2a) #### if frames is None: frames_embed = torch.zeros(batch, seq_len, NOTES, device=device) midis = torch.zeros(batch, seq_len, NOTES, device=device) else: if train_video_encoder: frames_embed = self.encode_frames(frames, seq_len) else: frames_embed = midis #print("frames_embed midis cond", frames_embed.shape if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, frames_embed.sum() if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, midis.shape, midis.sum(), cond.shape if cond is not None else cond, cond.sum() if cond is not None else cond, x1.shape) if train_video_encoder: #lw = 1.0 lw = torch.abs(midis-0.10) #lw = torch.max(torch.abs(midis-0.20), torch.tensor(0.20)) loss_midi = F.mse_loss(frames_embed, midis, reduction = 'none') * lw #loss_midi = nn.BCELoss(reduction = 'none')(frames_embed, midis) * lw #print("loss_midi", loss_midi.shape, mask.shape, mask, rand_span_mask.shape, rand_span_mask) loss_midi = loss_midi[mask[-frames_embed.shape[0]:,...]].mean() b, t, f = frames_embed.shape frames_embed_t = frames_embed[:,:(t//3)*3,:].reshape(b,t//3,3,f).mean(dim=2) midis_t = midis[:,:(t//3)*3,:].reshape(b,t//3,3,f).mean(dim=2) mask_t = mask[-frames_embed.shape[0]:,:(t//3)*3].reshape(b,t//3,3).to(torch.float32).mean(dim=2) >= 0.99 tp = ((frames_embed_t>=0.4)*(midis_t>=0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum() fp = ((frames_embed_t>=0.4)*(midis_t<0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum() fn = ((frames_embed_t<0.4)*(midis_t>=0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum() tn = ((frames_embed_t<0.4)*(midis_t<0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum() #print("tp fp fn tn", tp, fp, fn, tn) pre = tp / (tp + fp) if (tp + fp) != 0 else torch.tensor(0.0, device=device) rec = tp / (tp + fn) if (tp + fn) != 0 else torch.tensor(0.0, device=device) f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else torch.tensor(0.0, device=device) acc = tp / (tp + fp + fn) if (tp + fp + fn) != 0 else torch.tensor(0.0, device=device) else: loss_midi = torch.tensor(0.0, device=device) tp = torch.tensor(0.0, device=device) fp = torch.tensor(0.0, device=device) fn = torch.tensor(0.0, device=device) tn = torch.tensor(0.0, device=device) pre = torch.tensor(0.0, device=device) rec = torch.tensor(0.0, device=device) f1 = torch.tensor(0.0, device=device) acc = torch.tensor(0.0, device=device) #if train_video_encoder: # loss_midi_zeros * 100.0 # 0.2131/0.1856 # 0.2819/0.2417 # 2.451/ # loss_midi_zeros = F.mse_loss(torch.zeros_like(midis), midis, reduction = 'none') # loss_midi_zeros = loss_midi_zeros[mask[-frames_embed.shape[0]:,...]].mean() #else: # loss_midi_zeros = torch.tensor(0.0, device=device) if train_v2a: if use_midi_gt: frames_embed = midis if frames_embed.shape[0] < x1.shape[0]: frames_embed = torch.cat((torch.zeros(x1.shape[0]-frames_embed.shape[0],frames_embed.shape[1],frames_embed.shape[2],device=frames_embed.device), frames_embed), 0) # transformer and prediction head if not val: pred, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head( w, cond, times = times, text = text, frames_embed = frames_embed, mask = mask, prompt = prompt, video_drop_prompt = video_drop_prompt, audio_drop_prompt = audio_drop_prompt, return_drop_conditions = True ) else: pred, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head( w, cond, times = times, text = text, frames_embed = frames_embed, mask = mask, prompt = prompt, video_drop_prompt = video_drop_prompt, audio_drop_prompt = audio_drop_prompt, drop_audio_cond = False, drop_text_cond = False, drop_text_prompt = False, return_drop_conditions = True ) # maybe velocity consistency loss velocity_loss = self.zero if need_velocity_loss: #t_with_delta = t + velocity_consistency_delta #w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1 with torch.no_grad(): ema_pred = velocity_consistency_model.transformer_with_pred_head( w, #w_with_delta, cond, times = times, #times + velocity_consistency_delta, text = text, frames_embed = frames_embed, mask = mask, prompt = prompt, video_drop_prompt = video_drop_prompt, audio_drop_prompt = audio_drop_prompt, drop_audio_cond = did_drop_audio_cond, drop_text_cond = did_drop_text_cond, drop_text_prompt = did_drop_text_prompt ) #velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none') velocity_loss = F.mse_loss(ema_pred, flow, reduction = 'none') velocity_loss = (velocity_loss.mean(-1)*rand_span_mask).mean(-1) #.mean() ref_losses = velocity_loss[-2:, ...] ref_losses_w, ref_losses_l = ref_losses.chunk(2) raw_ref_loss = 0.5 * (ref_losses_w.mean() + ref_losses_l.mean()) ref_diff = ref_losses_w - ref_losses_l else: ref_losses_w, ref_losses_l = 0, 0 # flow matching loss loss = F.mse_loss(pred, flow, reduction = 'none') #print("loss", loss.shape, loss, "rand_span_mask", rand_span_mask.shape, rand_span_mask, "loss[rand_span_mask]", loss[rand_span_mask].shape, loss[rand_span_mask]) #### dpo loss = loss[rand_span_mask].mean() loss_dpo = torch.tensor(0.0, device=device) ####if val: #### loss = loss[rand_span_mask].mean() #### loss_dpo = torch.tensor(0.0, device=device) #### model_losses_w, model_losses_l = 0, 0 ####else: #### loss_fm = loss[rand_span_mask].mean() #### loss = (loss.mean(-1)*rand_span_mask).mean(-1) #.mean() #### #### model_losses = loss[-2:, ...] #### model_losses_w, model_losses_l = model_losses.chunk(2) #### raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean()) #### model_diff = model_losses_w - model_losses_l #### #### scale_term = -1 #### inside_term = scale_term * (model_diff - ref_diff) #### loss_dpo = -1 * F.logsigmoid(inside_term).mean() #### loss = loss_fm #### dpo else: pred = torch.zeros_like(x0) loss = torch.tensor(0.0, device=device) # total loss and get breakdown #midi_w = 100.0 midi_w = 10.0 #total_loss = loss #total_loss = loss + loss_midi * midi_w total_loss = loss + loss_midi * midi_w + loss_dpo ####breakdown = LossBreakdown(loss, loss_midi * midi_w, pre, rec) breakdown = LossBreakdown(pre, rec, f1, acc) #breakdown = LossBreakdown(tp, fp, fn, tn) #### dpo print("loss", loss, loss_midi * midi_w) #print("loss", loss, loss_midi * midi_w, loss_dpo, model_losses_w, model_losses_l, ref_losses_w, ref_losses_l) #### dpo # return total loss and bunch of intermediates return E2TTSReturn(total_loss, cond if cond is not None else w, pred, x0 + pred, breakdown)