Hhhh / sadtalker_utils.py
Hjgugugjhuhjggg's picture
Update sadtalker_utils.py
64bd56c verified
raw
history blame
36.8 kB
import os
import shutil
import uuid
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from PIL import Image
from skimage import img_as_ubyte, transform
import safetensors
import librosa
from pydub import AudioSegment
import imageio
from scipy import signal
from scipy.io import loadmat, savemat, wavfile
import glob
import tempfile
import tqdm
import math
import torchaudio
import urllib.request
from safetensors.torch import load_file, save_file
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
kp_file = "kp_detector.safetensors"
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
aud_file = "auido2pose_00140-model.pth"
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
wav_file = "wav2vec2.pth"
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
gen_file = "generator.pth"
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
mapx_file = "mapping.pth"
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
den_file = "dense_motion.pth"
def download_model(url, filename, checkpoint_dir):
if not os.path.exists(os.path.join(checkpoint_dir, filename)):
print(f"Downloading {filename}...")
os.makedirs(checkpoint_dir, exist_ok=True)
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
print(f"{filename} downloaded.")
else:
print(f"{filename} already exists.")
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
def load_wav_util(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav_util(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, sr, wav.astype(np.int16))
def load_state_dict_robust(model, checkpoint_path, device, model_name="model"):
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
if checkpoint_path.endswith('safetensors'):
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint.get(model_name, checkpoint)
try:
model.load_state_dict(state_dict)
except RuntimeError as e:
print(f"Error loading {model_name} state_dict: {e}")
print(f"Trying to load state_dict with key mapping for {model_name}.")
model_state_dict = model.state_dict()
mapped_state_dict = {}
for key, value in state_dict.items():
if key in model_state_dict and model_state_dict[key].shape == value.shape:
mapped_state_dict[key] = value
else:
print(f"Skipping key {key} due to shape mismatch or missing in model.")
missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
print(f"Successfully loaded {model_name} state_dict with key mapping.")
class OcclusionAwareKPDetector(nn.Module):
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
super(OcclusionAwareKPDetector, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.conv2(x)
kp = {'value': x.view(x.size(0), -1)}
return kp
class Wav2Vec2Model(nn.Module):
def __init__(self):
super(Wav2Vec2Model, self).__init__()
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
self.bn = nn.BatchNorm1d(64)
self.relu = nn.ReLU()
self.fc = nn.Linear(64, 2048)
def forward(self, audio):
x = audio.unsqueeze(1)
x = self.relu(self.bn(self.conv(x)))
x = torch.mean(x, dim=-1)
x = self.fc(x)
return x
class AudioCoeffsPredictor(nn.Module):
def __init__(self, input_dim, output_dim):
super(AudioCoeffsPredictor, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, audio_embedding):
return self.linear(audio_embedding)
class MappingNet(nn.Module):
def __init__(self, num_coeffs, num_layers, hidden_dim):
super(MappingNet, self).__init__()
layers = []
input_dim = num_coeffs * 2
for _ in range(num_layers):
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.ReLU())
input_dim = hidden_dim
layers.append(nn.Linear(hidden_dim, num_coeffs))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class DenseMotionNetwork(nn.Module):
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
super(DenseMotionNetwork, self).__init__()
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
def forward(self, kp_source, kp_driving, jacobian):
x = self.relu(self.conv1(kp_source))
x = self.conv2(x)
sparse_motion = {'dense_motion': x}
return sparse_motion
class Hourglass(nn.Module):
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
super(Hourglass, self).__init__()
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(max_features), nn.ReLU())
self.decoder = nn.Sequential(
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
def forward(self, source_image, kp_driving, **kwargs):
x = self.encoder(source_image)
x = self.decoder(x)
B, C, H, W = x.size()
video = []
for _ in range(10):
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
np.uint8)
video.append(frame)
return video
class Face3DHelper:
def __init__(self, local_pca_path, device):
self.local_pca_path = local_pca_path
self.device = device
def run(self, source_image):
h, w, _ = source_image.shape
x_min = w // 4
y_min = h // 4
x_max = x_min + w // 2
y_max = y_min + h // 2
return [x_min, y_min, x_max, y_max]
class MouthDetector:
def __init__(self):
pass
def detect(self, image):
h, w = image.shape[:2]
return (w // 2, h // 2)
class KeypointNorm(nn.Module):
def __init__(self, device):
super(KeypointNorm, self).__init__()
self.device = device
def forward(self, kp_driving):
return kp_driving
def save_video_with_watermark(video_frames, audio_path, output_path):
H, W, _ = video_frames[0].shape
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
for frame in video_frames:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
shutil.copy(video_path, output_path)
class TTSTalker:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tts_model = None
def load_model(self):
self.tts_model = self
def tokenizer(self, text):
return [ord(c) for c in text]
def __call__(self, input_tokens):
return torch.zeros(1, 16000, device=self.device)
def test(self, text, lang='en'):
if self.tts_model is None:
self.load_model()
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
os.makedirs('./results', exist_ok=True)
tokens = self.tokenizer(text)
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
with torch.no_grad():
audio_output = self(input_tokens)
torchaudio.save(output_path, audio_output.cpu(), 16000)
return output_path
class SadTalker:
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
old_version=False):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = self.get_cfg_defaults()
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
self.cfg['MODEL']['CONFIG_DIR'] = config_path
self.cfg['MODEL']['DEVICE'] = self.device
self.cfg['INPUT_IMAGE'] = {}
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
self.cfg['INPUT_IMAGE']['SIZE'] = size
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
for filename, url in [
(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url),
(mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL),
('RealESRGAN_x2plus.pth', REALESRGAN_URL)
]:
download_model(url, filename, checkpoint_path)
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
def get_cfg_defaults(self):
return {
'MODEL': {
'CHECKPOINTS_DIR': '',
'CONFIG_DIR': '',
'DEVICE': self.device,
'SCALE': 64,
'NUM_VOXEL_FRAMES': 8,
'NUM_MOTION_FRAMES': 10,
'MAX_FEATURES': 256,
'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
'VIDEO_FPS': 25,
'OUTPUT_VIDEO_FPS': None,
'OUTPUT_AUDIO_SAMPLE_RATE': None,
'USE_ENHANCER': False,
'ENHANCER_NAME': '',
'BG_UPSAMPLER': None,
'IS_HALF': False
},
'INPUT_IMAGE': {}
}
def merge_from_file(self, filepath):
if os.path.exists(filepath):
with open(filepath, 'r') as f:
cfg_from_file = yaml.safe_load(f)
self.cfg.update(cfg_from_file)
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
tts_text=None, tts_lang='en'):
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
length_of_audio, use_blink, result_dir, tts_text, tts_lang)
return self.sadtalker_model.save_result()
class SadTalkerModel:
def __init__(self, sadtalker_cfg, device_id=[0]):
self.cfg = sadtalker_cfg
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
self.preprocesser = self.sadtalker.preprocesser
self.kp_extractor = self.sadtalker.kp_extractor
self.generator = self.sadtalker.generator
self.mapping = self.sadtalker.mapping
self.he_estimator = self.sadtalker.he_estimator
self.audio_to_coeff = self.sadtalker.audio_to_coeff
self.animate_from_coeff = self.sadtalker.animate_from_coeff
self.face_enhancer = self.sadtalker.face_enhancer
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
jitter_amount, jitter_source_image)
return self.inner_test.test()
def save_result(self):
return self.inner_test.save_result()
class SadTalkerInner:
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
self.sadtalker_model = sadtalker_model
self.source_image = source_image
self.driven_audio = driven_audio
self.preprocess = preprocess
self.still_mode = still_mode
self.use_enhancer = use_enhancer
self.batch_size = batch_size
self.size = size
self.pose_style = pose_style
self.exp_scale = exp_scale
self.use_ref_video = use_ref_video
self.ref_video = ref_video
self.ref_info = ref_info
self.use_idle_mode = use_idle_mode
self.length_of_audio = length_of_audio
self.use_blink = use_blink
self.result_dir = result_dir
self.tts_text = tts_text
self.tts_lang = tts_lang
self.jitter_amount = jitter_amount
self.jitter_source_image = jitter_source_image
self.device = self.sadtalker_model.device
self.output_path = None
def get_test_data(self):
proc = self.sadtalker_model.preprocesser
if self.tts_text is not None:
temp_dir = tempfile.mkdtemp()
audio_path = os.path.join(temp_dir, 'audio.wav')
tts = TTSTalker()
tts.test(self.tts_text, self.tts_lang)
self.driven_audio = audio_path
source_image_pil = Image.open(self.source_image).convert('RGB')
if self.jitter_source_image:
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
source_image_pil = Image.fromarray(
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
if self.still_mode or self.use_idle_mode:
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
else:
ref_pose_coeff = None
ref_expression_coeff = None
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
batch = {
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
'audio': audio_tensor.unsqueeze(0).to(self.device),
'ref_pose_coeff': ref_pose_coeff,
'ref_expression_coeff': ref_expression_coeff,
'source_image_crop': cropped_image,
'crop_info': crop_info,
'use_blink': self.use_blink,
'pose_style': self.pose_style,
'exp_scale': self.exp_scale,
'ref_video': self.ref_video,
'use_ref_video': self.use_ref_video,
'ref_info': self.ref_info,
}
return batch, audio_sample_rate
def run_inference(self, batch):
kp_extractor = self.sadtalker_model.kp_extractor
generator = self.sadtalker_model.generator
mapping = self.sadtalker_model.mapping
he_estimator = self.sadtalker_model.he_estimator
audio_to_coeff = self.sadtalker_model.audio_to_coeff
animate_from_coeff = self.sadtalker_model.animate_from_coeff
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
with torch.no_grad():
kp_source = kp_extractor(batch['source_image'])
if self.still_mode or self.use_idle_mode:
ref_pose_coeff = batch['ref_pose_coeff']
ref_expression_coeff = batch['ref_expression_coeff']
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
elif self.use_idle_mode:
ref_pose_coeff = batch['ref_pose_coeff']
ref_expression_coeff = batch['ref_expression_coeff']
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
else:
if self.use_ref_video:
kp_ref = kp_extractor(batch['source_image'])
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
use_ref_info=batch['ref_info'])
else:
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
if self.use_blink:
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
else:
coeff['blink_coeff'] = None
kp_driving = audio_to_coeff(batch['audio'])[0]
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
coeff['kp_driving'] = kp_norm
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
he_estimator, batch['audio'], batch['source_image_crop'],
face_enhancer=face_enhancer)
return output_video
def post_processing(self, output_video, audio_sample_rate, batch):
proc = self.sadtalker_model.preprocesser
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
self.output_path = output_video_path
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
'OUTPUT_VIDEO_FPS'] is None else \
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
if self.use_enhancer:
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
output_video_path)
os.remove(enhanced_path)
else:
save_video_with_watermark(output_video, self.driven_audio, output_video_path)
if self.tts_text is not None:
shutil.rmtree(os.path.dirname(self.driven_audio))
def save_result(self):
return self.output_path
def __call__(self):
return self.output_path
def test(self):
batch, audio_sample_rate = self.get_test_data()
output_video = self.run_inference(batch)
self.post_processing(output_video, audio_sample_rate, batch)
return self.save_result()
class SadTalkerInnerModel:
def __init__(self, sadtalker_cfg, device_id=[0]):
self.cfg = sadtalker_cfg
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
'USE_ENHANCER'] else None
self.generator = Generator(sadtalker_cfg, self.device)
self.mapping = Mapping(sadtalker_cfg, self.device)
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
class Preprocesser:
def __init__(self, sadtalker_cfg, device):
self.cfg = sadtalker_cfg
self.device = device
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
self.mouth_detector = MouthDetector()
def crop(self, source_image_pil, preprocess_type, size=256):
source_image = np.array(source_image_pil)
face_info = self.face3d_helper.run(source_image)
if face_info is None:
raise Exception("No face detected")
x_min, y_min, x_max, y_max = face_info[:4]
old_size = (x_max - x_min, y_max - y_min)
x_center = (x_max + x_min) / 2
y_center = (y_max + y_min) / 2
if preprocess_type == 'crop':
face_size = max(x_max - x_min, y_max - y_min)
x_min = int(x_center - face_size / 2)
y_min = int(y_center - face_size / 2)
x_max = int(x_center + face_size / 2)
y_max = int(y_center + face_size / 2)
else:
x_min -= int((x_max - x_min) * 0.1)
y_min -= int((y_max - y_min) * 0.1)
x_max += int((x_max - x_min) * 0.1)
y_max += int((y_max - y_min) * 0.1)
h, w = source_image.shape[:2]
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(w, x_max)
y_max = min(h, y_max)
cropped_image = source_image[y_min:y_max, x_min:x_max]
cropped_image_pil = Image.fromarray(cropped_image)
if size is not None and size != 0:
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
source_image_tensor = self.img2tensor(cropped_image_pil)
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
def img2tensor(self, img):
img = np.array(img).astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1))
return torch.FloatTensor(img)
def video_to_tensor(self, video, device):
video_tensor_list = []
import torchvision.transforms as transforms
transform_func = transforms.ToTensor()
for frame in video:
frame_pil = Image.fromarray(frame)
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
video_tensor_list.append(frame_tensor)
video_tensor = torch.cat(video_tensor_list, dim=0)
return video_tensor
def process_audio(self, audio_path, sample_rate):
wav = load_wav_util(audio_path, sample_rate)
wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
return wav_tensor, sample_rate
def generate_still_pose(self, pose_style):
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
return ref_pose_coeff
def generate_still_expression(self, exp_scale):
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
return ref_expression_coeff
def generate_idles_pose(self, length_of_audio, pose_style):
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
start_pose = self.generate_still_pose(pose_style)
end_pose = self.generate_still_pose(pose_style)
for frame_idx in range(num_frames):
alpha = frame_idx / num_frames
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
return ref_pose_coeff
def generate_idles_expression(self, length_of_audio):
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
start_exp = self.generate_still_expression(1.0)
end_exp = self.generate_still_expression(1.0)
for frame_idx in range(num_frames):
alpha = frame_idx / num_frames
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
return ref_expression_coeff
class KeyPointExtractor(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(KeyPointExtractor, self).__init__()
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
num_kp=10,
num_dilation_blocks=2,
dropout_rate=0.1).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
def forward(self, x):
kp = self.kp_extractor(x)
return kp
class Audio2Coeff(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Audio2Coeff, self).__init__()
self.audio_model = Wav2Vec2Model().to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth')
load_state_dict_robust(self, mapping_checkpoint, device)
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
audio_embedding = self.audio_model(audio_tensor)
pose_coeff = self.pose_mapper(audio_embedding)
if ref_pose_coeff is not None:
pose_coeff = ref_pose_coeff
if kp_ref is not None and use_ref_info == 'pose':
ref_pose_6d = kp_ref['value'][:, :6]
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
return pose_coeff
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
audio_embedding = self.audio_model(audio_tensor)
expression_coeff = self.exp_mapper(audio_embedding)
if ref_expression_coeff is not None:
expression_coeff = ref_expression_coeff
return expression_coeff
def get_blink_coeff(self, audio_tensor):
audio_embedding = self.audio_model(audio_tensor)
blink_coeff = self.blink_mapper(audio_embedding)
return blink_coeff
def forward(self, audio):
audio_embedding = self.audio_model(audio)
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
audio_embedding), self.blink_mapper(audio_embedding)
return pose_coeff, expression_coeff, blink_coeff
def mean_std_normalize(self, coeff):
mean = coeff.mean(dim=1, keepdim=True)
std = coeff.std(dim=1, keepdim=True)
return (coeff - mean) / std
class AnimateFromCoeff(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(AnimateFromCoeff, self).__init__()
self.generator = Generator(sadtalker_cfg, device)
self.mapping = Mapping(sadtalker_cfg, device)
self.kp_norm = KeypointNorm(device=device)
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
def normalize_kp(self, kp_driving):
return self.kp_norm(kp_driving)
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
face_enhancer=None):
kp_driving = coeff['kp_driving']
jacobian = coeff['jacobian']
pose_coeff = coeff['pose_coeff']
expression_coeff = coeff['expression_coeff']
blink_coeff = coeff['blink_coeff']
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff)
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
dense_motion = sparse_motion['dense_motion']
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d)
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
if face_enhancer is not None:
video_output_enhanced = []
for frame in tqdm(video_output, 'Face enhancer running'):
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
enhanced_image = face_enhancer.forward(np.array(pil_image))
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
video_output = video_output_enhanced
return video_output
def make_animation(self, video_array):
H, W, _ = video_array[0].shape
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
for img in video_array:
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
out.release()
video = imageio.mimread('./tmp.mp4')
os.remove('./tmp.mp4')
return video
class Generator(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Generator, self).__init__()
self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
num_channels=3,
kp_size=10,
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param)
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
class Mapping(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Mapping, self).__init__()
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
self.f_3d_mean = torch.zeros(1, 64, device=device)
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
face_3d = self.mapping_net(coeff) + self.f_3d_mean
if blink_coeff is not None:
face_3d[:, -1:] = blink_coeff
return face_3d
class OcclusionAwareDenseMotion(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(OcclusionAwareDenseMotion, self).__init__()
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
num_channels=3,
block_expansion=sadtalker_cfg['MODEL']['SCALE'],
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
def forward(self, kp_source, kp_driving, jacobian):
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
return sparse_motion
class FaceEnhancer(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(FaceEnhancer, self).__init__()
enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
if enhancer_name == 'gfpgan':
from gfpgan import GFPGANer
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler)
elif enhancer_name == 'realesrgan':
from realesrgan import RealESRGANer
half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
self.face_enhancer = RealESRGANer(scale=2,
model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
'RealESRGAN_x2plus.pth'),
tile=0,
tile_pad=10,
pre_pad=0,
half=half,
device=device)
else:
self.face_enhancer = None
def forward(self, x):
if self.face_enhancer:
return self.face_enhancer.enhance(x, outscale=1)[0]
return x
def load_models():
checkpoint_path = './checkpoints'
config_path = './src/config'
size = 256
preprocess = 'crop'
old_version = False
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version)
print("SadTalker models loaded successfully!")
return sadtalker_instance
if __name__ == '__main__':
sadtalker_instance = load_models()