lshzhm's picture
fps
53593c1
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
dt - dimension text
"""
from __future__ import annotations
from pathlib import Path
from random import random
from functools import partial
from itertools import zip_longest
from collections import namedtuple
from typing import Literal, Callable
import jaxtyping
from beartype import beartype
import torch
import torch.nn.functional as F
from torch import nn, tensor, Tensor, from_numpy
from torch.nn import Module, ModuleList, Sequential, Linear
from torch.nn.utils.rnn import pad_sequence
import torchaudio
from torchaudio.functional import DB_to_amplitude
from torchdiffeq import odeint
import einx
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
from x_transformers import (
Attention,
FeedForward,
RMSNorm,
AdaptiveRMSNorm,
)
from x_transformers.x_transformers import RotaryEmbedding
import sys
sys.path.insert(0, "/zhanghaomin/codes3/vocos-main/")
from vocos import Vocos
from transformers import AutoTokenizer
from transformers import T5EncoderModel
from transformers import EncodecModel, AutoProcessor
sys.path.insert(0, "./src/audeo/")
import Video2RollNet
import torchvision.transforms as transforms
####transform = transforms.Compose([lambda x: x.resize((900,100)),
#### lambda x: np.reshape(x,(100,900,1)),
#### lambda x: np.transpose(x,[2,0,1]),
#### lambda x: x/255.])
transform = transforms.Compose([lambda x: x.resize((100,900)),
lambda x: np.reshape(x,(900,100,1)),
lambda x: np.transpose(x,[2,1,0]),
lambda x: x/255.])
####NOTES = 51
####NOTTE_MIN = 15
####NOTE_MAX = 65
NOTES = 88
NOTTE_MIN = 0#15
NOTE_MAX = 87#72
import os
import math
import traceback
import numpy as np
from moviepy.editor import AudioFileClip, VideoFileClip
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
#import open_clip
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import time
import warnings
warnings.filterwarnings("ignore")
def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8)
return waveform * 0.5
def read_frames_with_moviepy(video_path, max_frame_nums=None):
try:
clip = VideoFileClip(video_path)
duration = clip.duration
frames = []
for frame in clip.iter_frames():
frames.append(frame)
except:
print("Error read_frames_with_moviepy", video_path)
traceback.print_exc()
return None, None
if max_frame_nums is not None:
frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int)
return np.array(frames)[frames_idx, ...], duration
else:
return np.array(frames), duration
pad_sequence = partial(pad_sequence, batch_first = True)
# constants
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(jaxtyping.Float)
Int = TorchTyping(jaxtyping.Int)
Bool = TorchTyping(jaxtyping.Bool)
# named tuples
LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency', 'a', 'b'])
E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown'])
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def divisible_by(num, den):
return (num % den) == 0
def pack_one_with_inverse(x, pattern):
packed, packed_shape = pack([x], pattern)
def inverse(x, inverse_pattern = None):
inverse_pattern = default(inverse_pattern, pattern)
return unpack(x, packed_shape, inverse_pattern)[0]
return packed, inverse
class Identity(Module):
def forward(self, x, **kwargs):
return x
# tensor helpers
def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')
dtype = x.dtype
x, y = x.double(), y.double()
unit = F.normalize(y, dim = -1)
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel
return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(
text: list[str],
padding_value = -1
) -> Int['b nt']:
list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text]
padded_tensor = pad_sequence(list_tensors, padding_value = -1)
return padded_tensor
# simple english phoneme-based tokenizer
from g2p_en import G2p
import jieba
from pypinyin import lazy_pinyin, Style
def get_g2p_en_encode():
g2p = G2p()
# used by @lucasnewman successfully here
# https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py
phoneme_to_index = g2p.p2idx
num_phonemes = len(phoneme_to_index)
extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..']
num_extended_chars = len(extended_chars)
extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)}
phoneme_to_index = {**phoneme_to_index, **extended_chars_dict}
def encode(
text: list[str],
padding_value = -1
) -> Int['b nt']:
phonemes = [g2p(t) for t in text]
list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes]
padded_tensor = pad_sequence(list_tensors, padding_value = -1)
return padded_tensor
return encode, (num_phonemes + num_extended_chars)
def all_en(word):
res = word.replace("'", "").encode('utf-8').isalpha()
return res
def all_ch(word):
res = True
for w in word:
if not '\u4e00' <= w <= '\u9fff':
res = False
return res
def get_g2p_zh_encode():
puncs = [',', '。', '?', '、']
pinyins = ['a', 'a1', 'ai1', 'ai2', 'ai3', 'ai4', 'an1', 'an3', 'an4', 'ang1', 'ang2', 'ang4', 'ao1', 'ao2', 'ao3', 'ao4', 'ba', 'ba1', 'ba2', 'ba3', 'ba4', 'bai1', 'bai2', 'bai3', 'bai4', 'ban1', 'ban2', 'ban3', 'ban4', 'bang1', 'bang2', 'bang3', 'bang4', 'bao1', 'bao2', 'bao3', 'bao4', 'bei', 'bei1', 'bei2', 'bei3', 'bei4', 'ben1', 'ben2', 'ben3', 'ben4', 'beng1', 'beng2', 'beng4', 'bi1', 'bi2', 'bi3', 'bi4', 'bian1', 'bian2', 'bian3', 'bian4', 'biao1', 'biao2', 'biao3', 'bie1', 'bie2', 'bie3', 'bie4', 'bin1', 'bin4', 'bing1', 'bing2', 'bing3', 'bing4', 'bo', 'bo1', 'bo2', 'bo3', 'bo4', 'bu2', 'bu3', 'bu4', 'ca1', 'cai1', 'cai2', 'cai3', 'cai4', 'can1', 'can2', 'can3', 'can4', 'cang1', 'cang2', 'cao1', 'cao2', 'cao3', 'ce4', 'cen1', 'cen2', 'ceng1', 'ceng2', 'ceng4', 'cha1', 'cha2', 'cha3', 'cha4', 'chai1', 'chai2', 'chan1', 'chan2', 'chan3', 'chan4', 'chang1', 'chang2', 'chang3', 'chang4', 'chao1', 'chao2', 'chao3', 'che1', 'che2', 'che3', 'che4', 'chen1', 'chen2', 'chen3', 'chen4', 'cheng1', 'cheng2', 'cheng3', 'cheng4', 'chi1', 'chi2', 'chi3', 'chi4', 'chong1', 'chong2', 'chong3', 'chong4', 'chou1', 'chou2', 'chou3', 'chou4', 'chu1', 'chu2', 'chu3', 'chu4', 'chua1', 'chuai1', 'chuai2', 'chuai3', 'chuai4', 'chuan1', 'chuan2', 'chuan3', 'chuan4', 'chuang1', 'chuang2', 'chuang3', 'chuang4', 'chui1', 'chui2', 'chun1', 'chun2', 'chun3', 'chuo1', 'chuo4', 'ci1', 'ci2', 'ci3', 'ci4', 'cong1', 'cong2', 'cou4', 'cu1', 'cu4', 'cuan1', 'cuan2', 'cuan4', 'cui1', 'cui3', 'cui4', 'cun1', 'cun2', 'cun4', 'cuo1', 'cuo2', 'cuo4', 'da', 'da1', 'da2', 'da3', 'da4', 'dai1', 'dai3', 'dai4', 'dan1', 'dan2', 'dan3', 'dan4', 'dang1', 'dang2', 'dang3', 'dang4', 'dao1', 'dao2', 'dao3', 'dao4', 'de', 'de1', 'de2', 'dei3', 'den4', 'deng1', 'deng2', 'deng3', 'deng4', 'di1', 'di2', 'di3', 'di4', 'dia3', 'dian1', 'dian2', 'dian3', 'dian4', 'diao1', 'diao3', 'diao4', 'die1', 'die2', 'ding1', 'ding2', 'ding3', 'ding4', 'diu1', 'dong1', 'dong3', 'dong4', 'dou1', 'dou2', 'dou3', 'dou4', 'du1', 'du2', 'du3', 'du4', 'duan1', 'duan2', 'duan3', 'duan4', 'dui1', 'dui4', 'dun1', 'dun3', 'dun4', 'duo1', 'duo2', 'duo3', 'duo4', 'e1', 'e2', 'e3', 'e4', 'ei2', 'en1', 'en4', 'er', 'er2', 'er3', 'er4', 'fa1', 'fa2', 'fa3', 'fa4', 'fan1', 'fan2', 'fan3', 'fan4', 'fang1', 'fang2', 'fang3', 'fang4', 'fei1', 'fei2', 'fei3', 'fei4', 'fen1', 'fen2', 'fen3', 'fen4', 'feng1', 'feng2', 'feng3', 'feng4', 'fo2', 'fou2', 'fou3', 'fu1', 'fu2', 'fu3', 'fu4', 'ga1', 'ga2', 'ga4', 'gai1', 'gai3', 'gai4', 'gan1', 'gan2', 'gan3', 'gan4', 'gang1', 'gang2', 'gang3', 'gang4', 'gao1', 'gao2', 'gao3', 'gao4', 'ge1', 'ge2', 'ge3', 'ge4', 'gei2', 'gei3', 'gen1', 'gen2', 'gen3', 'gen4', 'geng1', 'geng3', 'geng4', 'gong1', 'gong3', 'gong4', 'gou1', 'gou2', 'gou3', 'gou4', 'gu', 'gu1', 'gu2', 'gu3', 'gu4', 'gua1', 'gua2', 'gua3', 'gua4', 'guai1', 'guai2', 'guai3', 'guai4', 'guan1', 'guan2', 'guan3', 'guan4', 'guang1', 'guang2', 'guang3', 'guang4', 'gui1', 'gui2', 'gui3', 'gui4', 'gun3', 'gun4', 'guo1', 'guo2', 'guo3', 'guo4', 'ha1', 'ha2', 'ha3', 'hai1', 'hai2', 'hai3', 'hai4', 'han1', 'han2', 'han3', 'han4', 'hang1', 'hang2', 'hang4', 'hao1', 'hao2', 'hao3', 'hao4', 'he1', 'he2', 'he4', 'hei1', 'hen2', 'hen3', 'hen4', 'heng1', 'heng2', 'heng4', 'hong1', 'hong2', 'hong3', 'hong4', 'hou1', 'hou2', 'hou3', 'hou4', 'hu1', 'hu2', 'hu3', 'hu4', 'hua1', 'hua2', 'hua4', 'huai2', 'huai4', 'huan1', 'huan2', 'huan3', 'huan4', 'huang1', 'huang2', 'huang3', 'huang4', 'hui1', 'hui2', 'hui3', 'hui4', 'hun1', 'hun2', 'hun4', 'huo', 'huo1', 'huo2', 'huo3', 'huo4', 'ji1', 'ji2', 'ji3', 'ji4', 'jia', 'jia1', 'jia2', 'jia3', 'jia4', 'jian1', 'jian2', 'jian3', 'jian4', 'jiang1', 'jiang2', 'jiang3', 'jiang4', 'jiao1', 'jiao2', 'jiao3', 'jiao4', 'jie1', 'jie2', 'jie3', 'jie4', 'jin1', 'jin2', 'jin3', 'jin4', 'jing1', 'jing2', 'jing3', 'jing4', 'jiong3', 'jiu1', 'jiu2', 'jiu3', 'jiu4', 'ju1', 'ju2', 'ju3', 'ju4', 'juan1', 'juan2', 'juan3', 'juan4', 'jue1', 'jue2', 'jue4', 'jun1', 'jun4', 'ka1', 'ka2', 'ka3', 'kai1', 'kai2', 'kai3', 'kai4', 'kan1', 'kan2', 'kan3', 'kan4', 'kang1', 'kang2', 'kang4', 'kao2', 'kao3', 'kao4', 'ke1', 'ke2', 'ke3', 'ke4', 'ken3', 'keng1', 'kong1', 'kong3', 'kong4', 'kou1', 'kou2', 'kou3', 'kou4', 'ku1', 'ku2', 'ku3', 'ku4', 'kua1', 'kua3', 'kua4', 'kuai3', 'kuai4', 'kuan1', 'kuan2', 'kuan3', 'kuang1', 'kuang2', 'kuang4', 'kui1', 'kui2', 'kui3', 'kui4', 'kun1', 'kun3', 'kun4', 'kuo4', 'la', 'la1', 'la2', 'la3', 'la4', 'lai2', 'lai4', 'lan2', 'lan3', 'lan4', 'lang1', 'lang2', 'lang3', 'lang4', 'lao1', 'lao2', 'lao3', 'lao4', 'le', 'le1', 'le4', 'lei', 'lei1', 'lei2', 'lei3', 'lei4', 'leng1', 'leng2', 'leng3', 'leng4', 'li', 'li1', 'li2', 'li3', 'li4', 'lia3', 'lian2', 'lian3', 'lian4', 'liang2', 'liang3', 'liang4', 'liao1', 'liao2', 'liao3', 'liao4', 'lie1', 'lie2', 'lie3', 'lie4', 'lin1', 'lin2', 'lin3', 'lin4', 'ling2', 'ling3', 'ling4', 'liu1', 'liu2', 'liu3', 'liu4', 'long1', 'long2', 'long3', 'long4', 'lou1', 'lou2', 'lou3', 'lou4', 'lu1', 'lu2', 'lu3', 'lu4', 'luan2', 'luan3', 'luan4', 'lun1', 'lun2', 'lun4', 'luo1', 'luo2', 'luo3', 'luo4', 'lv2', 'lv3', 'lv4', 'lve3', 'lve4', 'ma', 'ma1', 'ma2', 'ma3', 'ma4', 'mai2', 'mai3', 'mai4', 'man2', 'man3', 'man4', 'mang2', 'mang3', 'mao1', 'mao2', 'mao3', 'mao4', 'me', 'mei2', 'mei3', 'mei4', 'men', 'men1', 'men2', 'men4', 'meng1', 'meng2', 'meng3', 'meng4', 'mi1', 'mi2', 'mi3', 'mi4', 'mian2', 'mian3', 'mian4', 'miao1', 'miao2', 'miao3', 'miao4', 'mie1', 'mie4', 'min2', 'min3', 'ming2', 'ming3', 'ming4', 'miu4', 'mo1', 'mo2', 'mo3', 'mo4', 'mou1', 'mou2', 'mou3', 'mu2', 'mu3', 'mu4', 'n2', 'na1', 'na2', 'na3', 'na4', 'nai2', 'nai3', 'nai4', 'nan1', 'nan2', 'nan3', 'nan4', 'nang1', 'nang2', 'nao1', 'nao2', 'nao3', 'nao4', 'ne', 'ne2', 'ne4', 'nei3', 'nei4', 'nen4', 'neng2', 'ni1', 'ni2', 'ni3', 'ni4', 'nian1', 'nian2', 'nian3', 'nian4', 'niang2', 'niang4', 'niao2', 'niao3', 'niao4', 'nie1', 'nie4', 'nin2', 'ning2', 'ning3', 'ning4', 'niu1', 'niu2', 'niu3', 'niu4', 'nong2', 'nong4', 'nou4', 'nu2', 'nu3', 'nu4', 'nuan3', 'nuo2', 'nuo4', 'nv2', 'nv3', 'nve4', 'o1', 'o2', 'ou1', 'ou3', 'ou4', 'pa1', 'pa2', 'pa4', 'pai1', 'pai2', 'pai3', 'pai4', 'pan1', 'pan2', 'pan4', 'pang1', 'pang2', 'pang4', 'pao1', 'pao2', 'pao3', 'pao4', 'pei1', 'pei2', 'pei4', 'pen1', 'pen2', 'pen4', 'peng1', 'peng2', 'peng3', 'peng4', 'pi1', 'pi2', 'pi3', 'pi4', 'pian1', 'pian2', 'pian4', 'piao1', 'piao2', 'piao3', 'piao4', 'pie1', 'pie2', 'pie3', 'pin1', 'pin2', 'pin3', 'pin4', 'ping1', 'ping2', 'po1', 'po2', 'po3', 'po4', 'pou1', 'pu1', 'pu2', 'pu3', 'pu4', 'qi1', 'qi2', 'qi3', 'qi4', 'qia1', 'qia3', 'qia4', 'qian1', 'qian2', 'qian3', 'qian4', 'qiang1', 'qiang2', 'qiang3', 'qiang4', 'qiao1', 'qiao2', 'qiao3', 'qiao4', 'qie1', 'qie2', 'qie3', 'qie4', 'qin1', 'qin2', 'qin3', 'qin4', 'qing1', 'qing2', 'qing3', 'qing4', 'qiong1', 'qiong2', 'qiu1', 'qiu2', 'qiu3', 'qu1', 'qu2', 'qu3', 'qu4', 'quan1', 'quan2', 'quan3', 'quan4', 'que1', 'que2', 'que4', 'qun2', 'ran2', 'ran3', 'rang1', 'rang2', 'rang3', 'rang4', 'rao2', 'rao3', 'rao4', 're2', 're3', 're4', 'ren2', 'ren3', 'ren4', 'reng1', 'reng2', 'ri4', 'rong1', 'rong2', 'rong3', 'rou2', 'rou4', 'ru2', 'ru3', 'ru4', 'ruan2', 'ruan3', 'rui3', 'rui4', 'run4', 'ruo4', 'sa1', 'sa2', 'sa3', 'sa4', 'sai1', 'sai4', 'san1', 'san2', 'san3', 'san4', 'sang1', 'sang3', 'sang4', 'sao1', 'sao2', 'sao3', 'sao4', 'se4', 'sen1', 'seng1', 'sha1', 'sha2', 'sha3', 'sha4', 'shai1', 'shai2', 'shai3', 'shai4', 'shan1', 'shan3', 'shan4', 'shang', 'shang1', 'shang3', 'shang4', 'shao1', 'shao2', 'shao3', 'shao4', 'she1', 'she2', 'she3', 'she4', 'shei2', 'shen1', 'shen2', 'shen3', 'shen4', 'sheng1', 'sheng2', 'sheng3', 'sheng4', 'shi', 'shi1', 'shi2', 'shi3', 'shi4', 'shou1', 'shou2', 'shou3', 'shou4', 'shu1', 'shu2', 'shu3', 'shu4', 'shua1', 'shua2', 'shua3', 'shua4', 'shuai1', 'shuai3', 'shuai4', 'shuan1', 'shuan4', 'shuang1', 'shuang3', 'shui2', 'shui3', 'shui4', 'shun3', 'shun4', 'shuo1', 'shuo4', 'si1', 'si2', 'si3', 'si4', 'song1', 'song3', 'song4', 'sou1', 'sou3', 'sou4', 'su1', 'su2', 'su4', 'suan1', 'suan4', 'sui1', 'sui2', 'sui3', 'sui4', 'sun1', 'sun3', 'suo', 'suo1', 'suo2', 'suo3', 'ta1', 'ta3', 'ta4', 'tai1', 'tai2', 'tai4', 'tan1', 'tan2', 'tan3', 'tan4', 'tang1', 'tang2', 'tang3', 'tang4', 'tao1', 'tao2', 'tao3', 'tao4', 'te4', 'teng2', 'ti1', 'ti2', 'ti3', 'ti4', 'tian1', 'tian2', 'tian3', 'tiao1', 'tiao2', 'tiao3', 'tiao4', 'tie1', 'tie2', 'tie3', 'tie4', 'ting1', 'ting2', 'ting3', 'tong1', 'tong2', 'tong3', 'tong4', 'tou', 'tou1', 'tou2', 'tou4', 'tu1', 'tu2', 'tu3', 'tu4', 'tuan1', 'tuan2', 'tui1', 'tui2', 'tui3', 'tui4', 'tun1', 'tun2', 'tun4', 'tuo1', 'tuo2', 'tuo3', 'tuo4', 'wa', 'wa1', 'wa2', 'wa3', 'wa4', 'wai1', 'wai3', 'wai4', 'wan1', 'wan2', 'wan3', 'wan4', 'wang1', 'wang2', 'wang3', 'wang4', 'wei1', 'wei2', 'wei3', 'wei4', 'wen1', 'wen2', 'wen3', 'wen4', 'weng1', 'weng4', 'wo1', 'wo3', 'wo4', 'wu1', 'wu2', 'wu3', 'wu4', 'xi1', 'xi2', 'xi3', 'xi4', 'xia1', 'xia2', 'xia4', 'xian1', 'xian2', 'xian3', 'xian4', 'xiang1', 'xiang2', 'xiang3', 'xiang4', 'xiao1', 'xiao2', 'xiao3', 'xiao4', 'xie1', 'xie2', 'xie3', 'xie4', 'xin1', 'xin2', 'xin4', 'xing1', 'xing2', 'xing3', 'xing4', 'xiong1', 'xiong2', 'xiu1', 'xiu3', 'xiu4', 'xu', 'xu1', 'xu2', 'xu3', 'xu4', 'xuan1', 'xuan2', 'xuan3', 'xuan4', 'xue1', 'xue2', 'xue3', 'xue4', 'xun1', 'xun2', 'xun4', 'ya', 'ya1', 'ya2', 'ya3', 'ya4', 'yan1', 'yan2', 'yan3', 'yan4', 'yang1', 'yang2', 'yang3', 'yang4', 'yao1', 'yao2', 'yao3', 'yao4', 'ye1', 'ye2', 'ye3', 'ye4', 'yi1', 'yi2', 'yi3', 'yi4', 'yin1', 'yin2', 'yin3', 'yin4', 'ying1', 'ying2', 'ying3', 'ying4', 'yo1', 'yong1', 'yong3', 'yong4', 'you1', 'you2', 'you3', 'you4', 'yu1', 'yu2', 'yu3', 'yu4', 'yuan1', 'yuan2', 'yuan3', 'yuan4', 'yue1', 'yue4', 'yun1', 'yun2', 'yun3', 'yun4', 'za1', 'za2', 'za3', 'zai1', 'zai3', 'zai4', 'zan1', 'zan2', 'zan3', 'zan4', 'zang1', 'zang4', 'zao1', 'zao2', 'zao3', 'zao4', 'ze2', 'ze4', 'zei2', 'zen3', 'zeng1', 'zeng4', 'zha1', 'zha2', 'zha3', 'zha4', 'zhai1', 'zhai2', 'zhai3', 'zhai4', 'zhan1', 'zhan2', 'zhan3', 'zhan4', 'zhang1', 'zhang2', 'zhang3', 'zhang4', 'zhao1', 'zhao2', 'zhao3', 'zhao4', 'zhe', 'zhe1', 'zhe2', 'zhe3', 'zhe4', 'zhen1', 'zhen2', 'zhen3', 'zhen4', 'zheng1', 'zheng2', 'zheng3', 'zheng4', 'zhi1', 'zhi2', 'zhi3', 'zhi4', 'zhong1', 'zhong2', 'zhong3', 'zhong4', 'zhou1', 'zhou2', 'zhou3', 'zhou4', 'zhu1', 'zhu2', 'zhu3', 'zhu4', 'zhua1', 'zhua2', 'zhua3', 'zhuai1', 'zhuai3', 'zhuai4', 'zhuan1', 'zhuan2', 'zhuan3', 'zhuan4', 'zhuang1', 'zhuang4', 'zhui1', 'zhui4', 'zhun1', 'zhun2', 'zhun3', 'zhuo1', 'zhuo2', 'zi', 'zi1', 'zi2', 'zi3', 'zi4', 'zong1', 'zong2', 'zong3', 'zong4', 'zou1', 'zou2', 'zou3', 'zou4', 'zu1', 'zu2', 'zu3', 'zuan1', 'zuan3', 'zuan4', 'zui2', 'zui3', 'zui4', 'zun1', 'zuo1', 'zuo2', 'zuo3', 'zuo4']
ens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", ' ']
ens_U = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
phoneme_to_index = {}
num_phonemes = 0
for index, punc in enumerate(puncs):
phoneme_to_index[punc] = index + num_phonemes
num_phonemes += len(puncs)
for index, pinyin in enumerate(pinyins):
phoneme_to_index[pinyin] = index + num_phonemes
num_phonemes += len(pinyins)
for index, en in enumerate(ens):
phoneme_to_index[en] = index + num_phonemes
for index, en in enumerate(ens_U):
phoneme_to_index[en] = index + num_phonemes
num_phonemes += len(ens)
#print(num_phonemes, phoneme_to_index) # 1342
def encode(
text: list[str],
padding_value = -1
) -> Int['b nt']:
phonemes = []
for t in text:
one_phoneme = []
brk = False
for word in jieba.cut(t):
if all_ch(word):
seg = lazy_pinyin(word, style=Style.TONE3, tone_sandhi=True)
one_phoneme.extend(seg)
elif all_en(word):
for seg in word:
one_phoneme.append(seg)
elif word in [",", "。", "?", "、", "'", " "]:
one_phoneme.append(word)
else:
for ch in word:
if all_ch(ch):
seg = lazy_pinyin(ch, style=Style.TONE3, tone_sandhi=True)
one_phoneme.extend(seg)
elif all_en(ch):
for seg in ch:
one_phoneme.append(seg)
else:
brk = True
break
if brk:
break
if not brk:
phonemes.append(one_phoneme)
else:
print("Error Tokenized", t, list(jieba.cut(t)))
list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes]
padded_tensor = pad_sequence(list_tensors, padding_value = -1)
return padded_tensor
return encode, num_phonemes
# tensor helpers
def log(t, eps = 1e-5):
return t.clamp(min = eps).log()
def lens_to_mask(
t: Int['b'],
length: int | None = None
) -> Bool['b n']:
if not exists(length):
length = t.amax()
seq = torch.arange(length, device = t.device)
return einx.less('n, b -> b n', seq, t)
def mask_from_start_end_indices(
seq_len: Int['b'],
start: Int['b'],
end: Int['b']
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device = start.device).long()
return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
def mask_from_frac_lengths(
seq_len: Int['b'],
frac_lengths: Float['b'],
max_length: int | None = None,
val = False
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
if not val:
rand = torch.rand_like(frac_lengths)
else:
rand = torch.tensor([0.5]*frac_lengths.shape[0], device=frac_lengths.device).float()
start = (max_start * rand).long().clamp(min = 0)
end = start + lengths
out = mask_from_start_end_indices(seq_len, start, end)
if exists(max_length):
out = pad_to_length(out, max_length)
return out
def maybe_masked_mean(
t: Float['b n d'],
mask: Bool['b n'] | None = None
) -> Float['b d']:
if not exists(mask):
return t.mean(dim = 1)
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
num = reduce(t, 'b n d -> b d', 'sum')
den = reduce(mask.float(), 'b n -> b', 'sum')
return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
def pad_to_length(
t: Tensor,
length: int,
value = None
):
seq_len = t.shape[-1]
if length > seq_len:
t = F.pad(t, (0, length - seq_len), value = value)
return t[..., :length]
def interpolate_1d(
x: Tensor,
length: int,
mode = 'bilinear'
):
x = rearrange(x, 'n d -> 1 d n 1')
x = F.interpolate(x, (length, 1), mode = mode)
return rearrange(x, '1 d n 1 -> n d')
# to mel spec
class MelSpec(Module):
def __init__(
self,
filter_length = 1024,
hop_length = 256,
win_length = 1024,
n_mel_channels = 100,
sampling_rate = 24_000,
normalize = False,
power = 1,
norm = None,
center = True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate = sampling_rate,
n_fft = filter_length,
win_length = win_length,
hop_length = hop_length,
n_mels = n_mel_channels,
power = power,
center = center,
normalized = normalize,
norm = norm,
)
self.register_buffer('dummy', tensor(0), persistent = False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = rearrange(inp, 'b 1 nw -> b nw')
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = log(mel)
return mel
class EncodecWrapper(Module):
def __init__(self, path):
super().__init__()
self.model = EncodecModel.from_pretrained(path)
self.processor = AutoProcessor.from_pretrained(path)
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def forward(self, waveform):
with torch.no_grad():
inputs = self.processor(raw_audio=waveform[0], sampling_rate=self.processor.sampling_rate, return_tensors="pt")
emb = self.model.encoder(inputs.input_values)
return emb
def decode(self, emb):
with torch.no_grad():
output = self.model.decoder(emb)
return output[0]
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL
from audioldm.utils import default_audioldm_config, get_metadata
def build_pretrained_models(name):
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
config = default_audioldm_config(name)
vae_config = config["model"]["params"]["first_stage_config"]["params"]
vae_config["scale_factor"] = scale_factor
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(vae_state_dict)
fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
vae.eval()
fn_STFT.eval()
return vae, fn_STFT
class VaeWrapper(Module):
def __init__(self):
super().__init__()
vae, stft = build_pretrained_models("audioldm-s-full")
vae.eval()
stft.eval()
stft = stft.cpu()
self.vae = vae
for param in self.vae.parameters():
param.requires_grad = False
def forward(self, waveform):
return None
def decode(self, emb):
with torch.no_grad():
b, d, l = emb.shape
latents = emb.transpose(1,2).reshape(b, l, 8, 16).transpose(1,2)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
return wave
# convolutional positional generating module
# taken from https://github.com/lucidrains/voicebox-pytorch/blob/main/voicebox_pytorch/voicebox_pytorch.py#L203
class DepthwiseConv(Module):
def __init__(
self,
dim,
*,
kernel_size,
groups = None
):
super().__init__()
assert not divisible_by(kernel_size, 2)
groups = default(groups, dim) # full depthwise conv by default
self.dw_conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.SiLU()
)
def forward(
self,
x,
mask = None
):
if exists(mask):
x = einx.where('b n, b n d, -> b n d', mask, x, 0.)
x = rearrange(x, 'b n c -> b c n')
x = self.dw_conv1d(x)
out = rearrange(x, 'b c n -> b n c')
if exists(mask):
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
return out
# adaln zero from DiT paper
class AdaLNZero(Module):
def __init__(
self,
dim,
dim_condition = None,
init_bias_value = -2.
):
super().__init__()
dim_condition = default(dim_condition, dim)
self.to_gamma = nn.Linear(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)
nn.init.constant_(self.to_gamma.bias, init_bias_value)
def forward(self, x, *, condition):
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
gamma = self.to_gamma(condition).sigmoid()
return x * gamma
# random projection fourier embedding
class RandomFourierEmbed(Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
self.register_buffer('weights', torch.randn(dim // 2))
def forward(self, x):
freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi
fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *')
return fourier_embed
# character embedding
class CharacterEmbed(Module):
def __init__(
self,
dim,
num_embeds = 256,
):
super().__init__()
self.dim = dim
self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
def forward(
self,
text: Int['b nt'],
max_seq_len: int,
**kwargs
) -> Float['b n d']:
text = text + 1 # shift all other token ids up by 1 and use 0 as filler token
text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address
text = pad_to_length(text, max_seq_len, value = 0)
return self.embed(text)
class InterpolatedCharacterEmbed(Module):
def __init__(
self,
dim,
num_embeds = 256,
):
super().__init__()
self.dim = dim
self.embed = nn.Embedding(num_embeds, dim)
self.abs_pos_mlp = Sequential(
Rearrange('... -> ... 1'),
Linear(1, dim),
nn.SiLU(),
Linear(dim, dim)
)
def forward(
self,
text: Int['b nt'],
max_seq_len: int,
mask: Bool['b n'] | None = None
) -> Float['b n d']:
device = text.device
mask = default(mask, (None,))
interp_embeds = []
interp_abs_positions = []
for one_text, one_mask in zip_longest(text, mask):
valid_text = one_text >= 0
one_text = one_text[valid_text]
one_text_embed = self.embed(one_text)
# save the absolute positions
text_seq_len = one_text.shape[0]
# determine audio sequence length from mask
audio_seq_len = max_seq_len
if exists(one_mask):
audio_seq_len = one_mask.sum().long().item()
# interpolate text embedding to audio embedding length
interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len)
interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device)
interp_embeds.append(interp_text_embed)
interp_abs_positions.append(interp_abs_pos)
interp_embeds = pad_sequence(interp_embeds)
interp_abs_positions = pad_sequence(interp_abs_positions)
interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2]))
interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len)
# pass interp absolute positions through mlp for implicit positions
interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions)
if exists(mask):
interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.)
return interp_embeds
# text audio cross conditioning in multistream setup
class TextAudioCrossCondition(Module):
def __init__(
self,
dim,
dim_text,
dim_frames,
cond_audio_to_text = True,
):
super().__init__()
#self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False)
self.text_frames_to_audio = nn.Linear(dim + dim_text + dim_frames, dim, bias = False)
nn.init.zeros_(self.text_frames_to_audio.weight)
self.cond_audio_to_text = cond_audio_to_text
if cond_audio_to_text:
self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False)
nn.init.zeros_(self.audio_to_text.weight)
self.audio_to_frames = nn.Linear(dim + dim_frames, dim_frames, bias = False)
nn.init.zeros_(self.audio_to_frames.weight)
def forward(
self,
audio: Float['b n d'],
text: Float['b n dt'],
frames: Float['b n df'],
):
#audio_text, _ = pack((audio, text), 'b n *')
audio_text_frames, _ = pack((audio, text, frames), 'b n *')
audio_text, _ = pack((audio, text), 'b n *')
audio_frames, _ = pack((audio, frames), 'b n *')
#text_cond = self.text_to_audio(audio_text)
text_cond = self.text_frames_to_audio(audio_text_frames)
audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0.
audio_cond2 = self.audio_to_frames(audio_frames) if self.cond_audio_to_text else 0.
return audio + text_cond, text + audio_cond, frames + audio_cond2
# attention and transformer backbone
# for use in both e2tts as well as duration module
class Transformer(Module):
@beartype
def __init__(
self,
*,
dim,
dim_text = None, # will default to half of audio dimension
dim_frames = 512,
depth = 8,
heads = 8,
dim_head = 64,
ff_mult = 4,
text_depth = None,
text_heads = None,
text_dim_head = None,
text_ff_mult = None,
cond_on_time = True,
abs_pos_emb = True,
max_seq_len = 8192,
kernel_size = 31,
dropout = 0.1,
num_registers = 32,
attn_kwargs: dict = dict(
gate_value_heads = True,
softclamp_logits = True,
),
ff_kwargs: dict = dict(),
if_text_modules = True,
if_cross_attn = True,
if_audio_conv = True,
if_text_conv = False
):
super().__init__()
assert divisible_by(depth, 2), 'depth needs to be even'
# absolute positional embedding
self.max_seq_len = max_seq_len
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None
self.dim = dim
dim_text = default(dim_text, dim // 2)
self.dim_text = dim_text
self.dim_frames = dim_frames
text_heads = default(text_heads, heads)
text_dim_head = default(text_dim_head, dim_head)
text_ff_mult = default(text_ff_mult, ff_mult)
text_depth = default(text_depth, depth)
assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers'
self.depth = depth
self.layers = ModuleList([])
# registers
self.num_registers = num_registers
self.registers = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.registers, std = 0.02)
if if_text_modules:
self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text))
nn.init.normal_(self.text_registers, std = 0.02)
self.frames_registers = nn.Parameter(torch.zeros(num_registers, dim_frames))
nn.init.normal_(self.frames_registers, std = 0.02)
# rotary embedding
self.rotary_emb = RotaryEmbedding(dim_head)
self.text_rotary_emb = RotaryEmbedding(dim_head)
self.frames_rotary_emb = RotaryEmbedding(dim_head)
# time conditioning
# will use adaptive rmsnorm
self.cond_on_time = cond_on_time
rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm
postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim)
self.time_cond_mlp = Identity()
if cond_on_time:
self.time_cond_mlp = Sequential(
RandomFourierEmbed(dim),
Linear(dim + 1, dim),
nn.SiLU()
)
for ind in range(depth):
is_later_half = ind >= (depth // 2)
has_text = ind < text_depth
# speech related
if if_audio_conv:
speech_conv = DepthwiseConv(dim, kernel_size = kernel_size)
attn_norm = rmsnorm_klass(dim)
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)
attn_adaln_zero = postbranch_klass()
if if_cross_attn:
attn_norm2 = rmsnorm_klass(dim)
attn2 = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)
attn_adaln_zero2 = postbranch_klass()
ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs)
ff_adaln_zero = postbranch_klass()
skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None
if if_cross_attn:
if if_audio_conv:
speech_modules = ModuleList([
skip_proj,
speech_conv,
attn_norm,
attn,
attn_adaln_zero,
attn_norm2,
attn2,
attn_adaln_zero2,
ff_norm,
ff,
ff_adaln_zero,
])
else:
speech_modules = ModuleList([
skip_proj,
attn_norm,
attn,
attn_adaln_zero,
attn_norm2,
attn2,
attn_adaln_zero2,
ff_norm,
ff,
ff_adaln_zero,
])
else:
if if_audio_conv:
speech_modules = ModuleList([
skip_proj,
speech_conv,
attn_norm,
attn,
attn_adaln_zero,
ff_norm,
ff,
ff_adaln_zero,
])
else:
speech_modules = ModuleList([
skip_proj,
attn_norm,
attn,
attn_adaln_zero,
ff_norm,
ff,
ff_adaln_zero,
])
text_modules = None
if has_text and if_text_modules:
# text related
if if_text_conv:
text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size)
text_attn_norm = RMSNorm(dim_text)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs)
text_ff_norm = RMSNorm(dim_text)
text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs)
# cross condition
is_last = ind == (text_depth - 1)
cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, dim_frames = dim_frames, cond_audio_to_text = not is_last)
if if_text_conv:
text_modules = ModuleList([
text_conv,
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
])
else:
text_modules = ModuleList([
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
])
if True:
frames_conv = DepthwiseConv(dim_frames, kernel_size = kernel_size)
frames_attn_norm = RMSNorm(dim_frames)
frames_attn = Attention(dim = dim_frames, heads = 8, dim_head = 64, dropout = dropout, **attn_kwargs)
frames_ff_norm = RMSNorm(dim_frames)
frames_ff = FeedForward(dim = dim_frames, glu = True, mult = 4, dropout = dropout, **ff_kwargs)
# cross condition
frames_modules = ModuleList([
frames_conv,
frames_attn_norm,
frames_attn,
frames_ff_norm,
frames_ff
])
self.layers.append(ModuleList([
speech_modules,
text_modules,
frames_modules
]))
self.final_norm = RMSNorm(dim)
self.if_cross_attn = if_cross_attn
self.if_audio_conv = if_audio_conv
self.if_text_conv = if_text_conv
def forward(
self,
x: Float['b n d'],
times: Float['b'] | Float[''] | None = None,
mask: Bool['b n'] | None = None,
text_embed: Float['b n dt'] | None = None,
frames_embed: Float['b n df'] | None = None,
context: Float['b nc dc'] | None = None,
context_mask: Float['b nc'] | None = None
):
batch, seq_len, device = *x.shape[:2], x.device
assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'
# handle absolute positions if needed
if exists(self.abs_pos_emb):
assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer'
seq = torch.arange(seq_len, device = device)
x = x + self.abs_pos_emb(seq)
# handle adaptive rmsnorm kwargs
norm_kwargs = dict()
if exists(times):
if times.ndim == 0:
times = repeat(times, ' -> b', b = batch)
times = self.time_cond_mlp(times)
norm_kwargs.update(condition = times)
# register tokens
registers = repeat(self.registers, 'r d -> b r d', b = batch)
x, registers_packed_shape = pack((registers, x), 'b * d')
if exists(mask):
mask = F.pad(mask, (self.num_registers, 0), value = True)
# rotary embedding
rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])
# text related
if exists(text_embed):
text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2])
text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch)
text_embed, _ = pack((text_registers, text_embed), 'b * d')
if exists(frames_embed):
frames_rotary_pos_emb = self.frames_rotary_emb.forward_from_seq_len(x.shape[-2])
frames_registers = repeat(self.frames_registers, 'r d -> b r d', b = batch)
frames_embed, _ = pack((frames_registers, frames_embed), 'b * d')
# skip connection related stuff
skips = []
# go through the layers
for ind, (speech_modules, text_modules, frames_modules) in enumerate(self.layers):
layer = ind + 1
if self.if_cross_attn:
if self.if_audio_conv:
(
maybe_skip_proj,
speech_conv,
attn_norm,
attn,
maybe_attn_adaln_zero,
attn_norm2,
attn2,
maybe_attn_adaln_zero2,
ff_norm,
ff,
maybe_ff_adaln_zero
) = speech_modules
else:
(
maybe_skip_proj,
attn_norm,
attn,
maybe_attn_adaln_zero,
attn_norm2,
attn2,
maybe_attn_adaln_zero2,
ff_norm,
ff,
maybe_ff_adaln_zero
) = speech_modules
else:
if self.if_audio_conv:
(
maybe_skip_proj,
speech_conv,
attn_norm,
attn,
maybe_attn_adaln_zero,
ff_norm,
ff,
maybe_ff_adaln_zero
) = speech_modules
else:
(
maybe_skip_proj,
attn_norm,
attn,
maybe_attn_adaln_zero,
ff_norm,
ff,
maybe_ff_adaln_zero
) = speech_modules
# smaller text transformer
if exists(text_embed) and exists(text_modules):
if self.if_text_conv:
(
text_conv,
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
) = text_modules
else:
(
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
) = text_modules
if self.if_text_conv:
text_embed = text_conv(text_embed, mask = mask) + text_embed
text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed
text_embed = text_ff(text_ff_norm(text_embed)) + text_embed
# frames transformer
(
frames_conv,
frames_attn_norm,
frames_attn,
frames_ff_norm,
frames_ff
) = frames_modules
frames_embed = frames_conv(frames_embed, mask = mask) + frames_embed
frames_embed = frames_attn(frames_attn_norm(frames_embed), rotary_pos_emb = frames_rotary_pos_emb, mask = mask) + frames_embed
frames_embed = frames_ff(frames_ff_norm(frames_embed)) + frames_embed
# cross condition
x, text_embed, frames_embed = cross_condition(x, text_embed, frames_embed)
# skip connection logic
is_first_half = layer <= (self.depth // 2)
is_later_half = not is_first_half
if is_first_half:
skips.append(x)
if is_later_half:
skip = skips.pop()
x = torch.cat((x, skip), dim = -1)
x = maybe_skip_proj(x)
# position generating convolution
if self.if_audio_conv:
x = speech_conv(x, mask = mask) + x
# attention and feedforward blocks
attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask)
x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)
if self.if_cross_attn:
attn_out = attn2(attn_norm2(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, context = context, context_mask = context_mask)
x = x + maybe_attn_adaln_zero2(attn_out, **norm_kwargs)
ff_out = ff(ff_norm(x, **norm_kwargs))
x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs)
assert len(skips) == 0
_, x = unpack(x, registers_packed_shape, 'b * d')
return self.final_norm(x)
# main classes
class DurationPredictor(Module):
@beartype
def __init__(
self,
transformer: dict | Transformer,
num_channels = None,
mel_spec_kwargs: dict = dict(),
char_embed_kwargs: dict = dict(),
text_num_embeds = None,
tokenizer: (
Literal['char_utf8', 'phoneme_en'] |
Callable[[list[str]], Int['b nt']]
) = 'char_utf8'
):
super().__init__()
if isinstance(transformer, dict):
transformer = Transformer(
**transformer,
cond_on_time = False
)
# mel spec
self.mel_spec = MelSpec(**mel_spec_kwargs)
self.num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.transformer = transformer
dim = transformer.dim
dim_text = transformer.dim_text
self.dim = dim
self.proj_in = Linear(self.num_channels, self.dim)
# tokenizer and text embed
if callable(tokenizer):
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
self.tokenizer = tokenizer
elif tokenizer == 'char_utf8':
text_num_embeds = 256
self.tokenizer = list_str_to_tensor
elif tokenizer == 'phoneme_en':
self.tokenizer, text_num_embeds = get_g2p_en_encode()
elif tokenizer == 'phoneme_zh':
self.tokenizer, text_num_embeds = get_g2p_zh_encode()
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')
self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
# to prediction
self.to_pred = Sequential(
Linear(dim, 1, bias = False),
nn.Softplus(),
Rearrange('... 1 -> ...')
)
def forward(
self,
x: Float['b n d'] | Float['b nw'],
*,
text: Int['b nt'] | list[str] | None = None,
lens: Int['b'] | None = None,
return_loss = True
):
# raw wave
if x.ndim == 2:
x = self.mel_spec(x)
x = rearrange(x, 'b d n -> b n d')
assert x.shape[-1] == self.dim
x = self.proj_in(x)
batch, seq_len, device = *x.shape[:2], x.device
# text
text_embed = None
if exists(text):
if isinstance(text, list):
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
text_embed = self.embed_text(text, seq_len)
# handle lengths (duration)
if not exists(lens):
lens = torch.full((batch,), seq_len, device = device)
mask = lens_to_mask(lens, length = seq_len)
# if returning a loss, mask out randomly from an index and have it predict the duration
if return_loss:
rand_frac_index = x.new_zeros(batch).uniform_(0, 1)
rand_index = (rand_frac_index * lens).long()
seq = torch.arange(seq_len, device = device)
mask &= einx.less('n, b -> b n', seq, rand_index)
# attending
x = self.transformer(
x,
mask = mask,
text_embed = text_embed,
)
x = maybe_masked_mean(x, mask)
pred = self.to_pred(x)
# return the prediction if not returning loss
if not return_loss:
return pred
# loss
return F.mse_loss(pred, lens.float())
class E2TTS(Module):
@beartype
def __init__(
self,
transformer: dict | Transformer = None,
duration_predictor: dict | DurationPredictor | None = None,
odeint_kwargs: dict = dict(
#atol = 1e-5,
#rtol = 1e-5,
#method = 'midpoint'
method = "euler"
),
audiocond_drop_prob = 0.30,
cond_drop_prob = 0.20,
prompt_drop_prob = 0.10,
num_channels = None,
mel_spec_module: Module | None = None,
char_embed_kwargs: dict = dict(),
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
audiocond_snr: tuple[float, float] | None = None,
concat_cond = False,
interpolated_text = False,
text_num_embeds: int | None = None,
tokenizer: (
Literal['char_utf8', 'phoneme_en', 'phoneme_zh'] |
Callable[[list[str]], Int['b nt']]
) = 'char_utf8',
use_vocos = True,
pretrained_vocos_path = 'charactr/vocos-mel-24khz',
sampling_rate: int | None = None,
frame_size: int = 320,
#### dpo
velocity_consistency_weight = -1e-5,
#### dpo
if_cond_proj_in = True,
cond_proj_in_bias = True,
if_embed_text = True,
if_text_encoder2 = True,
if_clip_encoder = False,
video_encoder = "clip_vit"
):
super().__init__()
if isinstance(transformer, dict):
transformer = Transformer(
**transformer,
cond_on_time = True
)
if isinstance(duration_predictor, dict):
duration_predictor = DurationPredictor(**duration_predictor)
self.transformer = transformer
dim = transformer.dim
dim_text = transformer.dim_text
dim_frames = transformer.dim_frames
self.dim = dim
self.dim_text = dim_text
self.frac_lengths_mask = frac_lengths_mask
self.audiocond_snr = audiocond_snr
self.duration_predictor = duration_predictor
# sampling
self.odeint_kwargs = odeint_kwargs
# mel spec
self.mel_spec = default(mel_spec_module, None)
num_channels = default(num_channels, None)
self.num_channels = num_channels
self.sampling_rate = default(sampling_rate, None)
self.frame_size = frame_size
# whether to concat condition and project rather than project both and sum
self.concat_cond = concat_cond
if concat_cond:
self.proj_in = nn.Linear(num_channels * 2, dim)
else:
self.proj_in = nn.Linear(num_channels, dim)
self.cond_proj_in = nn.Linear(num_channels, dim, bias=cond_proj_in_bias) if if_cond_proj_in else None
#self.cond_proj_in = nn.Linear(NOTES, dim, bias=cond_proj_in_bias) if if_cond_proj_in else None
# to prediction
self.to_pred = Linear(dim, num_channels)
# tokenizer and text embed
if callable(tokenizer):
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
self.tokenizer = tokenizer
elif tokenizer == 'char_utf8':
text_num_embeds = 256
self.tokenizer = list_str_to_tensor
elif tokenizer == 'phoneme_en':
self.tokenizer, text_num_embeds = get_g2p_en_encode()
elif tokenizer == 'phoneme_zh':
self.tokenizer, text_num_embeds = get_g2p_zh_encode()
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')
self.audiocond_drop_prob = audiocond_drop_prob
self.cond_drop_prob = cond_drop_prob
self.prompt_drop_prob = prompt_drop_prob
# text embedding
text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed
self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) if if_embed_text else None
# weight for velocity consistency
self.register_buffer('zero', torch.tensor(0.), persistent = False)
self.velocity_consistency_weight = velocity_consistency_weight
# default vocos for mel -> audio
#if pretrained_vocos_path == 'charactr/vocos-mel-24khz':
# self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None
#elif pretrained_vocos_path == 'facebook/encodec_24khz':
# self.vocos = EncodecWrapper("facebook/encodec_24khz") if use_vocos else None
#elif pretrained_vocos_path == 'vae':
# self.vocos = VaeWrapper() if use_vocos else None
if if_text_encoder2:
self.tokenizer2 = AutoTokenizer.from_pretrained("./ckpts/flan-t5-large")
self.text_encoder2 = T5EncoderModel.from_pretrained("./ckpts/flan-t5-large")
for param in self.text_encoder2.parameters():
param.requires_grad = False
self.text_encoder2.eval()
self.proj_text = None
self.proj_frames = Linear(NOTES, dim_frames)
if if_clip_encoder:
if video_encoder == "clip_vit":
####pass
self.image_processor = CLIPImageProcessor()
#self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/IP-Adapter/", subfolder="models/image_encoder")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("./ckpts/IP-Adapter/", subfolder="sdxl_models/image_encoder")
elif video_encoder == "clip_vit2":
self.image_processor = AutoProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/")
elif video_encoder == "clip_convnext":
self.image_encoder, _, self.image_processor = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup")
elif video_encoder == "dinov2":
self.image_processor = AutoImageProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/")
self.image_encoder = AutoModel.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/")
elif video_encoder == "mixed":
pass
#self.image_processor1 = CLIPImageProcessor()
#self.image_encoder1 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/IP-Adapter/", subfolder="sdxl_models/image_encoder")
#self.image_processor2 = AutoProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/")
#self.image_encoder2 = CLIPVisionModelWithProjection.from_pretrained("/ailab-train2/speech/zhanghaomin/models/clip-vit-large-patch14-336/")
#self.image_encoder3, _, self.image_processor3 = open_clip.create_model_and_transforms("hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup")
#self.image_processor4 = AutoImageProcessor.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/")
#self.image_encoder4 = AutoModel.from_pretrained("/ailab-train2/speech/zhanghaomin/models/dinov2-giant/")
else:
self.image_processor = None
self.image_encoder = None
if video_encoder != "mixed":
####pass
for param in self.image_encoder.parameters():
param.requires_grad = False
self.image_encoder.eval()
else:
#for param in self.image_encoder1.parameters():
# param.requires_grad = False
#self.image_encoder1.eval()
#for param in self.image_encoder2.parameters():
# param.requires_grad = False
#self.image_encoder2.eval()
#for param in self.image_encoder3.parameters():
# param.requires_grad = False
#self.image_encoder3.eval()
#for param in self.image_encoder4.parameters():
# param.requires_grad = False
#self.image_encoder4.eval()
self.dim_text_raw = 4608
self.proj_text = Linear(self.dim_text_raw, dim_text)
self.video_encoder = video_encoder
#for param in self.vocos.parameters():
# param.requires_grad = False
#self.vocos.eval()
########self.conv1 = nn.Conv3d(6, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.pool1 = nn.Conv3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2))
####self.conv1 = nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.pool1 = nn.Conv3d(16, 16, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0))
####
########self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.pool2 = nn.Conv3d(128, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
####self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.pool2 = nn.Conv3d(32, 32, kernel_size=(2, 2, 2), stride=(1, 2, 2), padding=(0,0,0))
####
########self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.pool3 = nn.Conv3d(256, 256, kernel_size=(1, 2, 2), stride=(1, 2, 2))
####self.conv3a = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.conv3b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.pool3 = nn.Conv3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0))
####
########self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.pool4 = nn.Conv3d(512, 512, kernel_size=(1, 2, 2), stride=(1, 2, 2))
#####self.conv4a = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
#####self.conv4b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
######self.pool4 = nn.Conv3d(64, 32, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0,0,0))
#####self.pool4 = nn.ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####
########self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
########self.pool5 = nn.ConvTranspose3d(512, 128, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1))
####self.conv5a = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.conv5b = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
####self.pool5 = nn.ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1))
#####self.pool5 = nn.Conv3d(256, 256, kernel_size=(1, 2, 2), stride=(1, 2, 2))
####
####self.relu = nn.ReLU()
####self.final_activation = nn.Sigmoid()
####self.dropout = nn.Dropout(p=0.50)
########self.fc5 = nn.Linear(51200, NOTES)
####self.fc5 = nn.Linear(65536, 208)
####self.fc6 = nn.Linear(208, NOTES)
####
#####self.rnn = nn.RNN(NOTES, NOTES, 1)
#####self.fc7 = nn.Linear(NOTES, NOTES)
####
#####self.bn1 = nn.BatchNorm3d(16)
#####self.bn2 = nn.BatchNorm3d(32)
#####self.bn3 = nn.BatchNorm3d(64)
#####self.bn4 = nn.BatchNorm3d(64)
#####self.bn5 = nn.BatchNorm3d(64)
#####self.bn6 = nn.BatchNorm3d(64)
#####self.bn7 = nn.BatchNorm1d(208)
self.video2roll_net = Video2RollNet.resnet18(num_classes=NOTES)
def encode_frames(self, x, l):
#print("x input", x.shape, l) # [1, 1, 251, 100, 900]
b, c, t, w, h = x.shape
assert(c == 1)
x_all = []
for i in range(t):
frames = []
for j in [-2, -1, 0, 1, 2]:
f = min(max(i+j, 0), t-1)
frames.append(x[:,:,f:f+1,:,:])
frames = torch.cat(frames, dim=2) # [b, 1, 5, w, h]
x_all.append(frames)
x = torch.cat(x_all, dim=1).reshape(b*t, 5, w, h) # [b*t, 5, w, h]
#print("x", x.shape, l) # [251, 5, 100, 900]
x = self.video2roll_net(x)
x = nn.Sigmoid()(x)
#print("x output", x.shape) # [251, 51]
####video_multi
####x = x.reshape(b, t, 1, NOTES).repeat(1,1,3,1).reshape(b, t*3, NOTES)
t5 = (t*5//2)*2
x = x.reshape(b, t, 1, NOTES).repeat(1,1,5,1).reshape(b, t*5, NOTES)[:,:t5,:].reshape(b, t5//2, 2, NOTES).mean(2)
b, d, _ = x.shape
#print("encode_frames", x.shape, l)
if d > l:
x = x[:,:l,:]
elif d < l:
x = torch.cat((x, torch.zeros(b,l-d,NOTES,device=x.device)), 1)
return x
return x
####def encode_frames(self, x, l):
#### x = x[:,:3,:,...]
#### #print("x", x.shape) # [2, 6, 301, 320, 320] # [2, 3, 251, 128, 1024]
#### x = self.conv1(x)
#### #x = self.bn1(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### #print("conv1", x.shape) # [2, 64, 301, 320, 320] # [2, 16, 251, 128, 1024]
#### x = self.pool1(x)
#### #print("pool1", x.shape) # [2, 64, 301, 160, 160] # [2, 16, 251, 64, 512]
####
#### x = self.conv2(x)
#### #x = self.bn2(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### #print("conv2", x.shape) # [2, 128, 301, 160, 160] # [2, 32, 250, 64, 512]
#### x = self.pool2(x)
#### #x = self.relu(x)
#### #x = self.dropout(x)
#### #print("pool2", x.shape) # [2, 128, 150, 80, 80] # [2, 32, 250, 32, 256]
####
#### x = self.conv3a(x)
#### #x = self.bn3(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### x = self.conv3b(x)
#### #x = self.bn4(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### #print("conv3", x.shape) # [2, 256, 150, 80, 80] # [2, 64, 250, 32, 256]
#### x = self.pool3(x)
#### #x = self.relu(x)
#### #x = self.dropout(x)
#### #print("pool3", x.shape) # [2, 256, 150, 40, 40] # [2, 64, 250, 16, 128]
####
#### #x = self.conv4a(x)
#### #x = self.relu(x)
#### ##x = self.dropout(x)
#### #x = self.conv4b(x)
#### #x = self.relu(x)
#### ##x = self.dropout(x)
#### ###print("conv4", x.shape) # [2, 512, 150, 40, 40] # [2, 64, 250, 16, 128]
#### #x = self.pool4(x)
#### ##print("pool4", x.shape) # [2, 512, 150, 20, 20] # [2, 32, 250, 8, 64]
####
#### x = self.conv5a(x)
#### #x = self.bn5(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### x = self.conv5b(x)
#### #x = self.bn6(x)
#### x = self.relu(x)
#### x = self.dropout(x)
#### #print("conv5", x.shape) # [2, 512, 150, 20, 20] # [2, 64, 250, 16, 128]
#### x = self.pool5(x)
#### #x = self.relu(x)
#### #x = self.dropout(x)
#### #print("pool5", x.shape) # [2, 128, 750, 20, 20] # [2, 32, 750/250, 16, 128]
####
#### b, c, d, w, h = x.shape
#### x = x.permute(0,2,3,4,1).reshape(b,d,w*h*c)
#### x = self.fc5(x)
#### #x = x.reshape(b,208,d)
#### #x = self.bn7(x)
#### #x = x.reshape(b,d,208)
#### x = self.relu(x)
#### x = self.dropout(x)
#### x = self.fc6(x)
####
#### #x = self.relu(x)
#### #x, _ = self.rnn(x)
#### #x = self.fc7(x)
####
#### x = self.final_activation(x)
####
#### #x = x.reshape(b,d,1,NOTES).repeat(1,1,3,1).reshape(b,d*3,NOTES)
#### #d = d * 3
####
#### #print("encode_frames", x.shape, l)
#### if d > l:
#### x = x[:,:l,:]
#### elif d < l:
#### x = torch.cat((x, torch.zeros(b,l-d,NOTES,device=x.device)), 1)
#### return x
@property
def device(self):
return next(self.parameters()).device
def encode_text(self, prompt):
device = self.device
batch = self.tokenizer2(prompt, max_length=self.tokenizer2.model_max_length, padding=True, truncation=True, return_tensors="pt")
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
with torch.no_grad():
encoder_hidden_states = self.text_encoder2(input_ids=input_ids, attention_mask=attention_mask)[0]
boolean_encoder_mask = (attention_mask == 1).to(device)
return encoder_hidden_states, boolean_encoder_mask
def encode_video(self, video_paths, l):
if self.proj_text is None:
d = self.dim_text
else:
d = self.dim_text_raw
device = self.device
b = 20
with torch.no_grad():
video_embeddings = []
video_lens = []
for video_path in video_paths:
if video_path is None:
video_embeddings.append(None)
video_lens.append(0)
continue
if isinstance(video_path, tuple):
video_path, start_sample, max_sample = video_path
else:
start_sample = 0
max_sample = None
if video_path.startswith("/ailab-train2/speech/zhanghaomin/VGGSound/"):
if self.video_encoder == "clip_vit":
feature_path = video_path.replace("/video/", "/feature/").replace(".mp4", ".npz")
elif self.video_encoder == "clip_vit2":
feature_path = video_path.replace("/video/", "/feature_clip_vit2/").replace(".mp4", ".npz")
elif self.video_encoder == "clip_convnext":
feature_path = video_path.replace("/video/", "/feature_clip_convnext/").replace(".mp4", ".npz")
elif self.video_encoder == "dinov2":
feature_path = video_path.replace("/video/", "/feature_dinov2/").replace(".mp4", ".npz")
elif self.video_encoder == "mixed":
feature_path = video_path.replace("/video/", "/feature_mixed/").replace(".mp4", ".npz")
else:
raise Exception("Invalid video_encoder " + self.video_encoder)
else:
if self.video_encoder == "clip_vit":
feature_path = video_path.replace(".mp4", ".generated.npz")
elif self.video_encoder == "clip_vit2":
feature_path = video_path.replace(".mp4", ".generated.clip_vit2.npz")
elif self.video_encoder == "clip_convnext":
feature_path = video_path.replace(".mp4", ".generated.clip_convnext.npz")
elif self.video_encoder == "dinov2":
feature_path = video_path.replace(".mp4", ".generated.dinov2.npz")
elif self.video_encoder == "mixed":
feature_path = video_path.replace(".mp4", ".generated.mixed.npz")
else:
raise Exception("Invalid video_encoder " + self.video_encoder)
if not os.path.exists(feature_path):
#print("video not exist", video_path)
frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=None)
if frames is None:
video_embeddings.append(None)
video_lens.append(0)
continue
if self.video_encoder in ["clip_vit", "clip_vit2", "dinov2"]:
images = self.image_processor(images=frames, return_tensors="pt").to(device)
#print("images", images["pixel_values"].shape, images["pixel_values"].max(), images["pixel_values"].min(), torch.abs(images["pixel_values"]).mean())
elif self.video_encoder in ["clip_convnext"]:
images = []
for i in range(frames.shape[0]):
images.append(self.image_processor(Image.fromarray(frames[i])).unsqueeze(0))
images = torch.cat(images, dim=0).to(device)
#print("images", images.shape, images.max(), images.min(), torch.abs(images).mean())
elif self.video_encoder in ["mixed"]:
#images1 = self.image_processor1(images=frames, return_tensors="pt").to(device)
images2 = self.image_processor2(images=frames, return_tensors="pt").to(device)
images4 = self.image_processor4(images=frames, return_tensors="pt").to(device)
images3 = []
for i in range(frames.shape[0]):
images3.append(self.image_processor3(Image.fromarray(frames[i])).unsqueeze(0))
images3 = torch.cat(images3, dim=0).to(device)
else:
raise Exception("Invalid video_encoder " + self.video_encoder)
image_embeddings = []
if self.video_encoder == "clip_vit":
for i in range(math.ceil(images["pixel_values"].shape[0] / b)):
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu())
elif self.video_encoder == "clip_vit2":
for i in range(math.ceil(images["pixel_values"].shape[0] / b)):
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu())
elif self.video_encoder == "clip_convnext":
for i in range(math.ceil(images.shape[0] / b)):
image_embeddings.append(self.image_encoder.encode_image(images[i*b: (i+1)*b]).cpu())
elif self.video_encoder == "dinov2":
for i in range(math.ceil(images["pixel_values"].shape[0] / b)):
image_embeddings.append(self.image_encoder(pixel_values=images["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu())
elif self.video_encoder == "mixed":
feature_path1 = feature_path.replace("/feature_mixed/", "/feature/")
if not os.path.exists(feature_path1):
image_embeddings1 = []
for i in range(math.ceil(images1["pixel_values"].shape[0] / b)):
image_embeddings1.append(self.image_encoder1(pixel_values=images1["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu())
image_embeddings1 = torch.cat(image_embeddings1, dim=0)
#np.savez(feature_path1, image_embeddings1, duration)
else:
data1 = np.load(feature_path1)
image_embeddings1 = torch.from_numpy(data1["arr_0"])
feature_path2 = feature_path.replace("/feature_mixed/", "/feature_clip_vit2/")
if not os.path.exists(feature_path2):
image_embeddings2 = []
for i in range(math.ceil(images2["pixel_values"].shape[0] / b)):
image_embeddings2.append(self.image_encoder2(pixel_values=images2["pixel_values"][i*b: (i+1)*b]).image_embeds.cpu())
image_embeddings2 = torch.cat(image_embeddings2, dim=0)
np.savez(feature_path2, image_embeddings2, duration)
else:
data2 = np.load(feature_path2)
image_embeddings2 = torch.from_numpy(data2["arr_0"])
feature_path3 = feature_path.replace("/feature_mixed/", "/feature_clip_convnext/")
if not os.path.exists(feature_path3):
image_embeddings3 = []
for i in range(math.ceil(images3.shape[0] / b)):
image_embeddings3.append(self.image_encoder3.encode_image(images3[i*b: (i+1)*b]).cpu())
image_embeddings3 = torch.cat(image_embeddings3, dim=0)
np.savez(feature_path3, image_embeddings3, duration)
else:
data3 = np.load(feature_path3)
image_embeddings3 = torch.from_numpy(data3["arr_0"])
feature_path4 = feature_path.replace("/feature_mixed/", "/feature_dinov2/")
if not os.path.exists(feature_path4):
image_embeddings4 = []
for i in range(math.ceil(images4["pixel_values"].shape[0] / b)):
image_embeddings4.append(self.image_encoder4(pixel_values=images4["pixel_values"][i*b: (i+1)*b]).pooler_output.cpu())
image_embeddings4 = torch.cat(image_embeddings4, dim=0)
np.savez(feature_path4, image_embeddings4, duration)
else:
data4 = np.load(feature_path4)
image_embeddings4 = torch.from_numpy(data4["arr_0"])
mixed_l = min([image_embeddings1.shape[0], image_embeddings2.shape[0], image_embeddings3.shape[0], image_embeddings4.shape[0]])
for i in range(mixed_l):
image_embeddings.append(torch.cat([image_embeddings1[i:i+1,:], image_embeddings2[i:i+1,:], image_embeddings3[i:i+1,:], image_embeddings4[i:i+1,:]], dim=1))
else:
raise Exception("Invalid video_encoder " + self.video_encoder)
image_embeddings = torch.cat(image_embeddings, dim=0)
#print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean())
np.savez(feature_path, image_embeddings, duration)
else:
#print("video exist", feature_path)
data = np.load(feature_path)
image_embeddings = torch.from_numpy(data["arr_0"])
#print("image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean())
duration = data["arr_1"].item()
if max_sample is None:
max_sample = int(duration * self.sampling_rate)
interpolated = []
for i in range(start_sample, max_sample, self.frame_size):
j = min(round((i+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1)
interpolated.append(image_embeddings[j:j+1])
if len(interpolated) >= l:
break
interpolated = torch.cat(interpolated, dim=0)
#ll = list(range(start_sample, max_sample, self.frame_size))
#print("encode_video l", len(ll), l, round((ll[-1]+self.frame_size//2) / self.sampling_rate / (duration / (image_embeddings.shape[0] - 1))), image_embeddings.shape[0] - 1)
#print("encode_video one", video_path, duration, image_embeddings.shape, interpolated.shape, l)
video_embeddings.append(interpolated.unsqueeze(0))
video_lens.append(interpolated.shape[1])
max_length = max(video_lens)
if max_length == 0:
max_length = l
else:
max_length = l
for i in range(len(video_embeddings)):
if video_embeddings[i] is None:
video_embeddings[i] = torch.zeros(1, max_length, d)
continue
if video_embeddings[i].shape[1] < max_length:
video_embeddings[i] = torch.cat([video_embeddings[i], torch.zeros(1, max_length-video_embeddings[i].shape[1], d)], 1)
video_embeddings = torch.cat(video_embeddings, 0)
#print("encode_video", l, video_embeddings.shape, video_lens)
return video_embeddings.to(device)
@staticmethod
def encode_video_frames(video_paths, l):
#### skip video frames
train_video_encoder = True
if not train_video_encoder:
midi_gts = []
for video_path in video_paths:
if video_path is None:
#midi_gts.append(None)
continue
if isinstance(video_path, tuple):
video_path, start_sample, max_sample = video_path
else:
start_sample = 0
max_sample = None
####if video_path.startswith("/ailab-train2/speech/zhanghaomin/scps/instruments/"):
if "/piano_2h_cropped2_cuts/" in video_path:
pass
else:
#midi_gts.append(None)
continue
####midi_gt
midi_gt = torch.from_numpy(np.load(video_path.replace(".mp4", ".3.npy")).astype(np.float32))[:,NOTTE_MIN:NOTE_MAX+1]
#print("midi_gt", midi_gt.shape, midi_gt.max(), midi_gt.min(), torch.abs(midi_gt).mean())
midi_gts.append(midi_gt.unsqueeze(0))
if len(midi_gts) == 0:
return None, None
max_length = l
for i in range(len(midi_gts)):
if midi_gts[i] is None:
midi_gts[i] = torch.zeros(1, max_length, NOTES)
continue
if midi_gts[i].shape[1] < max_length:
midi_gts[i] = torch.cat([midi_gts[i], torch.zeros(1, max_length-midi_gts[i].shape[1], NOTES)], 1)
elif midi_gts[i].shape[1] > max_length:
midi_gts[i] = midi_gts[i][:, :max_length, :]
midi_gts = torch.cat(midi_gts, 0)
video_frames = 1.0
#print("encode_video_frames", l, midi_gts.shape, midi_gts.sum())
return video_frames, midi_gts
video_frames = []
video_lens = []
midi_gts = []
for video_path in video_paths:
if video_path is None:
#video_frames.append(None)
video_lens.append(0)
#midi_gts.append(None)
continue
if isinstance(video_path, tuple):
video_path, start_sample, max_sample = video_path
else:
start_sample = 0
max_sample = None
####if video_path.startswith("/ailab-train2/speech/zhanghaomin/scps/instruments/"):
if "/piano_2h_cropped2_cuts/" in video_path:
frames_raw_path = video_path.replace(".mp4", ".generated_frames_raw.2.npz")
if not os.path.exists(frames_raw_path):
frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=None)
if frames is None:
#video_frames.append(None)
video_lens.append(0)
#midi_gts.append(None)
continue
#print("raw image size", frames.shape, video_path)
frames_resized = []
for i in range(frames.shape[0]):
########frames_resized.append(np.asarray(Image.fromarray(frames[i]).resize((320, 320))))
####frames_resized.append(np.asarray(Image.fromarray(frames[i]).resize((1024, 128))))
input_img = Image.fromarray(frames[i]).convert('L')
binarr = np.array(input_img)
input_img = Image.fromarray(binarr.astype(np.uint8))
frames_resized.append(transform(input_img))
####frames_raw = np.array(frames_resized)
frames_raw = np.concatenate(frames_resized).astype(np.float32)[...,np.newaxis]
np.savez(frames_raw_path, frames_raw, duration)
else:
data = np.load(frames_raw_path)
frames_raw = data["arr_0"]
duration = data["arr_1"].item()
####frames_raw = frames_raw.astype(np.float32) / 255.0
#v_frames_raw = frames_raw[1:,:,:,:] - frames_raw[:-1,:,:,:]
#v_frames_raw = np.concatenate((np.zeros((1,v_frames_raw.shape[1],v_frames_raw.shape[2],v_frames_raw.shape[3]), dtype=np.float32), v_frames_raw), axis=0)
##print("v_frames_raw", v_frames_raw.shape, v_frames_raw.max(), v_frames_raw.min(), np.abs(v_frames_raw).mean(), np.abs(v_frames_raw[0,:,:,:]).mean())
#frames_raw = np.concatenate((frames_raw, v_frames_raw), axis=3)
frames_raw = torch.from_numpy(frames_raw)
#print("frames_raw", frames_raw.shape, frames_raw.max(), frames_raw.min(), torch.abs(frames_raw).mean(), "image_embeddings", image_embeddings.shape, image_embeddings.max(), image_embeddings.min(), torch.abs(image_embeddings).mean())
else:
#video_frames.append(None)
video_lens.append(0)
#midi_gts.append(None)
continue
#print("frames_raw", frames_raw.shape, l)
if max_sample is None:
max_sample = int(duration * 24000)
####video_multi = 3.0
video_multi = 2.5
interpolated_frames_raw = []
frame_size_video = int(video_multi*320)
for i in range(start_sample, max_sample+frame_size_video, frame_size_video):
j = min(round(i / 24000 / (duration / (frames_raw.shape[0] - 0))), frames_raw.shape[0] - 1)
#print(j)
interpolated_frames_raw.append(frames_raw[j:j+1])
if len(interpolated_frames_raw) >= math.floor(l/video_multi)+1:
#print("break", len(interpolated_frames_raw), l, frames_raw.shape, j)
break
interpolated_frames_raw = torch.cat(interpolated_frames_raw, dim=0)
####v_interpolated_frames_raw = interpolated_frames_raw[1:,:,:,:] - interpolated_frames_raw[:-1,:,:,:]
####v_interpolated_frames_raw = torch.cat((torch.zeros(1,v_interpolated_frames_raw.shape[1],v_interpolated_frames_raw.shape[2],v_interpolated_frames_raw.shape[3]), v_interpolated_frames_raw), 0)
#####print("v_interpolated_frames_raw", v_interpolated_frames_raw.shape, v_interpolated_frames_raw.max(), v_interpolated_frames_raw.min(), torch.abs(v_interpolated_frames_raw).mean(), torch.abs(v_interpolated_frames_raw[0,:,:,:]).mean())
####interpolated_frames_raw = torch.cat((interpolated_frames_raw, v_interpolated_frames_raw), 3)
video_frames.append(interpolated_frames_raw.unsqueeze(0))
video_lens.append(interpolated_frames_raw.shape[0])
####midi_gt
####midi_gt = torch.from_numpy(np.load(video_path.replace(".mp4", ".3.npy")).astype(np.float32))[:,NOTTE_MIN:NOTE_MAX+1]
#####print("midi_gt", midi_gt.shape, midi_gt.max(), midi_gt.min(), torch.abs(midi_gt).mean())
####midi_gts.append(midi_gt.unsqueeze(0))
midi_gts.append(None)
if len(video_frames) == 0:
return None, None
max_length = max(video_lens)
if max_length == 0:
max_length = l
else:
max_length = l
max_length_video = max(math.floor(l/video_multi)+1, max(video_lens))
for i in range(len(video_frames)):
if video_frames[i] is None:
########video_frames[i] = torch.zeros(1, max_length_video, 320, 320, 6)
####video_frames[i] = torch.zeros(1, max_length_video, 128, 1024, 6)
video_frames[i] = torch.zeros(1, max_length_video, 100, 900, 1)
continue
if video_frames[i].shape[1] < max_length_video:
########video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 320, 320, 6)], 1)
####video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 128, 1024, 6)], 1)
video_frames[i] = torch.cat([video_frames[i], torch.zeros(1, max_length_video-video_frames[i].shape[1], 100, 900, 1)], 1)
video_frames = torch.cat(video_frames, 0)
video_frames = video_frames.permute(0,4,1,2,3)
for i in range(len(midi_gts)):
if midi_gts[i] is None:
midi_gts[i] = torch.zeros(1, max_length, NOTES)
continue
if midi_gts[i].shape[1] < max_length:
midi_gts[i] = torch.cat([midi_gts[i], torch.zeros(1, max_length-midi_gts[i].shape[1], NOTES)], 1)
elif midi_gts[i].shape[1] > max_length:
midi_gts[i] = midi_gts[i][:, :max_length, :]
midi_gts = torch.cat(midi_gts, 0)
#print("encode_video_frames", l, video_frames.shape, video_lens, midi_gts.shape, midi_gts.sum())
return video_frames, midi_gts
def transformer_with_pred_head(
self,
x: Float['b n d'],
cond: Float['b n d'] | None = None,
times: Float['b'] | None = None,
mask: Bool['b n'] | None = None,
text: Int['b nt'] | Float['b nt dt'] | None = None,
frames_embed: Float['b nf df'] | None = None,
prompt = None,
video_drop_prompt = None,
audio_drop_prompt = None,
drop_audio_cond: bool | None = None,
drop_text_cond: bool | None = None,
drop_text_prompt: bool | None = None,
return_drop_conditions = False
):
seq_len = x.shape[-2]
bs = x.shape[0]
drop_audio_cond = [default(drop_audio_cond, self.training and random() < self.audiocond_drop_prob) for _ in range(bs)]
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)
drop_text_prompt = [default(drop_text_prompt, self.training and random() < self.prompt_drop_prob) for _ in range(bs)]
if cond is not None:
for b in range(bs):
if drop_audio_cond[b]:
cond[b] = 0
if audio_drop_prompt is not None and audio_drop_prompt[b]:
cond[b] = 0
if cond is not None:
if self.concat_cond:
# concat condition, given as using voicebox-like scheme
x = torch.cat((cond, x), dim = -1)
x = self.proj_in(x)
if cond is not None:
if not self.concat_cond:
# an alternative is to simply sum the condition
# seems to work fine
cond = self.cond_proj_in(cond)
x = x + cond
# whether to use a text embedding
text_embed = None
if exists(text) and len(text.shape) == 3:
text_embed = text.clone()
if drop_text_cond:
for b in range(bs):
text_embed[b] = 0
elif exists(text) and not drop_text_cond:
text_embed = self.embed_text(text, seq_len, mask = mask)
context, context_mask = None, None
if prompt is not None:
#for b in range(bs):
# if drop_text_prompt[b]:
# prompt[b] = ""
if video_drop_prompt is not None:
for b in range(bs):
if video_drop_prompt[b]:
prompt[b] = "the sound of X X"
context, context_mask = self.encode_text(prompt)
for b in range(bs):
if drop_text_prompt[b]:
context[b] = 0
if video_drop_prompt is not None and video_drop_prompt[b]:
context[b] = 0
#print("cross attention", context.shape, context_mask.shape, x.shape, mask.shape, text_embed.shape if text_embed is not None else None, torch.mean(torch.abs(text_embed), dim=(1,2)))
#print("video_drop_prompt", prompt, video_drop_prompt, context.shape, torch.mean(torch.abs(context), dim=(1,2)))
#print("audio_drop_prompt", audio_drop_prompt, cond.shape, torch.mean(torch.abs(cond), dim=(1,2)))
if self.proj_text is not None:
text_embed = self.proj_text(text_embed)
frames_embed = self.proj_frames(frames_embed)
# attend
attended = self.transformer(
x,
times = times,
mask = mask,
text_embed = text_embed,
frames_embed = frames_embed,
context = context,
context_mask = context_mask
)
pred = self.to_pred(attended)
if not return_drop_conditions:
return pred
return pred, drop_audio_cond, drop_text_cond, drop_text_prompt
def cfg_transformer_with_pred_head(
self,
*args,
cfg_strength: float = 1.,
remove_parallel_component: bool = True,
keep_parallel_frac: float = 0.,
**kwargs,
):
pred = self.transformer_with_pred_head(*args, drop_audio_cond = False, drop_text_cond = False, drop_text_prompt = False, **kwargs)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer_with_pred_head(*args, drop_audio_cond = True, drop_text_cond = True, drop_text_prompt = True, **kwargs)
cfg_update = pred - null_pred
if remove_parallel_component:
# https://arxiv.org/abs/2410.02416
parallel, orthogonal = project(cfg_update, pred)
cfg_update = orthogonal + parallel * keep_parallel_frac
return pred + cfg_update * cfg_strength
def add_noise(self, signal, mask, val):
if self.audiocond_snr is None:
return signal
if not val:
snr = np.random.uniform(self.audiocond_snr[0], self.audiocond_snr[1])
else:
snr = (self.audiocond_snr[0] + self.audiocond_snr[1]) / 2.0
#print("add_noise", self.audiocond_snr, snr, signal.shape, mask) # [True, ..., False]
noise = torch.randn_like(signal)
w = torch.abs(signal[mask]).mean() / (torch.abs(noise[mask]).mean() + 1e-6) / snr
return signal + noise * w
@torch.no_grad()
def sample(
self,
cond: Float['b n d'] | Float['b nw'] | None = None,
*,
text: Int['b nt'] | list[str] | None = None,
lens: Int['b'] | None = None,
duration: int | Int['b'] | None = None,
steps = 32,
cfg_strength = 1., # they used a classifier free guidance strength of 1.
remove_parallel_component = True,
sway_sampling = True,
max_duration = 4096, # in case the duration predictor goes haywire
vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None,
return_raw_output: bool | None = None,
save_to_filename: str | None = None,
prompt = None,
video_drop_prompt = None,
audio_drop_prompt = None,
video_paths = None,
frames = None,
midis = None
) -> (
Float['b n d'],
list[Float['_']]
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = rearrange(cond, 'b d n -> b n d')
assert cond.shape[-1] == self.num_channels
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if frames is None:
frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device)
else:
#### sampling settings
train_video_encoder = True
if train_video_encoder:
frames_embed = self.encode_frames(frames, cond_seq_len)
else:
frames_embed = midis
if frames_embed.shape[1] < cond_seq_len:
frames_embed = torch.cat([frames_embed, torch.zeros(1, cond_seq_len-frames_embed.shape[1], NOTES)], 1)
elif frames_embed.shape[1] > cond_seq_len:
frames_embed = frames_embed[:, :cond_seq_len, :]
#x0 = torch.zeros(batch, cond_seq_len, 128, device=device)
print("frames_embed midis cond", frames_embed.shape if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, frames_embed.sum() if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, midis.shape if midis is not None else midis, midis.sum() if midis is not None else midis, cond.shape if cond is not None else cond, cond.sum() if cond is not None else cond)
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
if video_paths is not None:
text = self.encode_video(video_paths, cond_seq_len)
# text
elif isinstance(text, list):
text = self.tokenizer(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim = -1)
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
# duration
cond_mask = lens_to_mask(lens)
if exists(duration):
if isinstance(duration, int):
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
elif exists(self.duration_predictor):
duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long()
duration = torch.maximum(lens, duration) # just add one token so something is generated
duration = duration.clamp(max = max_duration)
assert duration.shape[0] == batch
max_duration = duration.amax()
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
cond_mask = rearrange(cond_mask, '... -> ... 1')
mask = lens_to_mask(duration)
#print("mask", duration, mask, mask.shape, lens, cond_mask, cond_mask.shape, text)
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
if lens[0] == duration[0]:
print("No cond", lens, duration)
step_cond = None
else:
step_cond = torch.where(cond_mask, self.add_noise(cond, cond_mask, True), torch.zeros_like(cond))
#step_cond = cond
# predict flow
return self.cfg_transformer_with_pred_head(
x,
step_cond,
times = t,
text = text,
frames_embed = frames_embed,
mask = mask,
prompt = prompt,
video_drop_prompt = video_drop_prompt,
audio_drop_prompt = audio_drop_prompt,
cfg_strength = cfg_strength,
remove_parallel_component = remove_parallel_component
)
####torch.manual_seed(0)
y0 = torch.randn_like(cond)
#y0 = torch.randn_like(x0)
t = torch.linspace(0, 1, steps, device = self.device)
if sway_sampling:
t = t + -1.0 * (torch.cos(torch.pi / 2 * t) - 1 + t)
#print("@@@@", t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
out = sampled
if lens[0] != duration[0]:
out = torch.where(cond_mask, cond, out)
# able to return raw untransformed output, if not using mel rep
if exists(return_raw_output) and return_raw_output:
return out
# take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on
if exists(vocoder):
assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling'
out = rearrange(out, 'b n d -> b d n')
out = vocoder(out)
elif exists(self.vocos):
audio = []
for mel, one_mask in zip(out, mask):
#one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5)
one_out = mel[one_mask]
one_out = rearrange(one_out, 'n d -> 1 d n')
one_audio = self.vocos.decode(one_out)
one_audio = rearrange(one_audio, '1 nw -> nw')
audio.append(one_audio)
out = audio
if exists(save_to_filename):
assert exists(vocoder) or exists(self.vocos)
assert exists(self.sampling_rate)
path = Path(save_to_filename)
parent_path = path.parents[0]
parent_path.mkdir(exist_ok = True, parents = True)
for ind, one_audio in enumerate(out):
one_audio = rearrange(one_audio, 'nw -> 1 nw')
if len(out) == 1:
save_path = str(parent_path / f'{path.name}')
else:
save_path = str(parent_path / f'{ind + 1}.{path.name}')
torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate)
return out
def forward(
self,
inp: Float['b n d'] | Float['b nw'], # mel or raw wave
*,
text: Int['b nt'] | list[str] | None = None,
times: int | Int['b'] | None = None,
lens: Int['b'] | None = None,
velocity_consistency_model: E2TTS | None = None,
velocity_consistency_delta = 1e-3,
prompt = None,
video_drop_prompt=None,
audio_drop_prompt=None,
val = False,
video_paths=None,
frames=None,
midis=None
):
need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0.
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = rearrange(inp, 'b d n -> b n d')
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device
if video_paths is not None:
text = self.encode_video(video_paths, seq_len)
# handle text as string
elif isinstance(text, list):
text = self.tokenizer(text).to(device)
#print("text tokenized", text[0])
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device = device)
mask = lens_to_mask(lens, length = seq_len)
# get a random span to mask out for training conditionally
if not val:
if self.audiocond_drop_prob > 1.0:
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(1.0,1.0)
else:
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
else:
frac_lengths = torch.tensor([(0.7+1.0)/2.0]*batch, device = self.device).float()
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len, val = val)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# main conditional flow training logic
# just ~5 loc
# x0 is gaussian noise
if val:
torch.manual_seed(0)
x0 = torch.randn_like(x1)
if val:
torch.manual_seed(int(time.time()*1000))
# t is random times from above
if times is None:
times = torch.rand((batch,), dtype = dtype, device = self.device)
else:
times = torch.tensor((times,)*batch, dtype = dtype, device = self.device)
t = rearrange(times, 'b -> b 1 1')
# if need velocity consistency, make sure time does not exceed 1.
if need_velocity_loss:
t = t * (1. - velocity_consistency_delta)
# sample xt (w in the paper)
w = (1. - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
if self.audiocond_drop_prob > 1.0:
cond = None
else:
cond = einx.where(
'b n, b n d, b n d -> b n d',
rand_span_mask,
torch.zeros_like(x1), self.add_noise(x1, ~rand_span_mask, val)
)
#### training settings
train_video_encoder = True
train_v2a = True
use_midi_gt = False
train_video_encoder = train_video_encoder
####train_v2a = train_v2a or val
#print("train_video_encoder", train_video_encoder, use_midi_gt, train_v2a)
####
if frames is None:
frames_embed = torch.zeros(batch, seq_len, NOTES, device=device)
midis = torch.zeros(batch, seq_len, NOTES, device=device)
else:
if train_video_encoder:
frames_embed = self.encode_frames(frames, seq_len)
else:
frames_embed = midis
#print("frames_embed midis cond", frames_embed.shape if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, frames_embed.sum() if frames_embed is not None and not isinstance(frames_embed, float) else frames_embed, midis.shape, midis.sum(), cond.shape if cond is not None else cond, cond.sum() if cond is not None else cond, x1.shape)
if train_video_encoder:
#lw = 1.0
lw = torch.abs(midis-0.10)
#lw = torch.max(torch.abs(midis-0.20), torch.tensor(0.20))
loss_midi = F.mse_loss(frames_embed, midis, reduction = 'none') * lw
#loss_midi = nn.BCELoss(reduction = 'none')(frames_embed, midis) * lw
#print("loss_midi", loss_midi.shape, mask.shape, mask, rand_span_mask.shape, rand_span_mask)
loss_midi = loss_midi[mask[-frames_embed.shape[0]:,...]].mean()
b, t, f = frames_embed.shape
frames_embed_t = frames_embed[:,:(t//3)*3,:].reshape(b,t//3,3,f).mean(dim=2)
midis_t = midis[:,:(t//3)*3,:].reshape(b,t//3,3,f).mean(dim=2)
mask_t = mask[-frames_embed.shape[0]:,:(t//3)*3].reshape(b,t//3,3).to(torch.float32).mean(dim=2) >= 0.99
tp = ((frames_embed_t>=0.4)*(midis_t>=0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum()
fp = ((frames_embed_t>=0.4)*(midis_t<0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum()
fn = ((frames_embed_t<0.4)*(midis_t>=0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum()
tn = ((frames_embed_t<0.4)*(midis_t<0.5)).to(torch.float)[mask_t[-frames_embed_t.shape[0]:,...]].sum()
#print("tp fp fn tn", tp, fp, fn, tn)
pre = tp / (tp + fp) if (tp + fp) != 0 else torch.tensor(0.0, device=device)
rec = tp / (tp + fn) if (tp + fn) != 0 else torch.tensor(0.0, device=device)
f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else torch.tensor(0.0, device=device)
acc = tp / (tp + fp + fn) if (tp + fp + fn) != 0 else torch.tensor(0.0, device=device)
else:
loss_midi = torch.tensor(0.0, device=device)
tp = torch.tensor(0.0, device=device)
fp = torch.tensor(0.0, device=device)
fn = torch.tensor(0.0, device=device)
tn = torch.tensor(0.0, device=device)
pre = torch.tensor(0.0, device=device)
rec = torch.tensor(0.0, device=device)
f1 = torch.tensor(0.0, device=device)
acc = torch.tensor(0.0, device=device)
#if train_video_encoder: # loss_midi_zeros * 100.0 # 0.2131/0.1856 # 0.2819/0.2417 # 2.451/
# loss_midi_zeros = F.mse_loss(torch.zeros_like(midis), midis, reduction = 'none')
# loss_midi_zeros = loss_midi_zeros[mask[-frames_embed.shape[0]:,...]].mean()
#else:
# loss_midi_zeros = torch.tensor(0.0, device=device)
if train_v2a:
if use_midi_gt:
frames_embed = midis
if frames_embed.shape[0] < x1.shape[0]:
frames_embed = torch.cat((torch.zeros(x1.shape[0]-frames_embed.shape[0],frames_embed.shape[1],frames_embed.shape[2],device=frames_embed.device), frames_embed), 0)
# transformer and prediction head
if not val:
pred, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head(
w,
cond,
times = times,
text = text,
frames_embed = frames_embed,
mask = mask,
prompt = prompt,
video_drop_prompt = video_drop_prompt,
audio_drop_prompt = audio_drop_prompt,
return_drop_conditions = True
)
else:
pred, did_drop_audio_cond, did_drop_text_cond, did_drop_text_prompt = self.transformer_with_pred_head(
w,
cond,
times = times,
text = text,
frames_embed = frames_embed,
mask = mask,
prompt = prompt,
video_drop_prompt = video_drop_prompt,
audio_drop_prompt = audio_drop_prompt,
drop_audio_cond = False,
drop_text_cond = False,
drop_text_prompt = False,
return_drop_conditions = True
)
# maybe velocity consistency loss
velocity_loss = self.zero
if need_velocity_loss:
#t_with_delta = t + velocity_consistency_delta
#w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1
with torch.no_grad():
ema_pred = velocity_consistency_model.transformer_with_pred_head(
w, #w_with_delta,
cond,
times = times, #times + velocity_consistency_delta,
text = text,
frames_embed = frames_embed,
mask = mask,
prompt = prompt,
video_drop_prompt = video_drop_prompt,
audio_drop_prompt = audio_drop_prompt,
drop_audio_cond = did_drop_audio_cond,
drop_text_cond = did_drop_text_cond,
drop_text_prompt = did_drop_text_prompt
)
#velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none')
velocity_loss = F.mse_loss(ema_pred, flow, reduction = 'none')
velocity_loss = (velocity_loss.mean(-1)*rand_span_mask).mean(-1) #.mean()
ref_losses = velocity_loss[-2:, ...]
ref_losses_w, ref_losses_l = ref_losses.chunk(2)
raw_ref_loss = 0.5 * (ref_losses_w.mean() + ref_losses_l.mean())
ref_diff = ref_losses_w - ref_losses_l
else:
ref_losses_w, ref_losses_l = 0, 0
# flow matching loss
loss = F.mse_loss(pred, flow, reduction = 'none')
#print("loss", loss.shape, loss, "rand_span_mask", rand_span_mask.shape, rand_span_mask, "loss[rand_span_mask]", loss[rand_span_mask].shape, loss[rand_span_mask])
#### dpo
loss = loss[rand_span_mask].mean()
loss_dpo = torch.tensor(0.0, device=device)
####if val:
#### loss = loss[rand_span_mask].mean()
#### loss_dpo = torch.tensor(0.0, device=device)
#### model_losses_w, model_losses_l = 0, 0
####else:
#### loss_fm = loss[rand_span_mask].mean()
#### loss = (loss.mean(-1)*rand_span_mask).mean(-1) #.mean()
####
#### model_losses = loss[-2:, ...]
#### model_losses_w, model_losses_l = model_losses.chunk(2)
#### raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
#### model_diff = model_losses_w - model_losses_l
####
#### scale_term = -1
#### inside_term = scale_term * (model_diff - ref_diff)
#### loss_dpo = -1 * F.logsigmoid(inside_term).mean()
#### loss = loss_fm
#### dpo
else:
pred = torch.zeros_like(x0)
loss = torch.tensor(0.0, device=device)
# total loss and get breakdown
#midi_w = 100.0
midi_w = 10.0
#total_loss = loss
#total_loss = loss + loss_midi * midi_w
total_loss = loss + loss_midi * midi_w + loss_dpo
####breakdown = LossBreakdown(loss, loss_midi * midi_w, pre, rec)
breakdown = LossBreakdown(pre, rec, f1, acc)
#breakdown = LossBreakdown(tp, fp, fn, tn)
#### dpo
print("loss", loss, loss_midi * midi_w)
#print("loss", loss, loss_midi * midi_w, loss_dpo, model_losses_w, model_losses_l, ref_losses_w, ref_losses_l)
#### dpo
# return total loss and bunch of intermediates
return E2TTSReturn(total_loss, cond if cond is not None else w, pred, x0 + pred, breakdown)