Spaces:
Running
Running
""" | |
ein notation: | |
b - batch | |
n - sequence | |
nt - text sequence | |
nw - raw wave length | |
d - dimension | |
dt - dimension text | |
""" | |
from __future__ import annotations | |
from pathlib import Path | |
from random import random | |
from random import sample as random_sample | |
from functools import partial | |
from itertools import zip_longest | |
from collections import namedtuple | |
from typing import Literal, Callable | |
import jaxtyping | |
from beartype import beartype | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, tensor, Tensor, from_numpy | |
from torch.nn import Module, ModuleList, Sequential, Linear | |
from torch.nn.utils.rnn import pad_sequence | |
import torchaudio | |
from torchaudio.functional import DB_to_amplitude | |
from torchdiffeq import odeint | |
import einx | |
from einops.layers.torch import Rearrange | |
from einops import rearrange, repeat, reduce, pack, unpack | |
from x_transformers import ( | |
Attention, | |
FeedForward, | |
RMSNorm, | |
AdaptiveRMSNorm, | |
) | |
from x_transformers.x_transformers import RotaryEmbedding | |
import sys | |
sys.path.insert(0, "/zhanghaomin/codes3/vocos-main/") | |
from vocos import Vocos | |
from transformers import AutoTokenizer | |
from transformers import T5EncoderModel | |
from transformers import EncodecModel, AutoProcessor | |
import os | |
import math | |
import traceback | |
import numpy as np | |
from moviepy.editor import AudioFileClip, VideoFileClip | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
#import open_clip | |
from transformers import AutoImageProcessor, AutoModel | |
from PIL import Image | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") | |
def normalize_wav(waveform): | |
waveform = waveform - torch.mean(waveform) | |
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) | |
return waveform * 0.5 | |
def read_frames_with_moviepy(video_path, max_frame_nums=None): | |
try: | |
clip = VideoFileClip(video_path) | |
duration = clip.duration | |
frames = [] | |
for frame in clip.iter_frames(): | |
frames.append(frame) | |
except: | |
print("Error read_frames_with_moviepy", video_path) | |
traceback.print_exc() | |
return None, None | |
if max_frame_nums is not None: | |
frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int) | |
return np.array(frames)[frames_idx, ...], duration | |
else: | |
return np.array(frames), duration | |
pad_sequence = partial(pad_sequence, batch_first = True) | |
# constants | |
class TorchTyping: | |
def __init__(self, abstract_dtype): | |
self.abstract_dtype = abstract_dtype | |
def __getitem__(self, shapes: str): | |
return self.abstract_dtype[Tensor, shapes] | |
Float = TorchTyping(jaxtyping.Float) | |
Int = TorchTyping(jaxtyping.Int) | |
Bool = TorchTyping(jaxtyping.Bool) | |
# named tuples | |
LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency']) | |
E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown']) | |
# helpers | |
def exists(v): | |
return v is not None | |
def default(v, d): | |
return v if exists(v) else d | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def pack_one_with_inverse(x, pattern): | |
packed, packed_shape = pack([x], pattern) | |
def inverse(x, inverse_pattern = None): | |
inverse_pattern = default(inverse_pattern, pattern) | |
return unpack(x, packed_shape, inverse_pattern)[0] | |
return packed, inverse | |
class Identity(Module): | |
def forward(self, x, **kwargs): | |
return x | |
# tensor helpers | |
def project(x, y): | |
x, inverse = pack_one_with_inverse(x, 'b *') | |
y, _ = pack_one_with_inverse(y, 'b *') | |
dtype = x.dtype | |
x, y = x.double(), y.double() | |
unit = F.normalize(y, dim = -1) | |
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit | |
orthogonal = x - parallel | |
return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) | |
# simple utf-8 tokenizer, since paper went character based | |
def list_str_to_tensor( | |
text: list[str], | |
padding_value = -1 | |
) -> Int['b nt']: | |
list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text] | |
padded_tensor = pad_sequence(list_tensors, padding_value = -1) | |
return padded_tensor | |
# simple english phoneme-based tokenizer | |
from g2p_en import G2p | |
import jieba | |
from pypinyin import lazy_pinyin, Style | |
def get_g2p_en_encode(): | |
g2p = G2p() | |
# used by @lucasnewman successfully here | |
# https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py | |
phoneme_to_index = g2p.p2idx | |
num_phonemes = len(phoneme_to_index) | |
extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..'] | |
num_extended_chars = len(extended_chars) | |
extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)} | |
phoneme_to_index = {**phoneme_to_index, **extended_chars_dict} | |
def encode( | |
text: list[str], | |
padding_value = -1 | |
) -> Int['b nt']: | |
phonemes = [g2p(t) for t in text] | |
list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] | |
padded_tensor = pad_sequence(list_tensors, padding_value = -1) | |
return padded_tensor | |
return encode, (num_phonemes + num_extended_chars) | |
def all_en(word): | |
res = word.replace("'", "").encode('utf-8').isalpha() | |
return res | |
def all_ch(word): | |
res = True | |
for w in word: | |
if not '\u4e00' <= w <= '\u9fff': | |
res = False | |
return res | |
def get_g2p_zh_encode(): | |
puncs = [',', '。', '?', '、'] | |
pinyins = ['a', 'a1', 'ai1', 'ai2', 'ai3', 'ai4', 'an1', 'an3', 'an4', 'ang1', 'ang2', 'ang4', 'ao1', 'ao2', 'ao3', 'ao4', 'ba', 'ba1', 'ba2', 'ba3', 'ba4', 'bai1', 'bai2', 'bai3', 'bai4', 'ban1', 'ban2', 'ban3', 'ban4', 'bang1', 'bang2', 'bang3', 'bang4', 'bao1', 'bao2', 'bao3', 'bao4', 'bei', 'bei1', 'bei2', 'bei3', 'bei4', 'ben1', 'ben2', 'ben3', 'ben4', 'beng1', 'beng2', 'beng4', 'bi1', 'bi2', 'bi3', 'bi4', 'bian1', 'bian2', 'bian3', 'bian4', 'biao1', 'biao2', 'biao3', 'bie1', 'bie2', 'bie3', 'bie4', 'bin1', 'bin4', 'bing1', 'bing2', 'bing3', 'bing4', 'bo', 'bo1', 'bo2', 'bo3', 'bo4', 'bu2', 'bu3', 'bu4', 'ca1', 'cai1', 'cai2', 'cai3', 'cai4', 'can1', 'can2', 'can3', 'can4', 'cang1', 'cang2', 'cao1', 'cao2', 'cao3', 'ce4', 'cen1', 'cen2', 'ceng1', 'ceng2', 'ceng4', 'cha1', 'cha2', 'cha3', 'cha4', 'chai1', 'chai2', 'chan1', 'chan2', 'chan3', 'chan4', 'chang1', 'chang2', 'chang3', 'chang4', 'chao1', 'chao2', 'chao3', 'che1', 'che2', 'che3', 'che4', 'chen1', 'chen2', 'chen3', 'chen4', 'cheng1', 'cheng2', 'cheng3', 'cheng4', 'chi1', 'chi2', 'chi3', 'chi4', 'chong1', 'chong2', 'chong3', 'chong4', 'chou1', 'chou2', 'chou3', 'chou4', 'chu1', 'chu2', 'chu3', 'chu4', 'chua1', 'chuai1', 'chuai2', 'chuai3', 'chuai4', 'chuan1', 'chuan2', 'chuan3', 'chuan4', 'chuang1', 'chuang2', 'chuang3', 'chuang4', 'chui1', 'chui2', 'chun1', 'chun2', 'chun3', 'chuo1', 'chuo4', 'ci1', 'ci2', 'ci3', 'ci4', 'cong1', 'cong2', 'cou4', 'cu1', 'cu4', 'cuan1', 'cuan2', 'cuan4', 'cui1', 'cui3', 'cui4', 'cun1', 'cun2', 'cun4', 'cuo1', 'cuo2', 'cuo4', 'da', 'da1', 'da2', 'da3', 'da4', 'dai1', 'dai3', 'dai4', 'dan1', 'dan2', 'dan3', 'dan4', 'dang1', 'dang2', 'dang3', 'dang4', 'dao1', 'dao2', 'dao3', 'dao4', 'de', 'de1', 'de2', 'dei3', 'den4', 'deng1', 'deng2', 'deng3', 'deng4', 'di1', 'di2', 'di3', 'di4', 'dia3', 'dian1', 'dian2', 'dian3', 'dian4', 'diao1', 'diao3', 'diao4', 'die1', 'die2', 'ding1', 'ding2', 'ding3', 'ding4', 'diu1', 'dong1', 'dong3', 'dong4', 'dou1', 'dou2', 'dou3', 'dou4', 'du1', 'du2', 'du3', 'du4', 'duan1', 'duan2', 'duan3', 'duan4', 'dui1', 'dui4', 'dun1', 'dun3', 'dun4', 'duo1', 'duo2', 'duo3', 'duo4', 'e1', 'e2', 'e3', 'e4', 'ei2', 'en1', 'en4', 'er', 'er2', 'er3', 'er4', 'fa1', 'fa2', 'fa3', 'fa4', 'fan1', 'fan2', 'fan3', 'fan4', 'fang1', 'fang2', 'fang3', 'fang4', 'fei1', 'fei2', 'fei3', 'fei4', 'fen1', 'fen2', 'fen3', 'fen4', 'feng1', 'feng2', 'feng3', 'feng4', 'fo2', 'fou2', 'fou3', 'fu1', 'fu2', 'fu3', 'fu4', 'ga1', 'ga2', 'ga4', 'gai1', 'gai3', 'gai4', 'gan1', 'gan2', 'gan3', 'gan4', 'gang1', 'gang2', 'gang3', 'gang4', 'gao1', 'gao2', 'gao3', 'gao4', 'ge1', 'ge2', 'ge3', 'ge4', 'gei2', 'gei3', 'gen1', 'gen2', 'gen3', 'gen4', 'geng1', 'geng3', 'geng4', 'gong1', 'gong3', 'gong4', 'gou1', 'gou2', 'gou3', 'gou4', 'gu', 'gu1', 'gu2', 'gu3', 'gu4', 'gua1', 'gua2', 'gua3', 'gua4', 'guai1', 'guai2', 'guai3', 'guai4', 'guan1', 'guan2', 'guan3', 'guan4', 'guang1', 'guang2', 'guang3', 'guang4', 'gui1', 'gui2', 'gui3', 'gui4', 'gun3', 'gun4', 'guo1', 'guo2', 'guo3', 'guo4', 'ha1', 'ha2', 'ha3', 'hai1', 'hai2', 'hai3', 'hai4', 'han1', 'han2', 'han3', 'han4', 'hang1', 'hang2', 'hang4', 'hao1', 'hao2', 'hao3', 'hao4', 'he1', 'he2', 'he4', 'hei1', 'hen2', 'hen3', 'hen4', 'heng1', 'heng2', 'heng4', 'hong1', 'hong2', 'hong3', 'hong4', 'hou1', 'hou2', 'hou3', 'hou4', 'hu1', 'hu2', 'hu3', 'hu4', 'hua1', 'hua2', 'hua4', 'huai2', 'huai4', 'huan1', 'huan2', 'huan3', 'huan4', 'huang1', 'huang2', 'huang3', 'huang4', 'hui1', 'hui2', 'hui3', 'hui4', 'hun1', 'hun2', 'hun4', 'huo', 'huo1', 'huo2', 'huo3', 'huo4', 'ji1', 'ji2', 'ji3', 'ji4', 'jia', 'jia1', 'jia2', 'jia3', 'jia4', 'jian1', 'jian2', 'jian3', 'jian4', 'jiang1', 'jiang2', 'jiang3', 'jiang4', 'jiao1', 'jiao2', 'jiao3', 'jiao4', 'jie1', 'jie2', 'jie3', 'jie4', 'jin1', 'jin2', 'jin3', 'jin4', 'jing1', 'jing2', 'jing3', 'jing4', 'jiong3', 'jiu1', 'jiu2', 'jiu3', 'jiu4', 'ju1', 'ju2', 'ju3', 'ju4', 'juan1', 'juan2', 'juan3', 'juan4', 'jue1', 'jue2', 'jue4', 'jun1', 'jun4', 'ka1', 'ka2', 'ka3', 'kai1', 'kai2', 'kai3', 'kai4', 'kan1', 'kan2', 'kan3', 'kan4', 'kang1', 'kang2', 'kang4', 'kao2', 'kao3', 'kao4', 'ke1', 'ke2', 'ke3', 'ke4', 'ken3', 'keng1', 'kong1', 'kong3', 'kong4', 'kou1', 'kou2', 'kou3', 'kou4', 'ku1', 'ku2', 'ku3', 'ku4', 'kua1', 'kua3', 'kua4', 'kuai3', 'kuai4', 'kuan1', 'kuan2', 'kuan3', 'kuang1', 'kuang2', 'kuang4', 'kui1', 'kui2', 'kui3', 'kui4', 'kun1', 'kun3', 'kun4', 'kuo4', 'la', 'la1', 'la2', 'la3', 'la4', 'lai2', 'lai4', 'lan2', 'lan3', 'lan4', 'lang1', 'lang2', 'lang3', 'lang4', 'lao1', 'lao2', 'lao3', 'lao4', 'le', 'le1', 'le4', 'lei', 'lei1', 'lei2', 'lei3', 'lei4', 'leng1', 'leng2', 'leng3', 'leng4', 'li', 'li1', 'li2', 'li3', 'li4', 'lia3', 'lian2', 'lian3', 'lian4', 'liang2', 'liang3', 'liang4', 'liao1', 'liao2', 'liao3', 'liao4', 'lie1', 'lie2', 'lie3', 'lie4', 'lin1', 'lin2', 'lin3', 'lin4', 'ling2', 'ling3', 'ling4', 'liu1', 'liu2', 'liu3', 'liu4', 'long1', 'long2', 'long3', 'long4', 'lou1', 'lou2', 'lou3', 'lou4', 'lu1', 'lu2', 'lu3', 'lu4', 'luan2', 'luan3', 'luan4', 'lun1', 'lun2', 'lun4', 'luo1', 'luo2', 'luo3', 'luo4', 'lv2', 'lv3', 'lv4', 'lve3', 'lve4', 'ma', 'ma1', 'ma2', 'ma3', 'ma4', 'mai2', 'mai3', 'mai4', 'man2', 'man3', 'man4', 'mang2', 'mang3', 'mao1', 'mao2', 'mao3', 'mao4', 'me', 'mei2', 'mei3', 'mei4', 'men', 'men1', 'men2', 'men4', 'meng1', 'meng2', 'meng3', 'meng4', 'mi1', 'mi2', 'mi3', 'mi4', 'mian2', 'mian3', 'mian4', 'miao1', 'miao2', 'miao3', 'miao4', 'mie1', 'mie4', 'min2', 'min3', 'ming2', 'ming3', 'ming4', 'miu4', 'mo1', 'mo2', 'mo3', 'mo4', 'mou1', 'mou2', 'mou3', 'mu2', 'mu3', 'mu4', 'n2', 'na1', 'na2', 'na3', 'na4', 'nai2', 'nai3', 'nai4', 'nan1', 'nan2', 'nan3', 'nan4', 'nang1', 'nang2', 'nao1', 'nao2', 'nao3', 'nao4', 'ne', 'ne2', 'ne4', 'nei3', 'nei4', 'nen4', 'neng2', 'ni1', 'ni2', 'ni3', 'ni4', 'nian1', 'nian2', 'nian3', 'nian4', 'niang2', 'niang4', 'niao2', 'niao3', 'niao4', 'nie1', 'nie4', 'nin2', 'ning2', 'ning3', 'ning4', 'niu1', 'niu2', 'niu3', 'niu4', 'nong2', 'nong4', 'nou4', 'nu2', 'nu3', 'nu4', 'nuan3', 'nuo2', 'nuo4', 'nv2', 'nv3', 'nve4', 'o1', 'o2', 'ou1', 'ou3', 'ou4', 'pa1', 'pa2', 'pa4', 'pai1', 'pai2', 'pai3', 'pai4', 'pan1', 'pan2', 'pan4', 'pang1', 'pang2', 'pang4', 'pao1', 'pao2', 'pao3', 'pao4', 'pei1', 'pei2', 'pei4', 'pen1', 'pen2', 'pen4', 'peng1', 'peng2', 'peng3', 'peng4', 'pi1', 'pi2', 'pi3', 'pi4', 'pian1', 'pian2', 'pian4', 'piao1', 'piao2', 'piao3', 'piao4', 'pie1', 'pie2', 'pie3', 'pin1', 'pin2', 'pin3', 'pin4', 'ping1', 'ping2', 'po1', 'po2', 'po3', 'po4', 'pou1', 'pu1', 'pu2', 'pu3', 'pu4', 'qi1', 'qi2', 'qi3', 'qi4', 'qia1', 'qia3', 'qia4', 'qian1', 'qian2', 'qian3', 'qian4', 'qiang1', 'qiang2', 'qiang3', 'qiang4', 'qiao1', 'qiao2', 'qiao3', 'qiao4', 'qie1', 'qie2', 'qie3', 'qie4', 'qin1', 'qin2', 'qin3', 'qin4', 'qing1', 'qing2', 'qing3', 'qing4', 'qiong1', 'qiong2', 'qiu1', 'qiu2', 'qiu3', 'qu1', 'qu2', 'qu3', 'qu4', 'quan1', 'quan2', 'quan3', 'quan4', 'que1', 'que2', 'que4', 'qun2', 'ran2', 'ran3', 'rang1', 'rang2', 'rang3', 'rang4', 'rao2', 'rao3', 'rao4', 're2', 're3', 're4', 'ren2', 'ren3', 'ren4', 'reng1', 'reng2', 'ri4', 'rong1', 'rong2', 'rong3', 'rou2', 'rou4', 'ru2', 'ru3', 'ru4', 'ruan2', 'ruan3', 'rui3', 'rui4', 'run4', 'ruo4', 'sa1', 'sa2', 'sa3', 'sa4', 'sai1', 'sai4', 'san1', 'san2', 'san3', 'san4', 'sang1', 'sang3', 'sang4', 'sao1', 'sao2', 'sao3', 'sao4', 'se4', 'sen1', 'seng1', 'sha1', 'sha2', 'sha3', 'sha4', 'shai1', 'shai2', 'shai3', 'shai4', 'shan1', 'shan3', 'shan4', 'shang', 'shang1', 'shang3', 'shang4', 'shao1', 'shao2', 'shao3', 'shao4', 'she1', 'she2', 'she3', 'she4', 'shei2', 'shen1', 'shen2', 'shen3', 'shen4', 'sheng1', 'sheng2', 'sheng3', 'sheng4', 'shi', 'shi1', 'shi2', 'shi3', 'shi4', 'shou1', 'shou2', 'shou3', 'shou4', 'shu1', 'shu2', 'shu3', 'shu4', 'shua1', 'shua2', 'shua3', 'shua4', 'shuai1', 'shuai3', 'shuai4', 'shuan1', 'shuan4', 'shuang1', 'shuang3', 'shui2', 'shui3', 'shui4', 'shun3', 'shun4', 'shuo1', 'shuo4', 'si1', 'si2', 'si3', 'si4', 'song1', 'song3', 'song4', 'sou1', 'sou3', 'sou4', 'su1', 'su2', 'su4', 'suan1', 'suan4', 'sui1', 'sui2', 'sui3', 'sui4', 'sun1', 'sun3', 'suo', 'suo1', 'suo2', 'suo3', 'ta1', 'ta3', 'ta4', 'tai1', 'tai2', 'tai4', 'tan1', 'tan2', 'tan3', 'tan4', 'tang1', 'tang2', 'tang3', 'tang4', 'tao1', 'tao2', 'tao3', 'tao4', 'te4', 'teng2', 'ti1', 'ti2', 'ti3', 'ti4', 'tian1', 'tian2', 'tian3', 'tiao1', 'tiao2', 'tiao3', 'tiao4', 'tie1', 'tie2', 'tie3', 'tie4', 'ting1', 'ting2', 'ting3', 'tong1', 'tong2', 'tong3', 'tong4', 'tou', 'tou1', 'tou2', 'tou4', 'tu1', 'tu2', 'tu3', 'tu4', 'tuan1', 'tuan2', 'tui1', 'tui2', 'tui3', 'tui4', 'tun1', 'tun2', 'tun4', 'tuo1', 'tuo2', 'tuo3', 'tuo4', 'wa', 'wa1', 'wa2', 'wa3', 'wa4', 'wai1', 'wai3', 'wai4', 'wan1', 'wan2', 'wan3', 'wan4', 'wang1', 'wang2', 'wang3', 'wang4', 'wei1', 'wei2', 'wei3', 'wei4', 'wen1', 'wen2', 'wen3', 'wen4', 'weng1', 'weng4', 'wo1', 'wo3', 'wo4', 'wu1', 'wu2', 'wu3', 'wu4', 'xi1', 'xi2', 'xi3', 'xi4', 'xia1', 'xia2', 'xia4', 'xian1', 'xian2', 'xian3', 'xian4', 'xiang1', 'xiang2', 'xiang3', 'xiang4', 'xiao1', 'xiao2', 'xiao3', 'xiao4', 'xie1', 'xie2', 'xie3', 'xie4', 'xin1', 'xin2', 'xin4', 'xing1', 'xing2', 'xing3', 'xing4', 'xiong1', 'xiong2', 'xiu1', 'xiu3', 'xiu4', 'xu', 'xu1', 'xu2', 'xu3', 'xu4', 'xuan1', 'xuan2', 'xuan3', 'xuan4', 'xue1', 'xue2', 'xue3', 'xue4', 'xun1', 'xun2', 'xun4', 'ya', 'ya1', 'ya2', 'ya3', 'ya4', 'yan1', 'yan2', 'yan3', 'yan4', 'yang1', 'yang2', 'yang3', 'yang4', 'yao1', 'yao2', 'yao3', 'yao4', 'ye1', 'ye2', 'ye3', 'ye4', 'yi1', 'yi2', 'yi3', 'yi4', 'yin1', 'yin2', 'yin3', 'yin4', 'ying1', 'ying2', 'ying3', 'ying4', 'yo1', 'yong1', 'yong3', 'yong4', 'you1', 'you2', 'you3', 'you4', 'yu1', 'yu2', 'yu3', 'yu4', 'yuan1', 'yuan2', 'yuan3', 'yuan4', 'yue1', 'yue4', 'yun1', 'yun2', 'yun3', 'yun4', 'za1', 'za2', 'za3', 'zai1', 'zai3', 'zai4', 'zan1', 'zan2', 'zan3', 'zan4', 'zang1', 'zang4', 'zao1', 'zao2', 'zao3', 'zao4', 'ze2', 'ze4', 'zei2', 'zen3', 'zeng1', 'zeng4', 'zha1', 'zha2', 'zha3', 'zha4', 'zhai1', 'zhai2', 'zhai3', 'zhai4', 'zhan1', 'zhan2', 'zhan3', 'zhan4', 'zhang1', 'zhang2', 'zhang3', 'zhang4', 'zhao1', 'zhao2', 'zhao3', 'zhao4', 'zhe', 'zhe1', 'zhe2', 'zhe3', 'zhe4', 'zhen1', 'zhen2', 'zhen3', 'zhen4', 'zheng1', 'zheng2', 'zheng3', 'zheng4', 'zhi1', 'zhi2', 'zhi3', 'zhi4', 'zhong1', 'zhong2', 'zhong3', 'zhong4', 'zhou1', 'zhou2', 'zhou3', 'zhou4', 'zhu1', 'zhu2', 'zhu3', 'zhu4', 'zhua1', 'zhua2', 'zhua3', 'zhuai1', 'zhuai3', 'zhuai4', 'zhuan1', 'zhuan2', 'zhuan3', 'zhuan4', 'zhuang1', 'zhuang4', 'zhui1', 'zhui4', 'zhun1', 'zhun2', 'zhun3', 'zhuo1', 'zhuo2', 'zi', 'zi1', 'zi2', 'zi3', 'zi4', 'zong1', 'zong2', 'zong3', 'zong4', 'zou1', 'zou2', 'zou3', 'zou4', 'zu1', 'zu2', 'zu3', 'zuan1', 'zuan3', 'zuan4', 'zui2', 'zui3', 'zui4', 'zun1', 'zuo1', 'zuo2', 'zuo3', 'zuo4'] | |
ens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", ' '] | |
ens_U = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] | |
phoneme_to_index = {} | |
num_phonemes = 0 | |
for index, punc in enumerate(puncs): | |
phoneme_to_index[punc] = index + num_phonemes | |
num_phonemes += len(puncs) | |
for index, pinyin in enumerate(pinyins): | |
phoneme_to_index[pinyin] = index + num_phonemes | |
num_phonemes += len(pinyins) | |
for index, en in enumerate(ens): | |
phoneme_to_index[en] = index + num_phonemes | |
for index, en in enumerate(ens_U): | |
phoneme_to_index[en] = index + num_phonemes | |
num_phonemes += len(ens) | |
#print(num_phonemes, phoneme_to_index) # 1342 | |
def encode( | |
text: list[str], | |
padding_value = -1 | |
) -> Int['b nt']: | |
phonemes = [] | |
for t in text: | |
one_phoneme = [] | |
brk = False | |
for word in jieba.cut(t): | |
if all_ch(word): | |
seg = lazy_pinyin(word, style=Style.TONE3, tone_sandhi=True) | |
one_phoneme.extend(seg) | |
elif all_en(word): | |
for seg in word: | |
one_phoneme.append(seg) | |
elif word in [",", "。", "?", "、", "'", " "]: | |
one_phoneme.append(word) | |
else: | |
for ch in word: | |
if all_ch(ch): | |
seg = lazy_pinyin(ch, style=Style.TONE3, tone_sandhi=True) | |
one_phoneme.extend(seg) | |
elif all_en(ch): | |
for seg in ch: | |
one_phoneme.append(seg) | |
else: | |
brk = True | |
break | |
if brk: | |
break | |
if not brk: | |
phonemes.append(one_phoneme) | |
else: | |
print("Error Tokenized", t, list(jieba.cut(t))) | |
list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] | |
padded_tensor = pad_sequence(list_tensors, padding_value = -1) | |
return padded_tensor | |
return encode, num_phonemes | |
# tensor helpers | |
def log(t, eps = 1e-5): | |
return t.clamp(min = eps).log() | |
def lens_to_mask( | |
t: Int['b'], | |
length: int | None = None | |
) -> Bool['b n']: | |
if not exists(length): | |
length = t.amax() | |
seq = torch.arange(length, device = t.device) | |
return einx.less('n, b -> b n', seq, t) | |
def mask_from_start_end_indices( | |
seq_len: Int['b'], | |
start: Int['b'], | |
end: Int['b'] | |
): | |
max_seq_len = seq_len.max().item() | |
seq = torch.arange(max_seq_len, device = start.device).long() | |
return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end) | |
def mask_from_frac_lengths( | |
seq_len: Int['b'], | |
frac_lengths: Float['b'], | |
max_length: int | None = None, | |
val = False | |
): | |
lengths = (frac_lengths * seq_len).long() | |
max_start = seq_len - lengths | |
if not val: | |
rand = torch.rand_like(frac_lengths) | |
else: | |
rand = torch.tensor([0.5]*frac_lengths.shape[0], device=frac_lengths.device).float() | |
start = (max_start * rand).long().clamp(min = 0) | |
end = start + lengths | |
out = mask_from_start_end_indices(seq_len, start, end) | |
if exists(max_length): | |
out = pad_to_length(out, max_length) | |
return out | |
def maybe_masked_mean( | |
t: Float['b n d'], | |
mask: Bool['b n'] | None = None | |
) -> Float['b d']: | |
if not exists(mask): | |
return t.mean(dim = 1) | |
t = einx.where('b n, b n d, -> b n d', mask, t, 0.) | |
num = reduce(t, 'b n d -> b d', 'sum') | |
den = reduce(mask.float(), 'b n -> b', 'sum') | |
return einx.divide('b d, b -> b d', num, den.clamp(min = 1.)) | |
def pad_to_length( | |
t: Tensor, | |
length: int, | |
value = None | |
): | |
seq_len = t.shape[-1] | |
if length > seq_len: | |
t = F.pad(t, (0, length - seq_len), value = value) | |
return t[..., :length] | |
def interpolate_1d( | |
x: Tensor, | |
length: int, | |
mode = 'bilinear' | |
): | |
x = rearrange(x, 'n d -> 1 d n 1') | |
x = F.interpolate(x, (length, 1), mode = mode) | |
return rearrange(x, '1 d n 1 -> n d') | |
# to mel spec | |
class MelSpec(Module): | |
def __init__( | |
self, | |
filter_length = 1024, | |
hop_length = 256, | |
win_length = 1024, | |
n_mel_channels = 100, | |
sampling_rate = 24_000, | |
normalize = False, | |
power = 1, | |
norm = None, | |
center = True, | |
): | |
super().__init__() | |
self.n_mel_channels = n_mel_channels | |
self.sampling_rate = sampling_rate | |
self.mel_stft = torchaudio.transforms.MelSpectrogram( | |
sample_rate = sampling_rate, | |
n_fft = filter_length, | |
win_length = win_length, | |
hop_length = hop_length, | |
n_mels = n_mel_channels, | |
power = power, | |
center = center, | |
normalized = normalize, | |
norm = norm, | |
) | |
self.register_buffer('dummy', tensor(0), persistent = False) | |
def forward(self, inp): | |
if len(inp.shape) == 3: | |
inp = rearrange(inp, 'b 1 nw -> b nw') | |
assert len(inp.shape) == 2 | |
if self.dummy.device != inp.device: | |
self.to(inp.device) | |
mel = self.mel_stft(inp) | |
mel = log(mel) | |
return mel | |
class EncodecWrapper(Module): | |
def __init__(self, path): | |
super().__init__() | |
self.model = EncodecModel.from_pretrained(path) | |
self.processor = AutoProcessor.from_pretrained(path) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.model.eval() | |
def forward(self, waveform): | |
with torch.no_grad(): | |
inputs = self.processor(raw_audio=waveform[0], sampling_rate=self.processor.sampling_rate, return_tensors="pt") | |
emb = self.model.encoder(inputs.input_values) | |
return emb | |
def decode(self, emb): | |
with torch.no_grad(): | |
output = self.model.decoder(emb) | |
return output[0] | |
from audioldm.audio.stft import TacotronSTFT | |
from audioldm.variational_autoencoder import AutoencoderKL | |
from audioldm.utils import default_audioldm_config, get_metadata | |
def build_pretrained_models(name): | |
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu") | |
scale_factor = checkpoint["state_dict"]["scale_factor"].item() | |
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} | |
config = default_audioldm_config(name) | |
vae_config = config["model"]["params"]["first_stage_config"]["params"] | |
vae_config["scale_factor"] = scale_factor | |
vae = AutoencoderKL(**vae_config) | |
vae.load_state_dict(vae_state_dict) | |
fn_STFT = TacotronSTFT( | |
config["preprocessing"]["stft"]["filter_length"], | |
config["preprocessing"]["stft"]["hop_length"], | |
config["preprocessing"]["stft"]["win_length"], | |
config["preprocessing"]["mel"]["n_mel_channels"], | |
config["preprocessing"]["audio"]["sampling_rate"], | |
config["preprocessing"]["mel"]["mel_fmin"], | |
config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
vae.eval() | |
fn_STFT.eval() | |
return vae, fn_STFT | |
class VaeWrapper(Module): | |
def __init__(self): | |
super().__init__() | |
vae, stft = build_pretrained_models("audioldm-s-full") | |
vae.eval() | |
stft.eval() | |
stft = stft.cpu() | |
self.vae = vae | |
for param in self.vae.parameters(): | |
param.requires_grad = False | |
def forward(self, waveform): | |
return None | |
def decode(self, emb): | |
with torch.no_grad(): | |
b, d, l = emb.shape | |
latents = emb.transpose(1,2).reshape(b, l, 8, 16).transpose(1,2) | |
mel = self.vae.decode_first_stage(latents) | |
wave = self.vae.decode_to_waveform(mel) | |
return wave | |
# convolutional positional generating module | |
# taken from https://github.com/lucidrains/voicebox-pytorch/blob/main/voicebox_pytorch/voicebox_pytorch.py#L203 | |
class DepthwiseConv(Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
kernel_size, | |
groups = None | |
): | |
super().__init__() | |
assert not divisible_by(kernel_size, 2) | |
groups = default(groups, dim) # full depthwise conv by default | |
self.dw_conv1d = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), | |
nn.SiLU() | |
) | |
def forward( | |
self, | |
x, | |
mask = None | |
): | |
if exists(mask): | |
x = einx.where('b n, b n d, -> b n d', mask, x, 0.) | |
x = rearrange(x, 'b n c -> b c n') | |
x = self.dw_conv1d(x) | |
out = rearrange(x, 'b c n -> b n c') | |
if exists(mask): | |
out = einx.where('b n, b n d, -> b n d', mask, out, 0.) | |
return out | |
# adaln zero from DiT paper | |
class AdaLNZero(Module): | |
def __init__( | |
self, | |
dim, | |
dim_condition = None, | |
init_bias_value = -2. | |
): | |
super().__init__() | |
dim_condition = default(dim_condition, dim) | |
self.to_gamma = nn.Linear(dim_condition, dim) | |
nn.init.zeros_(self.to_gamma.weight) | |
nn.init.constant_(self.to_gamma.bias, init_bias_value) | |
def forward(self, x, *, condition): | |
if condition.ndim == 2: | |
condition = rearrange(condition, 'b d -> b 1 d') | |
gamma = self.to_gamma(condition).sigmoid() | |
return x * gamma | |
# random projection fourier embedding | |
class RandomFourierEmbed(Module): | |
def __init__(self, dim): | |
super().__init__() | |
assert divisible_by(dim, 2) | |
self.register_buffer('weights', torch.randn(dim // 2)) | |
def forward(self, x): | |
freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi | |
fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *') | |
return fourier_embed | |
# character embedding | |
class CharacterEmbed(Module): | |
def __init__( | |
self, | |
dim, | |
num_embeds = 256, | |
): | |
super().__init__() | |
self.dim = dim | |
self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' | |
def forward( | |
self, | |
text: Int['b nt'], | |
max_seq_len: int, | |
**kwargs | |
) -> Float['b n d']: | |
text = text + 1 # shift all other token ids up by 1 and use 0 as filler token | |
text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address | |
text = pad_to_length(text, max_seq_len, value = 0) | |
return self.embed(text) | |
class InterpolatedCharacterEmbed(Module): | |
def __init__( | |
self, | |
dim, | |
num_embeds = 256, | |
): | |
super().__init__() | |
self.dim = dim | |
self.embed = nn.Embedding(num_embeds, dim) | |
self.abs_pos_mlp = Sequential( | |
Rearrange('... -> ... 1'), | |
Linear(1, dim), | |
nn.SiLU(), | |
Linear(dim, dim) | |
) | |
def forward( | |
self, | |
text: Int['b nt'], | |
max_seq_len: int, | |
mask: Bool['b n'] | None = None | |
) -> Float['b n d']: | |
device = text.device | |
mask = default(mask, (None,)) | |
interp_embeds = [] | |
interp_abs_positions = [] | |
for one_text, one_mask in zip_longest(text, mask): | |
valid_text = one_text >= 0 | |
one_text = one_text[valid_text] | |
one_text_embed = self.embed(one_text) | |
# save the absolute positions | |
text_seq_len = one_text.shape[0] | |
# determine audio sequence length from mask | |
audio_seq_len = max_seq_len | |
if exists(one_mask): | |
audio_seq_len = one_mask.sum().long().item() | |
# interpolate text embedding to audio embedding length | |
interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len) | |
interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device) | |
interp_embeds.append(interp_text_embed) | |
interp_abs_positions.append(interp_abs_pos) | |
interp_embeds = pad_sequence(interp_embeds) | |
interp_abs_positions = pad_sequence(interp_abs_positions) | |
interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2])) | |
interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len) | |
# pass interp absolute positions through mlp for implicit positions | |
interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions) | |
if exists(mask): | |
interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.) | |
return interp_embeds | |
# text audio cross conditioning in multistream setup | |
class TextAudioCrossCondition(Module): | |
def __init__( | |
self, | |
dim, | |
dim_text, | |
cond_audio_to_text = True, | |
): | |
super().__init__() | |
self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False) | |
nn.init.zeros_(self.text_to_audio.weight) | |
self.cond_audio_to_text = cond_audio_to_text | |
if cond_audio_to_text: | |
self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False) | |
nn.init.zeros_(self.audio_to_text.weight) | |
def forward( | |
self, | |
audio: Float['b n d'], | |
text: Float['b n dt'] | |
): | |
audio_text, _ = pack((audio, text), 'b n *') | |
text_cond = self.text_to_audio(audio_text) | |
audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0. | |
return audio + text_cond, text + audio_cond | |
# attention and transformer backbone | |
# for use in both e2tts as well as duration module | |
class Transformer(Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
dim_text = None, # will default to half of audio dimension | |
depth = 8, | |
heads = 8, | |
dim_head = 64, | |
ff_mult = 4, | |
text_depth = None, | |
text_heads = None, | |
text_dim_head = None, | |
text_ff_mult = None, | |
cond_on_time = True, | |
abs_pos_emb = True, | |
max_seq_len = 8192, | |
kernel_size = 31, | |
dropout = 0.1, | |
num_registers = 32, | |
attn_kwargs: dict = dict( | |
gate_value_heads = True, | |
softclamp_logits = True, | |
), | |
ff_kwargs: dict = dict(), | |
if_text_modules = True, | |
if_cross_attn = True, | |
if_audio_conv = True, | |
if_text_conv = False | |
): | |
super().__init__() | |
assert divisible_by(depth, 2), 'depth needs to be even' | |
# absolute positional embedding | |
self.max_seq_len = max_seq_len | |
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None | |
self.dim = dim | |
dim_text = default(dim_text, dim // 2) | |
self.dim_text = dim_text | |
text_heads = default(text_heads, heads) | |
text_dim_head = default(text_dim_head, dim_head) | |
text_ff_mult = default(text_ff_mult, ff_mult) | |
text_depth = default(text_depth, depth) | |
assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers' | |
self.depth = depth | |
self.layers = ModuleList([]) | |
# registers | |
self.num_registers = num_registers | |
self.registers = nn.Parameter(torch.zeros(num_registers, dim)) | |
nn.init.normal_(self.registers, std = 0.02) | |
if if_text_modules: | |
self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text)) | |
nn.init.normal_(self.text_registers, std = 0.02) | |
# rotary embedding | |
self.rotary_emb = RotaryEmbedding(dim_head) | |
self.text_rotary_emb = RotaryEmbedding(dim_head) | |
# time conditioning | |
# will use adaptive rmsnorm | |
self.cond_on_time = cond_on_time | |
rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm | |
postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim) | |
self.time_cond_mlp = Identity() | |
if cond_on_time: | |
self.time_cond_mlp = Sequential( | |
RandomFourierEmbed(dim), | |
Linear(dim + 1, dim), | |
nn.SiLU() | |
) | |
for ind in range(depth): | |
is_later_half = ind >= (depth // 2) | |
has_text = ind < text_depth | |
# speech related | |
if if_audio_conv: | |
speech_conv = DepthwiseConv(dim, kernel_size = kernel_size) | |
attn_norm = rmsnorm_klass(dim) | |
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs) | |
attn_adaln_zero = postbranch_klass() | |
if if_cross_attn: | |
attn_norm2 = rmsnorm_klass(dim) | |
attn2 = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs) | |
attn_adaln_zero2 = postbranch_klass() | |
ff_norm = rmsnorm_klass(dim) | |
ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) | |
ff_adaln_zero = postbranch_klass() | |
skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None | |
if if_cross_attn: | |
if if_audio_conv: | |
speech_modules = ModuleList([ | |
skip_proj, | |
speech_conv, | |
attn_norm, | |
attn, | |
attn_adaln_zero, | |
attn_norm2, | |
attn2, | |
attn_adaln_zero2, | |
ff_norm, | |
ff, | |
ff_adaln_zero, | |
]) | |
else: | |
speech_modules = ModuleList([ | |
skip_proj, | |
attn_norm, | |
attn, | |
attn_adaln_zero, | |
attn_norm2, | |
attn2, | |
attn_adaln_zero2, | |
ff_norm, | |
ff, | |
ff_adaln_zero, | |
]) | |
else: | |
if if_audio_conv: | |
speech_modules = ModuleList([ | |
skip_proj, | |
speech_conv, | |
attn_norm, | |
attn, | |
attn_adaln_zero, | |
ff_norm, | |
ff, | |
ff_adaln_zero, | |
]) | |
else: | |
speech_modules = ModuleList([ | |
skip_proj, | |
attn_norm, | |
attn, | |
attn_adaln_zero, | |
ff_norm, | |
ff, | |
ff_adaln_zero, | |
]) | |
text_modules = None | |
if has_text and if_text_modules: | |
# text related | |
if if_text_conv: | |
text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size) | |
text_attn_norm = RMSNorm(dim_text) | |
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs) | |
text_ff_norm = RMSNorm(dim_text) | |
text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs) | |
# cross condition | |
is_last = ind == (text_depth - 1) | |
cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, cond_audio_to_text = not is_last) | |
if if_text_conv: | |
text_modules = ModuleList([ | |
text_conv, | |
text_attn_norm, | |
text_attn, | |
text_ff_norm, | |
text_ff, | |
cross_condition | |
]) | |
else: | |
text_modules = ModuleList([ | |
text_attn_norm, | |
text_attn, | |
text_ff_norm, | |
text_ff, | |
cross_condition | |
]) | |
self.layers.append(ModuleList([ | |
speech_modules, | |
text_modules | |
])) | |
self.final_norm = RMSNorm(dim) | |
self.if_cross_attn = if_cross_attn | |
self.if_audio_conv = if_audio_conv | |
self.if_text_conv = if_text_conv | |
####self.a_align = Linear(dim, 512) | |
####self.v_align = Linear(dim_text, 512) | |
####self.contrastive_loss = SupConLoss() | |
#self.contrastive_loss = FactorCLSSL(None, [1024, 1280], None) | |
self.contrastive_loss = FactorCLSUP(None, [1024, 1280], 6) | |
def forward( | |
self, | |
x: Float['b n d'], | |
times: Float['b'] | Float[''] | None = None, | |
mask: Bool['b n'] | None = None, | |
text_embed: Float['b n dt'] | None = None, | |
context: Float['b nc dc'] | None = None, | |
context_mask: Float['b nc'] | None = None | |
): | |
batch, seq_len, device = *x.shape[:2], x.device | |
assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa' | |
# handle absolute positions if needed | |
if exists(self.abs_pos_emb): | |
assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer' | |
seq = torch.arange(seq_len, device = device) | |
x = x + self.abs_pos_emb(seq) | |
# handle adaptive rmsnorm kwargs | |
norm_kwargs = dict() | |
if exists(times): | |
if times.ndim == 0: | |
times = repeat(times, ' -> b', b = batch) | |
times = self.time_cond_mlp(times) | |
norm_kwargs.update(condition = times) | |
# register tokens | |
registers = repeat(self.registers, 'r d -> b r d', b = batch) | |
x, registers_packed_shape = pack((registers, x), 'b * d') | |
if exists(mask): | |
mask = F.pad(mask, (self.num_registers, 0), value = True) | |
# rotary embedding | |
rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2]) | |
# text related | |
if exists(text_embed): | |
text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2]) | |
text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch) | |
text_embed, _ = pack((text_registers, text_embed), 'b * d') | |
# skip connection related stuff | |
skips = [] | |
# go through the layers | |
loss_contra = 0 | |
for ind, (speech_modules, text_modules) in enumerate(self.layers): | |
layer = ind + 1 | |
if self.if_cross_attn: | |
if self.if_audio_conv: | |
( | |
maybe_skip_proj, | |
speech_conv, | |
attn_norm, | |
attn, | |
maybe_attn_adaln_zero, | |
attn_norm2, | |
attn2, | |
maybe_attn_adaln_zero2, | |
ff_norm, | |
ff, | |
maybe_ff_adaln_zero | |
) = speech_modules | |
else: | |
( | |
maybe_skip_proj, | |
attn_norm, | |
attn, | |
maybe_attn_adaln_zero, | |
attn_norm2, | |
attn2, | |
maybe_attn_adaln_zero2, | |
ff_norm, | |
ff, | |
maybe_ff_adaln_zero | |
) = speech_modules | |
else: | |
if self.if_audio_conv: | |
( | |
maybe_skip_proj, | |
speech_conv, | |
attn_norm, | |
attn, | |
maybe_attn_adaln_zero, | |
ff_norm, | |
ff, | |
maybe_ff_adaln_zero | |
) = speech_modules | |
else: | |
( | |
maybe_skip_proj, | |
attn_norm, | |
attn, | |
maybe_attn_adaln_zero, | |
ff_norm, | |
ff, | |
maybe_ff_adaln_zero | |
) = speech_modules | |
# smaller text transformer | |
if exists(text_embed) and exists(text_modules): | |
if self.if_text_conv: | |
( | |
text_conv, | |
text_attn_norm, | |
text_attn, | |
text_ff_norm, | |
text_ff, | |
cross_condition | |
) = text_modules | |
else: | |
( | |
text_attn_norm, | |
text_attn, | |
text_ff_norm, | |
text_ff, | |
cross_condition | |
) = text_modules | |
if self.if_text_conv: | |
text_embed = text_conv(text_embed, mask = mask) + text_embed | |
text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed | |
text_embed = text_ff(text_ff_norm(text_embed)) + text_embed | |
if x.shape[0] >= 8: | |
if layer in [1]: | |
#print("contrastive learning", x.shape, text_embed.shape) # [16, 782, 1024] [16, 782, 1280] | |
####features1 = self.a_align(x) | |
####features2 = self.v_align(text_embed) | |
features1 = x | |
features2 = text_embed | |
#b1, b2 = random_sample(range(2,8), 2) | |
#t1, t2 = random_sample(range(x.shape[1]), 2) | |
#loss_contra = torch.cosine_similarity(x[b1,t1:t1+1,:], y[b1,t2:t2+1,:]) + torch.cosine_similarity(x[b1,t1:t1+1,:], y[b2,t1:t1+1,:]) - torch.cosine_similarity(x[b1,t1:t1+1,:], y[b1,t1:t1+1,:]) - torch.cosine_similarity(x[b2,t2:t2+1,:], y[b2,t2:t2+1,:]) | |
features1 = features1[2:8,32:,:] | |
features2 = features2[2:8,32:,:] | |
if self.training: | |
ts = random_sample(range(features1.shape[1]), 1) | |
#ts = random_sample(range(features1.shape[1]), 6) | |
#ts = random_sample(range(features1.shape[1]), 14) | |
else: | |
ts = [350] | |
#ts = [100, 200, 300, 400, 500, 600] | |
#ts = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700] | |
features1 = features1[:,ts,:] | |
features2 = features2[:,ts,:] | |
####loss_contra += self.contrastive_loss(features1, features2, labels=None, mask=None) | |
nf = 1 | |
features1 = features1.reshape(6*nf, -1) | |
features2 = features2.reshape(6*nf, -1) | |
label = torch.tensor([0.0]*nf+[1.0]*nf+[2.0]*nf+[3.0]*nf+[4.0]*nf+[5.0]*nf, device=device).reshape(6*nf,1) | |
loss_contra += self.contrastive_loss(features1, features2, label) | |
else: | |
loss_contra = torch.tensor(0.0, device=device) | |
x, text_embed = cross_condition(x, text_embed) | |
# skip connection logic | |
is_first_half = layer <= (self.depth // 2) | |
is_later_half = not is_first_half | |
if is_first_half: | |
skips.append(x) | |
if is_later_half: | |
skip = skips.pop() | |
x = torch.cat((x, skip), dim = -1) | |
x = maybe_skip_proj(x) | |
# position generating convolution | |
if self.if_audio_conv: | |
x = speech_conv(x, mask = mask) + x | |
# attention and feedforward blocks | |
attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) | |
x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs) | |
if self.if_cross_attn: | |
attn_out = attn2(attn_norm2(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, context = context, context_mask = context_mask) | |
x = x + maybe_attn_adaln_zero2(attn_out, **norm_kwargs) | |
ff_out = ff(ff_norm(x, **norm_kwargs)) | |
x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs) | |
assert len(skips) == 0 | |
_, x = unpack(x, registers_packed_shape, 'b * d') | |
return self.final_norm(x), loss_contra | |
from .multibench_model import FactorCLSUP, FactorCLSSL | |
class SupConLoss(nn.Module): | |
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. | |
It also supports the unsupervised contrastive loss in SimCLR""" | |
def __init__(self, temperature=0.07, contrast_mode='all', | |
base_temperature=0.07): | |
super(SupConLoss, self).__init__() | |
self.temperature = temperature | |
self.contrast_mode = contrast_mode | |
self.base_temperature = base_temperature | |
def forward(self, features, features2, labels=None, mask=None): | |
"""Compute loss for model. If both `labels` and `mask` are None, | |
it degenerates to SimCLR unsupervised loss: | |
https://arxiv.org/pdf/2002.05709.pdf | |
Args: | |
features: hidden vector of shape [bsz, n_views, ...]. | |
labels: ground truth of shape [bsz]. | |
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j | |
has the same class as sample i. Can be asymmetric. | |
Returns: | |
A loss scalar. | |
""" | |
features = F.normalize(features, dim=2) | |
features2 = F.normalize(features2, dim=2) | |
device = (torch.device('cuda') | |
if features.is_cuda | |
else torch.device('cpu')) | |
if len(features.shape) < 3: | |
raise ValueError('`features` needs to be [bsz, n_views, ...],' | |
'at least 3 dimensions are required') | |
if len(features.shape) > 3: | |
features = features.view(features.shape[0], features.shape[1], -1) | |
features2 = features2.view(features2.shape[0], features2.shape[1], -1) | |
batch_size = features.shape[0] | |
if labels is not None and mask is not None: | |
raise ValueError('Cannot define both `labels` and `mask`') | |
elif labels is None and mask is None: | |
mask = torch.eye(batch_size, dtype=torch.float32).to(device) | |
elif labels is not None: | |
labels = labels.contiguous().view(-1, 1) | |
if labels.shape[0] != batch_size: | |
raise ValueError('Num of labels does not match num of features') | |
mask = torch.eq(labels, labels.T).float().to(device) | |
else: | |
mask = mask.float().to(device) | |
contrast_count = features.shape[1] | |
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) | |
contrast_feature2 = torch.cat(torch.unbind(features2, dim=1), dim=0) | |
if self.contrast_mode == 'one': | |
anchor_feature = features[:, 0] | |
anchor_feature2 = features2[:, 0] | |
anchor_count = 1 | |
elif self.contrast_mode == 'all': | |
anchor_feature = contrast_feature | |
anchor_feature2 = contrast_feature2 | |
anchor_count = contrast_count | |
else: | |
raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) | |
# compute logits | |
anchor_dot_contrast = torch.div( | |
#torch.matmul(anchor_feature, contrast_feature.T), | |
torch.matmul(anchor_feature, contrast_feature2.T), | |
self.temperature) | |
# for numerical stability | |
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) | |
logits = anchor_dot_contrast - logits_max.detach() | |
# tile mask | |
mask = mask.repeat(anchor_count, contrast_count) | |
# mask-out self-contrast cases | |
#logits_mask = torch.scatter( | |
# torch.ones_like(mask), | |
# 1, | |
# torch.arange(batch_size * anchor_count).view(-1, 1).to(device), | |
# 0 | |
#) | |
logits_mask = torch.ones_like(mask) | |
mask = mask * logits_mask | |
#print("logits", logits.shape, logits.min(), logits.max()) | |
# compute log_prob | |
exp_logits = torch.exp(logits) * logits_mask | |
#print("exp_logits", exp_logits.shape, exp_logits.min(), exp_logits.max()) | |
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) | |
#print("log_prob", log_prob.shape, log_prob.min(), log_prob.max()) | |
# compute mean of log-likelihood over positive | |
# modified to handle edge cases when there is no positive pair | |
# for an anchor point. | |
# Edge case e.g.:- | |
# features of shape: [4,1,...] | |
# labels: [0,1,1,2] | |
# loss before mean: [nan, ..., ..., nan] | |
mask_pos_pairs = mask.sum(1) | |
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) | |
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs | |
# loss | |
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos | |
loss = loss.view(anchor_count, batch_size).mean() | |
return loss | |
# main classes | |
class DurationPredictor(Module): | |
def __init__( | |
self, | |
transformer: dict | Transformer, | |
num_channels = None, | |
mel_spec_kwargs: dict = dict(), | |
char_embed_kwargs: dict = dict(), | |
text_num_embeds = None, | |
tokenizer: ( | |
Literal['char_utf8', 'phoneme_en'] | | |
Callable[[list[str]], Int['b nt']] | |
) = 'char_utf8' | |
): | |
super().__init__() | |
if isinstance(transformer, dict): | |
transformer = Transformer( | |
**transformer, | |
cond_on_time = False | |
) | |
# mel spec | |
self.mel_spec = MelSpec(**mel_spec_kwargs) | |
self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) | |
self.transformer = transformer | |
dim = transformer.dim | |
dim_text = transformer.dim_text | |
self.dim = dim | |
self.proj_in = Linear(self.num_channels, self.dim) | |
# tokenizer and text embed | |
if callable(tokenizer): | |
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' | |
self.tokenizer = tokenizer | |
elif tokenizer == 'char_utf8': | |
text_num_embeds = 256 | |
self.tokenizer = list_str_to_tensor | |
elif tokenizer == 'phoneme_en': | |
self.tokenizer, text_num_embeds = get_g2p_en_encode() | |
elif tokenizer == 'phoneme_zh': | |
self.tokenizer, text_num_embeds = get_g2p_zh_encode() | |
else: | |
raise ValueError(f'unknown tokenizer string {tokenizer}') | |
self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) | |
# to prediction | |
self.to_pred = Sequential( | |
Linear(dim, 1, bias = False), | |
nn.Softplus(), | |
Rearrange('... 1 -> ...') | |
) | |
def forward( | |
self, | |
x: Float['b n d'] | Float['b nw'], | |
*, | |
text: Int['b nt'] | list[str] | None = None, | |
lens: Int['b'] | None = None, | |
return_loss = True | |
): | |
# raw wave | |
if x.ndim == 2: | |
x = self.mel_spec(x) | |
x = rearrange(x, 'b d n -> b n d') | |
assert x.shape[-1] == self.dim | |
x = self.proj_in(x) | |
batch, seq_len, device = *x.shape[:2], x.device | |
# text | |
text_embed = None | |
if exists(text): | |
if isinstance(text, list): | |
text = list_str_to_tensor(text).to(device) | |
assert text.shape[0] == batch | |
text_embed = self.embed_text(text, seq_len) | |
# handle lengths (duration) | |
if not exists(lens): | |
lens = torch.full((batch,), seq_len, device = device) | |
mask = lens_to_mask(lens, length = seq_len) | |
# if returning a loss, mask out randomly from an index and have it predict the duration | |
if return_loss: | |
rand_frac_index = x.new_zeros(batch).uniform_(0, 1) | |
rand_index = (rand_frac_index * lens).long() | |
seq = torch.arange(seq_len, device = device) | |
mask &= einx.less('n, b -> b n', seq, rand_index) | |
# attending | |
x = self.transformer( | |
x, | |
mask = mask, | |
text_embed = text_embed, | |
) | |
x = maybe_masked_mean(x, mask) | |
pred = self.to_pred(x) | |
# return the prediction if not returning loss | |
if not return_loss: | |
return pred | |
# loss | |
return F.mse_loss(pred, lens.float()) | |
class E2TTS(Module): | |
def __init__( | |
self, | |
transformer: dict | Transformer = None, | |
duration_predictor: dict | DurationPredictor | None = None, | |
odeint_kwargs: dict = dict( | |
#atol = 1e-5, | |
#rtol = 1e-5, | |
#method = 'midpoint' | |
method = "euler" | |
), | |
audiocond_drop_prob = 0.30, | |
cond_drop_prob = 0.20, | |
prompt_drop_prob = 0.10, | |
num_channels = None, | |
mel_spec_module: Module | None = None, | |
char_embed_kwargs: dict = dict(), | |
mel_spec_kwargs: dict = dict(), | |
frac_lengths_mask: tuple[float, float] = (0.7, 1.), | |
audiocond_snr: tuple[float, float] | None = None, | |
concat_cond = False, | |
interpolated_text = False, | |
text_num_embeds: int | None = None, | |
tokenizer: ( | |
Literal['char_utf8', 'phoneme_en', 'phoneme_zh'] | | |
Callable[[list[str]], Int['b nt']] | |
) = 'char_utf8', | |
use_vocos = True, | |
pretrained_vocos_path = 'charactr/vocos-mel-24khz', | |
sampling_rate: int | None = None, | |
frame_size: int = 320, | |
velocity_consistency_weight = -1e-5, | |
if_cond_proj_in = True, | |
cond_proj_in_bias = True, | |
if_embed_text = True, | |
if_text_encoder2 = True, | |
if_clip_encoder = False, | |
video_encoder = "clip_vit" | |
): | |
super().__init__() | |
if isinstance(transformer, dict): | |
transformer = Transformer( | |
**transformer, | |
cond_on_time = True | |
) | |
if isinstance(duration_predictor, dict): | |
duration_predictor = DurationPredictor(**duration_predictor) | |
self.transformer = transformer | |
dim = transformer.dim | |
dim_text = transformer.dim_text | |
self.dim = dim | |
self.dim_text = dim_text | |
self.frac_lengths_mask = frac_lengths_mask | |
self.audiocond_snr = audiocond_snr | |
self.duration_predictor = duration_predictor | |
# sampling | |
self.odeint_kwargs = odeint_kwargs | |
# mel spec | |
self.mel_spec = default(mel_spec_module, None) | |
num_channels = default(num_channels, None) | |
self.num_channels = num_channels | |
self.sampling_rate = default(sampling_rate, None) | |
self.frame_size = frame_size | |
# whether to concat condition and project rather than project both and sum | |
self.concat_cond = concat_cond | |
if concat_cond: | |
self.proj_in = nn.Linear(num_channels * 2, dim) | |
else: | |
self.proj_in = nn.Linear(num_channels, dim) | |
self.cond_proj_in = nn.Linear(num_channels, dim, bias=cond_proj_in_bias) if if_cond_proj_in else None | |
# to prediction | |
self.to_pred = Linear(dim, num_channels) | |
# tokenizer and text embed | |
if callable(tokenizer): | |
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' | |
self.tokenizer = tokenizer | |
elif tokenizer == 'char_utf8': | |
text_num_embeds = 256 | |
self.tokenizer = list_str_to_tensor | |
elif tokenizer == 'phoneme_en': | |
self.tokenizer, text_num_embeds = get_g2p_en_encode() | |
elif tokenizer == 'phoneme_zh': | |
self.tokenizer, text_num_embeds = get_g2p_zh_encode() | |
else: | |
raise ValueError(f'unknown tokenizer string {tokenizer}') | |
self.audiocond_drop_prob = audiocond_drop_prob | |
self.cond_drop_prob = cond_drop_prob | |
self.prompt_drop_prob = prompt_drop_prob | |
# text embedding | |
text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed | |
self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) if if_embed_text else None | |
# weight for velocity consistency | |
self.register_buffer('zero', torch.tensor(0.), persistent = False) | |
self.velocity_consistency_weight = velocity_consistency_weight | |
# default vocos for mel -> audio | |
if pretrained_vocos_path == 'charactr/vocos-mel-24khz': | |
self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None | |
elif pretrained_vocos_path == 'facebook/encodec_24khz': | |
self.vocos = EncodecWrapper("facebook/encodec_24khz") if use_vocos else None | |
elif pretrained_vocos_path == 'vae': | |
self.vocos = VaeWrapper() if use_vocos else None | |
if if_text_encoder2: | |
self.tokenizer2 = AutoTokenizer.from_pretrained("/ailab-train/speech/zhanghaomin/models/flan-t5-large") | |
self.text_encoder2 = T5EncoderModel.from_pretrained("/ailab-train/speech/zhanghaomin/models/flan-t5-large") | |
for param in self.text_encoder2.parameters(): | |
param.requires_grad = False | |
self.text_encoder2.eval() | |
self.proj_text = None | |
if if_clip_encoder: | |
if video_encoder == "clip_vit": | |
self.image_processor = CLIPImageProcessor() | |
#self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train/speech/zhanghaomin/models/IP-Adapter/", subfolder="models/image_encoder") | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train/speech/zhanghaomin/models/IP-Adapter/", subfolder="sdxl_models/image_encoder") | |
elif video_encoder == "clip_vit2": | |
self.image_processor = AutoProcessor.from_pretrained("/ailab-train/speech/zhanghaomin/models/clip-vit-large-patch14-336/") | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train/speech/zhanghaomin/models/clip-vit-large-patch14-336/") | |
elif video_encoder == "clip_convnext": | |
self.image_encoder, _, self.image_processor = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup") | |
elif video_encoder == "dinov2": | |
self.image_processor = AutoImageProcessor.from_pretrained("/ailab-train/speech/zhanghaomin/models/dinov2-giant/") | |
self.image_encoder = AutoModel.from_pretrained("/ailab-train/speech/zhanghaomin/models/dinov2-giant/") | |
elif video_encoder == "mixed": | |
#pass | |
self.image_processor1 = CLIPImageProcessor() | |
self.image_encoder1 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train/speech/zhanghaomin/models/IP-Adapter/", subfolder="sdxl_models/image_encoder") | |
self.image_processor2 = AutoProcessor.from_pretrained("/ailab-train/speech/zhanghaomin/models/clip-vit-large-patch14-336/") | |
self.image_encoder2 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train/speech/zhanghaomin/models/clip-vit-large-patch14-336/") | |
self.image_encoder3, _, self.image_processor3 = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup") | |
self.image_processor4 = AutoImageProcessor.from_pretrained("/ailab-train/speech/zhanghaomin/models/dinov2-giant/") | |
self.image_encoder4 = AutoModel.from_pretrained("/ailab-train/speech/zhanghaomin/models/dinov2-giant/") | |
else: | |
self.image_processor = None | |
self.image_encoder = None | |
if video_encoder != "mixed": | |
for param in self.image_encoder.parameters(): | |
param.requires_grad = False | |
self.image_encoder.eval() | |
else: | |
for param in self.image_encoder1.parameters(): | |
param.requires_grad = False | |
self.image_encoder1.eval() | |
for param in self.image_encoder2.parameters(): | |
param.requires_grad = False | |
self.image_encoder2.eval() | |
for param in self.image_encoder3.parameters(): | |
param.requires_grad = False | |
self.image_encoder3.eval() | |
for param in self.image_encoder4.parameters(): | |
param.requires_grad = False | |
self.image_encoder4.eval() | |
self.dim_text_raw = 4608 | |
self.proj_text = Linear(self.dim_text_raw, dim_text) | |
self.video_encoder = video_encoder | |
for param in self.vocos.parameters(): | |
param.requires_grad = False | |
self.vocos.eval() | |
def device(self): | |
return next(self.parameters()).device | |
def encode_text(self, prompt): | |
device = self.device | |
batch = self.tokenizer2(prompt, max_length=self.tokenizer2.model_max_length, padding=True, truncation=True, return_tensors="pt") | |
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) | |
with torch.no_grad(): | |
encoder_hidden_states = self.text_encoder2(input_ids=input_ids, attention_mask=attention_mask)[0] | |
boolean_encoder_mask = (attention_mask == 1).to(device) | |
return encoder_hidden_states, boolean_encoder_mask | |
def encode_video(self, video_paths, l): | |
if self.proj_text is None: | |
d = self.dim_text | |
else: | |
d = self.dim_text_raw | |
device = self.device | |
b = 20 | |
with torch.no_grad(): | |
video_embeddings = [] | |
video_lens = [] | |
for video_path in video_paths: | |
if video_path is None: | |
video_embeddings.append(None) | |
video_lens.append(0) | |
continue | |
if isinstance(video_path, tuple): | |
video_path, start_sample, max_sample = video_path | |
else: | |
start_sample = 0 | |
max_sample = None | |
if video_path.startswith("/ailab-train/speech/zhanghaomin/VGGSound/"): | |
if self.video_encoder == "clip_vit": | |
feature_path = video_path.replace("/video/", "/feature/").replace(".mp4", ".npz") | |
elif self.video_encoder == "clip_vit2": | |
feature_path = video_path.replace("/video/", "/feature_clip_vit2/").replace(".mp4", ".npz") | |
elif self.video_encoder == "clip_convnext": | |
feature_path = video_path.replace("/video/", "/feature_clip_convnext/").replace(".mp4", ".npz") | |
elif self.video_encoder == "dinov2": | |
feature_path = video_path.replace("/video/", "/feature_dinov2/").replace(".mp4", ".npz") | |
elif self.video_encoder == "mixed": | |
feature_path = video_path.replace("/video/", "/feature_mixed/").replace(".mp4", ".npz") | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
else: | |
if self.video_encoder == "clip_vit": | |
feature_path = video_path.replace(".mp4", ".generated.npz") | |
elif self.video_encoder == "clip_vit2": | |
feature_path = video_path.replace(".mp4", ".generated.clip_vit2.npz") | |
elif self.video_encoder == "clip_convnext": | |
feature_path = video_path.replace(".mp4", ".generated.clip_convnext.npz") | |
elif self.video_encoder == "dinov2": | |
feature_path = video_path.replace(".mp4", ".generated.dinov2.npz") | |
elif self.video_encoder == "mixed": | |
feature_path = video_path.replace(".mp4", ".generated.mixed.npz") | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
if not os.path.exists(feature_path): | |
#print("video not exist", video_path) | |
frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=None) | |
if frames is None: | |
video_embeddings.append(None) | |
video_lens.append(0) | |
continue | |
if self.video_encoder in ["clip_vit", "clip_vit2", "dinov2"]: | |
images = self.image_processor(images=frames, return_tensors="pt").to(device) | |
#print("images", images["pixel_values"].shape, images["pixel_values"].max(), images["pixel_values"].min(), torch.abs(images["pixel_values"]).mean()) | |
elif self.video_encoder in ["clip_convnext"]: | |
images = [] | |
for i in range(frames.shape[0]): | |
images.append(self.image_processor(Image.fromarray(frames[i])).unsqueeze(0)) | |
images = torch.cat(images, dim=0).to(device) | |
#print("images", images.shape, images.max(), images.min(), torch.abs(images).mean()) | |
elif self.video_encoder in ["mixed"]: | |
#images1 = self.image_processor1(images=frames, return_tensors="pt").to(device) | |
images2 = self.image_processor2(images=frames, return_tensors="pt").to(device) | |
images4 = self.image_processor4(images=frames, return_tensors="pt").to(device) | |
images3 = [] | |
for i in range(frames.shape[0]): | |
images3.append(self.image_processor3(Image.fromarray(frames[i])).unsqueeze(0)) | |
images3 = torch.cat(images3, dim=0).to(device) | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
image_embeddings = [] | |
if self.video_encoder == "clip_vit": | |
for i in range(math.ceil(images["pixel_values"].shape[0] / b)): | |
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) | |
elif self.video_encoder == "clip_vit2": | |
for i in range(math.ceil(images["pixel_values"].shape[0] / b)): | |
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) | |
elif self.video_encoder == "clip_convnext": | |
for i in range(math.ceil(images.shape[0] / b)): | |
image_embeddings.append(self.image_encoder.encode_image(images[i*b: (i+1)*b]).cpu()) | |
elif self.video_encoder == "dinov2": | |
for i in range(math.ceil(images["pixel_values"].shape[0] / b)): | |
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu()) | |
elif self.video_encoder == "mixed": | |
feature_path1 = feature_path.replace("/feature_mixed/", "/feature/") | |
if not os.path.exists(feature_path1): | |
image_embeddings1 = [] | |
for i in range(math.ceil(images1["pixel_values"].shape[0] / b)): | |
image_embeddings1.append(self.image_encoder1(pixel_values=images1["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) | |
image_embeddings1 = torch.cat(image_embeddings1, dim=0) | |
#np.savez(feature_path1, image_embeddings1, duration) | |
else: | |
data1 = np.load(feature_path1) | |
image_embeddings1 = torch.from_numpy(data1["arr_0"]) | |
feature_path2 = feature_path.replace("/feature_mixed/", "/feature_clip_vit2/") | |
if not os.path.exists(feature_path2): | |
image_embeddings2 = [] | |
for i in range(math.ceil(images2["pixel_values"].shape[0] / b)): | |
image_embeddings2.append(self.image_encoder2(pixel_values=images2["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu()) | |
image_embeddings2 = torch.cat(image_embeddings2, dim=0) | |
np.savez(feature_path2, image_embeddings2, duration) | |
else: | |
data2 = np.load(feature_path2) | |
image_embeddings2 = torch.from_numpy(data2["arr_0"]) | |
feature_path3 = feature_path.replace("/feature_mixed/", "/feature_clip_convnext/") | |
if not os.path.exists(feature_path3): | |
image_embeddings3 = [] | |
for i in range(math.ceil(images3.shape[0] / b)): | |
image_embeddings3.append(self.image_encoder3.encode_image(images3[i*b: (i+1)*b]).cpu()) | |
image_embeddings3 = torch.cat(image_embeddings3, dim=0) | |
np.savez(feature_path3, image_embeddings3, duration) | |
else: | |
data3 = np.load(feature_path3) | |
image_embeddings3 = torch.from_numpy(data3["arr_0"]) | |
feature_path4 = feature_path.replace("/feature_mixed/", "/feature_dinov2/") | |
if not os.path.exists(feature_path4): | |
image_embeddings4 = [] | |
for i in range(math.ceil(images4["pixel_values"].shape[0] / b)): | |
image_embeddings4.append(self.image_encoder4(pixel_values=images4["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu()) | |
image_embeddings4 = torch.cat(image_embeddings4, dim=0) | |
np.savez(feature_path4, image_embeddings4, duration) | |
else: | |
data4 = np.load(feature_path4) | |
image_embeddings4 = torch.from_numpy(data4["arr_0"]) | |
mixed_l = min([image_embeddings1.shape[0], image_embeddings2.shape[0], image_embeddings3.shape[0], image_embeddings4.shape[0]]) | |
for i in range(mixed_l): | |
image_embeddings.append(torch.cat([image_embeddings1[i:i+1,:], image_embeddings2[i:i+1,:], image_embeddings3[i:i+1,:], image_embeddings4[i:i+1,:]], dim=1)) | |
else: | |
raise Exception("Invalid video_encoder " + self.video_encoder) | |
image_embeddings = torch.cat(image_embeddings, dim=0) | |
#print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean()) | |
np.savez(feature_path, image_embeddings, duration) | |
else: | |
#print("video exist", feature_path) | |
data = np.load(feature_path) | |
image_embeddings = torch.from_numpy(data["arr_0"]) | |
#print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean()) | |
duration = data["arr_1"].item() | |
if max_sample is None: | |
max_sample = int(duration * self.sampling_rate) | |
interpolated = [] | |
for i in range(start_sample, max_sample, self.frame_size): | |
j = min(round((i+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1) | |
interpolated.append(image_embeddings[j:j+1]) | |
if len(interpolated) >= l: | |
break | |
interpolated = torch.cat(interpolated, dim=0) | |
#ll = list(range(start_sample, max_sample, self.frame_size)) | |
#print("encode_video l", len(ll), l, round((ll[-1]+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1) | |
#print("encode_video one", video_path, duration, image_embeddings.shape, interpolated.shape, l) | |
video_embeddings.append(interpolated.unsqueeze(0)) | |
video_lens.append(interpolated.shape[1]) | |
max_length = max(video_lens) | |
if max_length == 0: | |
max_length = l | |
else: | |
max_length = l | |
for i in range(len(video_embeddings)): | |
if video_embeddings[i] is None: | |
video_embeddings[i] = torch.zeros(1, max_length, d) | |
continue | |
if video_embeddings[i].shape[1] < max_length: | |
video_embeddings[i] = torch.cat([video_embeddings[i], torch.zeros(1, max_length-video_embeddings[i].shape[1], d)], 1) | |
video_embeddings = torch.cat(video_embeddings, 0) | |
#print("encode_video", l, video_embeddings.shape, video_lens) | |
return video_embeddings.to(device) | |
def transformer_with_pred_head( | |
self, | |
x: Float['b n d'], | |
cond: Float['b n d'] | None = None, | |
times: Float['b'] | None = None, | |
mask: Bool['b n'] | None = None, | |
text: Int['b nt'] | Float['b nt dt'] | None = None, | |
prompt = None, | |
video_drop_prompt = None, | |
audio_drop_prompt = None, | |
drop_audio_cond: bool | None = None, | |
drop_text_cond: bool | None = None, | |
drop_text_prompt: bool | None = None, | |
return_drop_conditions = False | |
): | |
seq_len = x.shape[-2] | |
bs = x.shape[0] | |
drop_audio_cond = [default(drop_audio_cond, self.training and random() < self.audiocond_drop_prob) for _ in range(bs)] | |
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob) | |
drop_text_prompt = [default(drop_text_prompt, self.training and random() < self.prompt_drop_prob) for _ in range(bs)] | |
if cond is not None: | |
for b in range(bs): | |
if drop_audio_cond[b]: | |
cond[b] = 0 | |
if audio_drop_prompt is not None and audio_drop_prompt[b]: | |
cond[b] = 0 | |
if cond is not None: | |
if self.concat_cond: | |
# concat condition, given as using voicebox-like scheme | |
x = torch.cat((cond, x), dim = -1) | |
x = self.proj_in(x) | |
if cond is not None: | |
if not self.concat_cond: | |
# an alternative is to simply sum the condition | |
# seems to work fine | |
cond = self.cond_proj_in(cond) | |
x = x + cond | |
# whether to use a text embedding | |
text_embed = None | |
if exists(text) and len(text.shape) == 3: | |
text_embed = text.clone() | |
if drop_text_cond: | |
for b in range(bs): | |
text_embed[b] = 0 | |
elif exists(text) and not drop_text_cond: | |
text_embed = self.embed_text(text, seq_len, mask = mask) | |
context, context_mask = None, None | |
if prompt is not None: | |
#for b in range(bs): | |
# if drop_text_prompt[b]: | |
# prompt[b] = "" | |
if video_drop_prompt is not None: | |
for b in range(bs): | |
if video_drop_prompt[b]: | |
prompt[b] = "the sound of X X" | |
context, context_mask = self.encode_text(prompt) | |
for b in range(bs): | |
if drop_text_prompt[b]: | |
context[b] = 0 | |
if video_drop_prompt is not None and video_drop_prompt[b]: | |
context[b] = 0 | |
#print("cross attention", context.shape, context_mask.shape, x.shape, mask.shape, text_embed.shape if text_embed is not None else None, torch.mean(torch.abs(text_embed), dim=(1,2))) | |
#print("video_drop_prompt", prompt, video_drop_prompt, context.shape, torch.mean(torch.abs(context), dim=(1,2))) | |
#print("audio_drop_prompt", audio_drop_prompt, cond.shape, torch.mean(torch.abs(cond), dim=(1,2))) | |
if self.proj_text is not None: | |
text_embed = self.proj_text(text_embed) | |
# attend | |
attended, loss_contra = self.transformer( | |
x, | |
times = times, | |
mask = mask, | |
text_embed = text_embed, | |
context = context, | |
context_mask = context_mask | |
) | |
pred = self.to_pred(attended) | |
if not return_drop_conditions: | |
return pred, loss_contra | |
return pred, loss_contra, drop_audio_cond, drop_text_cond, drop_text_prompt | |
def cfg_transformer_with_pred_head( | |
self, | |
*args, | |
cfg_strength: float = 1., | |
remove_parallel_component: bool = True, | |
keep_parallel_frac: float = 0., | |
**kwargs, | |
): | |
pred, _ = self.transformer_with_pred_head(*args, drop_audio_cond = False, drop_text_cond = False, drop_text_prompt = False, **kwargs) | |
if cfg_strength < 1e-5: | |
return pred | |
null_pred, _ = self.transformer_with_pred_head(*args, drop_audio_cond = True, drop_text_cond = True, drop_text_prompt = True, **kwargs) | |
cfg_update = pred - null_pred | |
if remove_parallel_component: | |
# https://arxiv.org/abs/2410.02416 | |
parallel, orthogonal = project(cfg_update, pred) | |
cfg_update = orthogonal + parallel * keep_parallel_frac | |
return pred + cfg_update * cfg_strength | |
def add_noise(self, signal, mask, val): | |
if self.audiocond_snr is None: | |
return signal | |
if not val: | |
snr = np.random.uniform(self.audiocond_snr[0], self.audiocond_snr[1]) | |
else: | |
snr = (self.audiocond_snr[0] + self.audiocond_snr[1]) / 2.0 | |
#print("add_noise", self.audiocond_snr, snr, signal.shape, mask) # [True, ..., False] | |
noise = torch.randn_like(signal) | |
w = torch.abs(signal[mask]).mean() / (torch.abs(noise[mask]).mean() + 1e-6) / snr | |
return signal + noise * w | |
def sample( | |
self, | |
cond: Float['b n d'] | Float['b nw'] | None = None, | |
*, | |
text: Int['b nt'] | list[str] | None = None, | |
lens: Int['b'] | None = None, | |
duration: int | Int['b'] | None = None, | |
steps = 32, | |
cfg_strength = 1., # they used a classifier free guidance strength of 1. | |
remove_parallel_component = True, | |
sway_sampling = True, | |
max_duration = 4096, # in case the duration predictor goes haywire | |
vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None, | |
return_raw_output: bool | None = None, | |
save_to_filename: str | None = None, | |
prompt = None, | |
video_drop_prompt = None, | |
audio_drop_prompt = None, | |
video_paths = None | |
) -> ( | |
Float['b n d'], | |
list[Float['_']] | |
): | |
self.eval() | |
# raw wave | |
if cond.ndim == 2: | |
cond = self.mel_spec(cond) | |
cond = rearrange(cond, 'b d n -> b n d') | |
assert cond.shape[-1] == self.num_channels | |
batch, cond_seq_len, device = *cond.shape[:2], cond.device | |
if not exists(lens): | |
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) | |
if video_paths is not None: | |
text = self.encode_video(video_paths, cond_seq_len) | |
# text | |
elif isinstance(text, list): | |
text = self.tokenizer(text).to(device) | |
assert text.shape[0] == batch | |
if exists(text): | |
text_lens = (text != -1).sum(dim = -1) | |
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters | |
# duration | |
cond_mask = lens_to_mask(lens) | |
if exists(duration): | |
if isinstance(duration, int): | |
duration = torch.full((batch,), duration, device = device, dtype = torch.long) | |
elif exists(self.duration_predictor): | |
duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long() | |
duration = torch.maximum(lens, duration) # just add one token so something is generated | |
duration = duration.clamp(max = max_duration) | |
assert duration.shape[0] == batch | |
max_duration = duration.amax() | |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) | |
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) | |
cond_mask = rearrange(cond_mask, '... -> ... 1') | |
mask = lens_to_mask(duration) | |
#print("####", duration, mask, mask.shape, lens, cond_mask, cond_mask.shape, text) | |
# neural ode | |
def fn(t, x): | |
# at each step, conditioning is fixed | |
if lens[0] == duration[0]: | |
print("No cond", lens, duration) | |
step_cond = None | |
else: | |
step_cond = torch.where(cond_mask, self.add_noise(cond, cond_mask, True), torch.zeros_like(cond)) | |
# predict flow | |
return self.cfg_transformer_with_pred_head( | |
x, | |
step_cond, | |
times = t, | |
text = text, | |
mask = mask, | |
prompt = prompt, | |
video_drop_prompt = video_drop_prompt, | |
audio_drop_prompt = audio_drop_prompt, | |
cfg_strength = cfg_strength, | |
remove_parallel_component = remove_parallel_component | |
) | |
####torch.manual_seed(0) | |
y0 = torch.randn_like(cond) | |
t = torch.linspace(0, 1, steps, device = self.device) | |
if sway_sampling: | |
t = t + -1.0 * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
#print("@@@@", t) | |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs) | |
sampled = trajectory[-1] | |
out = sampled | |
if lens[0] != duration[0]: | |
out = torch.where(cond_mask, cond, out) | |
# able to return raw untransformed output, if not using mel rep | |
if exists(return_raw_output) and return_raw_output: | |
return out | |
# take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on | |
if exists(vocoder): | |
assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling' | |
out = rearrange(out, 'b n d -> b d n') | |
out = vocoder(out) | |
elif exists(self.vocos): | |
audio = [] | |
for mel, one_mask in zip(out, mask): | |
#one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5) | |
one_out = mel[one_mask] | |
one_out = rearrange(one_out, 'n d -> 1 d n') | |
one_audio = self.vocos.decode(one_out) | |
one_audio = rearrange(one_audio, '1 nw -> nw') | |
audio.append(one_audio) | |
out = audio | |
if exists(save_to_filename): | |
assert exists(vocoder) or exists(self.vocos) | |
assert exists(self.sampling_rate) | |
path = Path(save_to_filename) | |
parent_path = path.parents[0] | |
parent_path.mkdir(exist_ok = True, parents = True) | |
for ind, one_audio in enumerate(out): | |
one_audio = rearrange(one_audio, 'nw -> 1 nw') | |
if len(out) == 1: | |
save_path = str(parent_path / f'{path.name}') | |
else: | |
save_path = str(parent_path / f'{ind + 1}.{path.name}') | |
torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate) | |
return out | |
def forward( | |
self, | |
inp: Float['b n d'] | Float['b nw'], # mel or raw wave | |
*, | |
text: Int['b nt'] | list[str] | None = None, | |
times: int | Int['b'] | None = None, | |
lens: Int['b'] | None = None, | |
velocity_consistency_model: E2TTS | None = None, | |
velocity_consistency_delta = 1e-3, | |
prompt = None, | |
video_drop_prompt=None, | |
audio_drop_prompt=None, | |
val = False, | |
video_paths=None | |
): | |
need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0. | |
# handle raw wave | |
if inp.ndim == 2: | |
inp = self.mel_spec(inp) | |
inp = rearrange(inp, 'b d n -> b n d') | |
assert inp.shape[-1] == self.num_channels | |
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device | |
if video_paths is not None: | |
text = self.encode_video(video_paths, seq_len) | |
# handle text as string | |
elif isinstance(text, list): | |
text = self.tokenizer(text).to(device) | |
#print("text tokenized", text[0]) | |
assert text.shape[0] == batch | |
# lens and mask | |
if not exists(lens): | |
lens = torch.full((batch,), seq_len, device = device) | |
mask = lens_to_mask(lens, length = seq_len) | |
# get a random span to mask out for training conditionally | |
if not val: | |
if self.audiocond_drop_prob > 1.0: | |
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(1.0,1.0) | |
else: | |
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) | |
else: | |
frac_lengths = torch.tensor([(0.7+1.0)/2.0]*batch, device = self.device).float() | |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len, val = val) | |
if exists(mask): | |
rand_span_mask &= mask | |
# mel is x1 | |
x1 = inp | |
# main conditional flow training logic | |
# just ~5 loc | |
# x0 is gaussian noise | |
if val: | |
torch.manual_seed(0) | |
x0 = torch.randn_like(x1) | |
if val: | |
torch.manual_seed(int(time.time()*1000)) | |
# t is random times from above | |
if times is None: | |
times = torch.rand((batch,), dtype = dtype, device = self.device) | |
else: | |
times = torch.tensor((times,)*batch, dtype = dtype, device = self.device) | |
t = rearrange(times, 'b -> b 1 1') | |
# if need velocity consistency, make sure time does not exceed 1. | |
if need_velocity_loss: | |
t = t * (1. - velocity_consistency_delta) | |
# sample xt (w in the paper) | |
w = (1. - t) * x0 + t * x1 | |
flow = x1 - x0 | |
# only predict what is within the random mask span for infilling | |
if self.audiocond_drop_prob > 1.0: | |
cond = None | |
else: | |
cond = einx.where( | |
'b n, b n d, b n d -> b n d', | |
rand_span_mask, | |
torch.zeros_like(x1), self.add_noise(x1, ~rand_span_mask, val) | |
) | |
# transformer and prediction head | |
if not val: | |
pred, loss_contra, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head( | |
w, | |
cond, | |
times = times, | |
text = text, | |
mask = mask, | |
prompt = prompt, | |
video_drop_prompt = video_drop_prompt, | |
audio_drop_prompt = audio_drop_prompt, | |
return_drop_conditions = True | |
) | |
else: | |
pred, loss_contra, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head( | |
w, | |
cond, | |
times = times, | |
text = text, | |
mask = mask, | |
prompt = prompt, | |
video_drop_prompt = video_drop_prompt, | |
audio_drop_prompt = audio_drop_prompt, | |
drop_audio_cond = False, | |
drop_text_cond = False, | |
drop_text_prompt = False, | |
return_drop_conditions = True | |
) | |
# maybe velocity consistency loss | |
velocity_loss = self.zero | |
if need_velocity_loss: | |
t_with_delta = t + velocity_consistency_delta | |
w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1 | |
with torch.no_grad(): | |
ema_pred, _ = velocity_consistency_model.transformer_with_pred_head( | |
w_with_delta, | |
cond, | |
times = times + velocity_consistency_delta, | |
text = text, | |
mask = mask, | |
prompt = prompt, | |
video_drop_prompt = video_drop_prompt, | |
audio_drop_prompt = audio_drop_prompt, | |
drop_audio_cond = did_drop_audio_cond, | |
drop_text_cond = did_drop_text_cond, | |
drop_text_prompt = did_drop_text_prompt | |
) | |
velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none') | |
velocity_loss = velocity_loss[rand_span_mask].mean() | |
# flow matching loss | |
loss = F.mse_loss(pred, flow, reduction = 'none') | |
#print("loss", loss.shape, loss, "rand_span_mask", rand_span_mask.shape, rand_span_mask, "loss[rand_span_mask]", loss[rand_span_mask].shape, loss[rand_span_mask]) | |
loss = loss[rand_span_mask].mean() | |
# total loss and get breakdown | |
#total_loss = loss + velocity_loss * self.velocity_consistency_weight | |
#breakdown = LossBreakdown(loss, velocity_loss) | |
#print("loss", loss, velocity_loss, self.velocity_consistency_weight) | |
####contra_weight = 0.04 | |
contra_weight = 1.0 | |
total_loss = loss + loss_contra * contra_weight | |
breakdown = LossBreakdown(loss, loss_contra * contra_weight) | |
#print("loss", loss, loss_contra * contra_weight) | |
# return total loss and bunch of intermediates | |
return E2TTSReturn(total_loss, cond if cond is not None else w, pred, x0 + pred, breakdown) | |