import os import shutil import uuid import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import yaml from PIL import Image from skimage import img_as_ubyte, transform import safetensors import librosa from pydub import AudioSegment import imageio from scipy import signal from scipy.io import loadmat, savemat, wavfile import glob import tempfile import tqdm import math import torchaudio import urllib.request from safetensors.torch import load_file, save_file REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors" kp_file = "kp_detector.safetensors" aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth" aud_file = "auido2pose_00140-model.pth" wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin" wav_file = "wav2vec2.pth" gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth" gen_file = "generator.pth" mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar" mapx_file = "mapping.pth" den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth" den_file = "dense_motion.pth" def download_model(url, filename, checkpoint_dir): if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}...") os.makedirs(checkpoint_dir, exist_ok=True) urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)) print(f"{filename} downloaded.") else: print(f"{filename} already exists.") def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate): AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav") def load_wav_util(path, sr): return librosa.core.load(path, sr=sr)[0] def save_wav_util(wav, path, sr): wav *= 32767 / max(0.01, np.max(np.abs(wav))) wavfile.write(path, sr, wav.astype(np.int16)) def load_state_dict_robust(model, checkpoint_path, device, model_name="model"): if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint.get(model_name, checkpoint) try: model.load_state_dict(state_dict) except RuntimeError as e: print(f"Error loading {model_name} state_dict: {e}") print(f"Trying to load state_dict with key mapping for {model_name}.") model_state_dict = model.state_dict() mapped_state_dict = {} for key, value in state_dict.items(): if key in model_state_dict and model_state_dict[key].shape == value.shape: mapped_state_dict[key] = value else: print(f"Skipping key {key} due to shape mismatch or missing in model.") missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False) if missing_keys or unexpected_keys: print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") print(f"Successfully loaded {model_name} state_dict with key mapping.") class OcclusionAwareKPDetector(nn.Module): def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate): super(OcclusionAwareKPDetector, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1) def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.conv2(x) kp = {'value': x.view(x.size(0), -1)} return kp class Wav2Vec2Model(nn.Module): def __init__(self): super(Wav2Vec2Model, self).__init__() self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5) self.bn = nn.BatchNorm1d(64) self.relu = nn.ReLU() self.fc = nn.Linear(64, 2048) def forward(self, audio): x = audio.unsqueeze(1) x = self.relu(self.bn(self.conv(x))) x = torch.mean(x, dim=-1) x = self.fc(x) return x class AudioCoeffsPredictor(nn.Module): def __init__(self, input_dim, output_dim): super(AudioCoeffsPredictor, self).__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, audio_embedding): return self.linear(audio_embedding) class MappingNet(nn.Module): def __init__(self, num_coeffs, num_layers, hidden_dim): super(MappingNet, self).__init__() layers = [] input_dim = num_coeffs * 2 for _ in range(num_layers): layers.append(nn.Linear(input_dim, hidden_dim)) layers.append(nn.ReLU()) input_dim = hidden_dim layers.append(nn.Linear(hidden_dim, num_coeffs)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class DenseMotionNetwork(nn.Module): def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features): super(DenseMotionNetwork, self).__init__() self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1) def forward(self, kp_source, kp_driving, jacobian): x = self.relu(self.conv1(kp_source)) x = self.conv2(x) sparse_motion = {'dense_motion': x} return sparse_motion class Hourglass(nn.Module): def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks): super(Hourglass, self).__init__() self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(max_features), nn.ReLU()) self.decoder = nn.Sequential( nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh()) def forward(self, source_image, kp_driving, **kwargs): x = self.encoder(source_image) x = self.decoder(x) B, C, H, W = x.size() video = [] for _ in range(10): frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype( np.uint8) video.append(frame) return video class Face3DHelper: def __init__(self, local_pca_path, device): self.local_pca_path = local_pca_path self.device = device def run(self, source_image): h, w, _ = source_image.shape x_min = w // 4 y_min = h // 4 x_max = x_min + w // 2 y_max = y_min + h // 2 return [x_min, y_min, x_max, y_max] class MouthDetector: def __init__(self): pass def detect(self, image): h, w = image.shape[:2] return (w // 2, h // 2) class KeypointNorm(nn.Module): def __init__(self, device): super(KeypointNorm, self).__init__() self.device = device def forward(self, kp_driving): return kp_driving def save_video_with_watermark(video_frames, audio_path, output_path): H, W, _ = video_frames[0].shape out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) for frame in video_frames: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path): shutil.copy(video_path, output_path) class TTSTalker: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tts_model = None def load_model(self): self.tts_model = self def tokenizer(self, text): return [ord(c) for c in text] def __call__(self, input_tokens): return torch.zeros(1, 16000, device=self.device) def test(self, text, lang='en'): if self.tts_model is None: self.load_model() output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav') os.makedirs('./results', exist_ok=True) tokens = self.tokenizer(text) input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device) with torch.no_grad(): audio_output = self(input_tokens) torchaudio.save(output_path, audio_output.cpu(), 16000) return output_path class SadTalker: def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = self.get_cfg_defaults() self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml')) self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path self.cfg['MODEL']['CONFIG_DIR'] = config_path self.cfg['MODEL']['DEVICE'] = self.device self.cfg['INPUT_IMAGE'] = {} self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None' self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None' self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess self.cfg['INPUT_IMAGE']['SIZE'] = size self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version for filename, url in [ (kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL) ]: download_model(url, filename, checkpoint_path) self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0]) def get_cfg_defaults(self): return { 'MODEL': { 'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False }, 'INPUT_IMAGE': {} } def merge_from_file(self, filepath): if os.path.exists(filepath): with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f) self.cfg.update(cfg_from_file) def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'): self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang) return self.sadtalker_model.save_result() class SadTalkerModel: def __init__(self, sadtalker_cfg, device_id=[0]): self.cfg = sadtalker_cfg self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) self.preprocesser = self.sadtalker.preprocesser self.kp_extractor = self.sadtalker.kp_extractor self.generator = self.sadtalker.generator self.mapping = self.sadtalker.mapping self.he_estimator = self.sadtalker.he_estimator self.audio_to_coeff = self.sadtalker.audio_to_coeff self.animate_from_coeff = self.sadtalker.animate_from_coeff self.face_enhancer = self.sadtalker.face_enhancer def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False): self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image) return self.inner_test.test() def save_result(self): return self.inner_test.save_result() class SadTalkerInner: def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image): self.sadtalker_model = sadtalker_model self.source_image = source_image self.driven_audio = driven_audio self.preprocess = preprocess self.still_mode = still_mode self.use_enhancer = use_enhancer self.batch_size = batch_size self.size = size self.pose_style = pose_style self.exp_scale = exp_scale self.use_ref_video = use_ref_video self.ref_video = ref_video self.ref_info = ref_info self.use_idle_mode = use_idle_mode self.length_of_audio = length_of_audio self.use_blink = use_blink self.result_dir = result_dir self.tts_text = tts_text self.tts_lang = tts_lang self.jitter_amount = jitter_amount self.jitter_source_image = jitter_source_image self.device = self.sadtalker_model.device self.output_path = None def get_test_data(self): proc = self.sadtalker_model.preprocesser if self.tts_text is not None: temp_dir = tempfile.mkdtemp() audio_path = os.path.join(temp_dir, 'audio.wav') tts = TTSTalker() tts.test(self.tts_text, self.tts_lang) self.driven_audio = audio_path source_image_pil = Image.open(self.source_image).convert('RGB') if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) source_image_pil = Image.fromarray( np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0)) source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size) if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style) ref_expression_coeff = proc.generate_still_expression(self.exp_scale) else: ref_pose_coeff = None ref_expression_coeff = None audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE']) batch = { 'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info, } return batch, audio_sample_rate def run_inference(self, batch): kp_extractor = self.sadtalker_model.kp_extractor generator = self.sadtalker_model.generator mapping = self.sadtalker_model.mapping he_estimator = self.sadtalker_model.he_estimator audio_to_coeff = self.sadtalker_model.audio_to_coeff animate_from_coeff = self.sadtalker_model.animate_from_coeff face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None with torch.no_grad(): kp_source = kp_extractor(batch['source_image']) if self.still_mode or self.use_idle_mode: ref_pose_coeff = batch['ref_pose_coeff'] ref_expression_coeff = batch['ref_expression_coeff'] pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) elif self.use_idle_mode: ref_pose_coeff = batch['ref_pose_coeff'] ref_expression_coeff = batch['ref_expression_coeff'] pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) else: if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']) pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info']) else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio']) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio']) coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff} if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio']) else: coeff['blink_coeff'] = None kp_driving = audio_to_coeff(batch['audio'])[0] kp_norm = animate_from_coeff.normalize_kp(kp_driving) coeff['kp_driving'] = kp_norm coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4 output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer) return output_video def post_processing(self, output_video, audio_sample_rate, batch): proc = self.sadtalker_model.preprocesser base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0] audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0] output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4') self.output_path = output_video_path video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][ 'OUTPUT_VIDEO_FPS'] is None else \ self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4') save_video_with_watermark(output_video, self.driven_audio, enhanced_path) paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path) os.remove(enhanced_path) else: save_video_with_watermark(output_video, self.driven_audio, output_video_path) if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio)) def save_result(self): return self.output_path def __call__(self): return self.output_path def test(self): batch, audio_sample_rate = self.get_test_data() output_video = self.run_inference(batch) self.post_processing(output_video, audio_sample_rate, batch) return self.save_result() class SadTalkerInnerModel: def __init__(self, sadtalker_cfg, device_id=[0]): self.cfg = sadtalker_cfg self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') self.preprocesser = Preprocesser(sadtalker_cfg, self.device) self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device) self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device) self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device) self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][ 'USE_ENHANCER'] else None self.generator = Generator(sadtalker_cfg, self.device) self.mapping = Mapping(sadtalker_cfg, self.device) self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device) class Preprocesser: def __init__(self, sadtalker_cfg, device): self.cfg = sadtalker_cfg self.device = device self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device) self.mouth_detector = MouthDetector() def crop(self, source_image_pil, preprocess_type, size=256): source_image = np.array(source_image_pil) face_info = self.face3d_helper.run(source_image) if face_info is None: raise Exception("No face detected") x_min, y_min, x_max, y_max = face_info[:4] old_size = (x_max - x_min, y_max - y_min) x_center = (x_max + x_min) / 2 y_center = (y_max + y_min) / 2 if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min) x_min = int(x_center - face_size / 2) y_min = int(y_center - face_size / 2) x_max = int(x_center + face_size / 2) y_max = int(y_center + face_size / 2) else: x_min -= int((x_max - x_min) * 0.1) y_min -= int((y_max - y_min) * 0.1) x_max += int((x_max - x_min) * 0.1) y_max += int((y_max - y_min) * 0.1) h, w = source_image.shape[:2] x_min = max(0, x_min) y_min = max(0, y_min) x_max = min(w, x_max) y_max = min(h, y_max) cropped_image = source_image[y_min:y_max, x_min:x_max] cropped_image_pil = Image.fromarray(cropped_image) if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS) source_image_tensor = self.img2tensor(cropped_image_pil) return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename( self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', '')) def img2tensor(self, img): img = np.array(img).astype(np.float32) / 255.0 img = np.transpose(img, (2, 0, 1)) return torch.FloatTensor(img) def video_to_tensor(self, video, device): video_tensor_list = [] import torchvision.transforms as transforms transform_func = transforms.ToTensor() for frame in video: frame_pil = Image.fromarray(frame) frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device) video_tensor_list.append(frame_tensor) video_tensor = torch.cat(video_tensor_list, dim=0) return video_tensor def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate) wav_tensor = torch.FloatTensor(wav).unsqueeze(0) return wav_tensor, sample_rate def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32) return ref_pose_coeff def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32) return ref_expression_coeff def generate_idles_pose(self, length_of_audio, pose_style): num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS']) ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) start_pose = self.generate_still_pose(pose_style) end_pose = self.generate_still_pose(pose_style) for frame_idx in range(num_frames): alpha = frame_idx / num_frames ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose return ref_pose_coeff def generate_idles_expression(self, length_of_audio): num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS']) ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) start_exp = self.generate_still_expression(1.0) end_exp = self.generate_still_expression(1.0) for frame_idx in range(num_frames): alpha = frame_idx / num_frames ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp return ref_expression_coeff class KeyPointExtractor(nn.Module): def __init__(self, sadtalker_cfg, device): super(KeyPointExtractor, self).__init__() self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device) checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors') load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector') def forward(self, x): kp = self.kp_extractor(x) return kp class Audio2Coeff(nn.Module): def __init__(self, sadtalker_cfg, device): super(Audio2Coeff, self).__init__() self.audio_model = Wav2Vec2Model().to(device) checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth') load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2') self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device) self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device) self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device) mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth') load_state_dict_robust(self, mapping_checkpoint, device) def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor) pose_coeff = self.pose_mapper(audio_embedding) if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6] pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1) return pose_coeff def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor) expression_coeff = self.exp_mapper(audio_embedding) if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff return expression_coeff def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor) blink_coeff = self.blink_mapper(audio_embedding) return blink_coeff def forward(self, audio): audio_embedding = self.audio_model(audio) pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper( audio_embedding), self.blink_mapper(audio_embedding) return pose_coeff, expression_coeff, blink_coeff def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True) std = coeff.std(dim=1, keepdim=True) return (coeff - mean) / std class AnimateFromCoeff(nn.Module): def __init__(self, sadtalker_cfg, device): super(AnimateFromCoeff, self).__init__() self.generator = Generator(sadtalker_cfg, device) self.mapping = Mapping(sadtalker_cfg, device) self.kp_norm = KeypointNorm(device=device) self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device) def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving) def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None): kp_driving = coeff['kp_driving'] jacobian = coeff['jacobian'] pose_coeff = coeff['pose_coeff'] expression_coeff = coeff['expression_coeff'] blink_coeff = coeff['blink_coeff'] face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff) sparse_motion = he_estimator(kp_source, kp_driving, jacobian) dense_motion = sparse_motion['dense_motion'] video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}) video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d) video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d'] if face_enhancer is not None: video_output_enhanced = [] for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) enhanced_image = face_enhancer.forward(np.array(pil_image)) video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)) video_output = video_output_enhanced return video_output def make_animation(self, video_array): H, W, _ = video_array[0].shape out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) out.release() video = imageio.mimread('./tmp.mp4') os.remove('./tmp.mp4') return video class Generator(nn.Module): def __init__(self, sadtalker_cfg, device): super(Generator, self).__init__() self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device) checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth') load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator') def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param) return {'video_3d': video_3d, 'video_no_reocclusion': video_3d} class Mapping(nn.Module): def __init__(self, sadtalker_cfg, device): super(Mapping, self).__init__() self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device) checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth') load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping') self.f_3d_mean = torch.zeros(1, 64, device=device) def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1) face_3d = self.mapping_net(coeff) + self.f_3d_mean if blink_coeff is not None: face_3d[:, -1:] = blink_coeff return face_3d class OcclusionAwareDenseMotion(nn.Module): def __init__(self, sadtalker_cfg, device): super(OcclusionAwareDenseMotion, self).__init__() self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device) checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth') load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion') def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian) return sparse_motion class FaceEnhancer(nn.Module): def __init__(self, sadtalker_cfg, device): super(FaceEnhancer, self).__init__() enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME'] bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER'] if enhancer_name == 'gfpgan': from gfpgan import GFPGANer self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF'] self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device) else: self.face_enhancer = None def forward(self, x): if self.face_enhancer: return self.face_enhancer.enhance(x, outscale=1)[0] return x def load_models(): checkpoint_path = './checkpoints' config_path = './src/config' size = 256 preprocess = 'crop' old_version = False sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version) print("SadTalker models loaded successfully!") return sadtalker_instance if __name__ == '__main__': sadtalker_instance = load_models()