Spaces:
Running
Running
import os | |
import shutil | |
import uuid | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
from PIL import Image | |
from skimage import img_as_ubyte, transform | |
import safetensors | |
import librosa | |
from pydub import AudioSegment | |
import imageio | |
from scipy import signal | |
from scipy.io import loadmat, savemat, wavfile | |
import glob | |
import tempfile | |
import tqdm | |
import math | |
import torchaudio | |
import urllib.request | |
from safetensors.torch import load_file, save_file | |
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" | |
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" | |
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" | |
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" | |
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors" | |
kp_file = "kp_detector.safetensors" | |
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth" | |
aud_file = "auido2pose_00140-model.pth" | |
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin" | |
wav_file = "wav2vec2.pth" | |
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth" | |
gen_file = "generator.pth" | |
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar" | |
mapx_file = "mapping.pth" | |
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth" | |
den_file = "dense_motion.pth" | |
def download_model(url, filename, checkpoint_dir): | |
if not os.path.exists(os.path.join(checkpoint_dir, filename)): | |
print(f"Downloading {filename}...") | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)) | |
print(f"{filename} downloaded.") | |
else: | |
print(f"{filename} already exists.") | |
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate): | |
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav") | |
def load_wav_util(path, sr): | |
return librosa.core.load(path, sr=sr)[0] | |
def save_wav_util(wav, path, sr): | |
wav *= 32767 / max(0.01, np.max(np.abs(wav))) | |
wavfile.write(path, sr, wav.astype(np.int16)) | |
def load_state_dict_robust(model, checkpoint_path, device, model_name="model"): | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
state_dict = checkpoint.get(model_name, checkpoint) | |
try: | |
model.load_state_dict(state_dict) | |
except RuntimeError as e: | |
print(f"Error loading {model_name} state_dict: {e}") | |
print(f"Trying to load state_dict with key mapping for {model_name}.") | |
model_state_dict = model.state_dict() | |
mapped_state_dict = {} | |
for key, value in state_dict.items(): | |
if key in model_state_dict and model_state_dict[key].shape == value.shape: | |
mapped_state_dict[key] = value | |
else: | |
print(f"Skipping key {key} due to shape mismatch or missing in model.") | |
missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False) | |
if missing_keys or unexpected_keys: | |
print(f"Missing keys: {missing_keys}") | |
print(f"Unexpected keys: {unexpected_keys}") | |
print(f"Successfully loaded {model_name} state_dict with key mapping.") | |
class OcclusionAwareKPDetector(nn.Module): | |
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate): | |
super(OcclusionAwareKPDetector, self).__init__() | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1) | |
def forward(self, x): | |
x = self.relu(self.bn1(self.conv1(x))) | |
x = self.conv2(x) | |
kp = {'value': x.view(x.size(0), -1)} | |
return kp | |
class Wav2Vec2Model(nn.Module): | |
def __init__(self): | |
super(Wav2Vec2Model, self).__init__() | |
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5) | |
self.bn = nn.BatchNorm1d(64) | |
self.relu = nn.ReLU() | |
self.fc = nn.Linear(64, 2048) | |
def forward(self, audio): | |
x = audio.unsqueeze(1) | |
x = self.relu(self.bn(self.conv(x))) | |
x = torch.mean(x, dim=-1) | |
x = self.fc(x) | |
return x | |
class AudioCoeffsPredictor(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(AudioCoeffsPredictor, self).__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, audio_embedding): | |
return self.linear(audio_embedding) | |
class MappingNet(nn.Module): | |
def __init__(self, num_coeffs, num_layers, hidden_dim): | |
super(MappingNet, self).__init__() | |
layers = [] | |
input_dim = num_coeffs * 2 | |
for _ in range(num_layers): | |
layers.append(nn.Linear(input_dim, hidden_dim)) | |
layers.append(nn.ReLU()) | |
input_dim = hidden_dim | |
layers.append(nn.Linear(hidden_dim, num_coeffs)) | |
self.net = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.net(x) | |
class DenseMotionNetwork(nn.Module): | |
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features): | |
super(DenseMotionNetwork, self).__init__() | |
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1) | |
def forward(self, kp_source, kp_driving, jacobian): | |
x = self.relu(self.conv1(kp_source)) | |
x = self.conv2(x) | |
sparse_motion = {'dense_motion': x} | |
return sparse_motion | |
class Hourglass(nn.Module): | |
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks): | |
super(Hourglass, self).__init__() | |
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3), | |
nn.BatchNorm2d(max_features), nn.ReLU()) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh()) | |
def forward(self, source_image, kp_driving, **kwargs): | |
x = self.encoder(source_image) | |
x = self.decoder(x) | |
B, C, H, W = x.size() | |
video = [] | |
for _ in range(10): | |
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype( | |
np.uint8) | |
video.append(frame) | |
return video | |
class Face3DHelper: | |
def __init__(self, local_pca_path, device): | |
self.local_pca_path = local_pca_path | |
self.device = device | |
def run(self, source_image): | |
h, w, _ = source_image.shape | |
x_min = w // 4 | |
y_min = h // 4 | |
x_max = x_min + w // 2 | |
y_max = y_min + h // 2 | |
return [x_min, y_min, x_max, y_max] | |
class MouthDetector: | |
def __init__(self): | |
pass | |
def detect(self, image): | |
h, w = image.shape[:2] | |
return (w // 2, h // 2) | |
class KeypointNorm(nn.Module): | |
def __init__(self, device): | |
super(KeypointNorm, self).__init__() | |
self.device = device | |
def forward(self, kp_driving): | |
return kp_driving | |
def save_video_with_watermark(video_frames, audio_path, output_path): | |
H, W, _ = video_frames[0].shape | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) | |
for frame in video_frames: | |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
out.release() | |
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path): | |
shutil.copy(video_path, output_path) | |
class TTSTalker: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tts_model = None | |
def load_model(self): | |
self.tts_model = self | |
def tokenizer(self, text): | |
return [ord(c) for c in text] | |
def __call__(self, input_tokens): | |
return torch.zeros(1, 16000, device=self.device) | |
def test(self, text, lang='en'): | |
if self.tts_model is None: | |
self.load_model() | |
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav') | |
os.makedirs('./results', exist_ok=True) | |
tokens = self.tokenizer(text) | |
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device) | |
with torch.no_grad(): | |
audio_output = self(input_tokens) | |
torchaudio.save(output_path, audio_output.cpu(), 16000) | |
return output_path | |
class SadTalker: | |
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', | |
old_version=False): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.cfg = self.get_cfg_defaults() | |
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml')) | |
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path | |
self.cfg['MODEL']['CONFIG_DIR'] = config_path | |
self.cfg['MODEL']['DEVICE'] = self.device | |
self.cfg['INPUT_IMAGE'] = {} | |
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None' | |
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None' | |
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess | |
self.cfg['INPUT_IMAGE']['SIZE'] = size | |
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version | |
for filename, url in [ | |
(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), | |
(mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), | |
('RealESRGAN_x2plus.pth', REALESRGAN_URL) | |
]: | |
download_model(url, filename, checkpoint_path) | |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0]) | |
def get_cfg_defaults(self): | |
return { | |
'MODEL': { | |
'CHECKPOINTS_DIR': '', | |
'CONFIG_DIR': '', | |
'DEVICE': self.device, | |
'SCALE': 64, | |
'NUM_VOXEL_FRAMES': 8, | |
'NUM_MOTION_FRAMES': 10, | |
'MAX_FEATURES': 256, | |
'DRIVEN_AUDIO_SAMPLE_RATE': 16000, | |
'VIDEO_FPS': 25, | |
'OUTPUT_VIDEO_FPS': None, | |
'OUTPUT_AUDIO_SAMPLE_RATE': None, | |
'USE_ENHANCER': False, | |
'ENHANCER_NAME': '', | |
'BG_UPSAMPLER': None, | |
'IS_HALF': False | |
}, | |
'INPUT_IMAGE': {} | |
} | |
def merge_from_file(self, filepath): | |
if os.path.exists(filepath): | |
with open(filepath, 'r') as f: | |
cfg_from_file = yaml.safe_load(f) | |
self.cfg.update(cfg_from_file) | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, | |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, | |
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', | |
tts_text=None, tts_lang='en'): | |
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, | |
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, | |
length_of_audio, use_blink, result_dir, tts_text, tts_lang) | |
return self.sadtalker_model.save_result() | |
class SadTalkerModel: | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg | |
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') | |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) | |
self.preprocesser = self.sadtalker.preprocesser | |
self.kp_extractor = self.sadtalker.kp_extractor | |
self.generator = self.sadtalker.generator | |
self.mapping = self.sadtalker.mapping | |
self.he_estimator = self.sadtalker.he_estimator | |
self.audio_to_coeff = self.sadtalker.audio_to_coeff | |
self.animate_from_coeff = self.sadtalker.animate_from_coeff | |
self.face_enhancer = self.sadtalker.face_enhancer | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, | |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, | |
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', | |
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False): | |
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, | |
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, | |
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, | |
jitter_amount, jitter_source_image) | |
return self.inner_test.test() | |
def save_result(self): | |
return self.inner_test.save_result() | |
class SadTalkerInner: | |
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, | |
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, | |
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image): | |
self.sadtalker_model = sadtalker_model | |
self.source_image = source_image | |
self.driven_audio = driven_audio | |
self.preprocess = preprocess | |
self.still_mode = still_mode | |
self.use_enhancer = use_enhancer | |
self.batch_size = batch_size | |
self.size = size | |
self.pose_style = pose_style | |
self.exp_scale = exp_scale | |
self.use_ref_video = use_ref_video | |
self.ref_video = ref_video | |
self.ref_info = ref_info | |
self.use_idle_mode = use_idle_mode | |
self.length_of_audio = length_of_audio | |
self.use_blink = use_blink | |
self.result_dir = result_dir | |
self.tts_text = tts_text | |
self.tts_lang = tts_lang | |
self.jitter_amount = jitter_amount | |
self.jitter_source_image = jitter_source_image | |
self.device = self.sadtalker_model.device | |
self.output_path = None | |
def get_test_data(self): | |
proc = self.sadtalker_model.preprocesser | |
if self.tts_text is not None: | |
temp_dir = tempfile.mkdtemp() | |
audio_path = os.path.join(temp_dir, 'audio.wav') | |
tts = TTSTalker() | |
tts.test(self.tts_text, self.tts_lang) | |
self.driven_audio = audio_path | |
source_image_pil = Image.open(self.source_image).convert('RGB') | |
if self.jitter_source_image: | |
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) | |
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) | |
source_image_pil = Image.fromarray( | |
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0)) | |
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size) | |
if self.still_mode or self.use_idle_mode: | |
ref_pose_coeff = proc.generate_still_pose(self.pose_style) | |
ref_expression_coeff = proc.generate_still_expression(self.exp_scale) | |
else: | |
ref_pose_coeff = None | |
ref_expression_coeff = None | |
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, | |
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE']) | |
batch = { | |
'source_image': source_image_tensor.unsqueeze(0).to(self.device), | |
'audio': audio_tensor.unsqueeze(0).to(self.device), | |
'ref_pose_coeff': ref_pose_coeff, | |
'ref_expression_coeff': ref_expression_coeff, | |
'source_image_crop': cropped_image, | |
'crop_info': crop_info, | |
'use_blink': self.use_blink, | |
'pose_style': self.pose_style, | |
'exp_scale': self.exp_scale, | |
'ref_video': self.ref_video, | |
'use_ref_video': self.use_ref_video, | |
'ref_info': self.ref_info, | |
} | |
return batch, audio_sample_rate | |
def run_inference(self, batch): | |
kp_extractor = self.sadtalker_model.kp_extractor | |
generator = self.sadtalker_model.generator | |
mapping = self.sadtalker_model.mapping | |
he_estimator = self.sadtalker_model.he_estimator | |
audio_to_coeff = self.sadtalker_model.audio_to_coeff | |
animate_from_coeff = self.sadtalker_model.animate_from_coeff | |
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None | |
with torch.no_grad(): | |
kp_source = kp_extractor(batch['source_image']) | |
if self.still_mode or self.use_idle_mode: | |
ref_pose_coeff = batch['ref_pose_coeff'] | |
ref_expression_coeff = batch['ref_expression_coeff'] | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) | |
elif self.use_idle_mode: | |
ref_pose_coeff = batch['ref_pose_coeff'] | |
ref_expression_coeff = batch['ref_expression_coeff'] | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) | |
else: | |
if self.use_ref_video: | |
kp_ref = kp_extractor(batch['source_image']) | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, | |
use_ref_info=batch['ref_info']) | |
else: | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio']) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio']) | |
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff} | |
if self.use_blink: | |
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio']) | |
else: | |
coeff['blink_coeff'] = None | |
kp_driving = audio_to_coeff(batch['audio'])[0] | |
kp_norm = animate_from_coeff.normalize_kp(kp_driving) | |
coeff['kp_driving'] = kp_norm | |
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4 | |
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, | |
he_estimator, batch['audio'], batch['source_image_crop'], | |
face_enhancer=face_enhancer) | |
return output_video | |
def post_processing(self, output_video, audio_sample_rate, batch): | |
proc = self.sadtalker_model.preprocesser | |
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0] | |
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0] | |
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4') | |
self.output_path = output_video_path | |
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][ | |
'OUTPUT_VIDEO_FPS'] is None else \ | |
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] | |
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \ | |
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \ | |
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] | |
if self.use_enhancer: | |
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4') | |
save_video_with_watermark(output_video, self.driven_audio, enhanced_path) | |
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, | |
output_video_path) | |
os.remove(enhanced_path) | |
else: | |
save_video_with_watermark(output_video, self.driven_audio, output_video_path) | |
if self.tts_text is not None: | |
shutil.rmtree(os.path.dirname(self.driven_audio)) | |
def save_result(self): | |
return self.output_path | |
def __call__(self): | |
return self.output_path | |
def test(self): | |
batch, audio_sample_rate = self.get_test_data() | |
output_video = self.run_inference(batch) | |
self.post_processing(output_video, audio_sample_rate, batch) | |
return self.save_result() | |
class SadTalkerInnerModel: | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg | |
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') | |
self.preprocesser = Preprocesser(sadtalker_cfg, self.device) | |
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device) | |
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device) | |
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device) | |
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][ | |
'USE_ENHANCER'] else None | |
self.generator = Generator(sadtalker_cfg, self.device) | |
self.mapping = Mapping(sadtalker_cfg, self.device) | |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device) | |
class Preprocesser: | |
def __init__(self, sadtalker_cfg, device): | |
self.cfg = sadtalker_cfg | |
self.device = device | |
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device) | |
self.mouth_detector = MouthDetector() | |
def crop(self, source_image_pil, preprocess_type, size=256): | |
source_image = np.array(source_image_pil) | |
face_info = self.face3d_helper.run(source_image) | |
if face_info is None: | |
raise Exception("No face detected") | |
x_min, y_min, x_max, y_max = face_info[:4] | |
old_size = (x_max - x_min, y_max - y_min) | |
x_center = (x_max + x_min) / 2 | |
y_center = (y_max + y_min) / 2 | |
if preprocess_type == 'crop': | |
face_size = max(x_max - x_min, y_max - y_min) | |
x_min = int(x_center - face_size / 2) | |
y_min = int(y_center - face_size / 2) | |
x_max = int(x_center + face_size / 2) | |
y_max = int(y_center + face_size / 2) | |
else: | |
x_min -= int((x_max - x_min) * 0.1) | |
y_min -= int((y_max - y_min) * 0.1) | |
x_max += int((x_max - x_min) * 0.1) | |
y_max += int((y_max - y_min) * 0.1) | |
h, w = source_image.shape[:2] | |
x_min = max(0, x_min) | |
y_min = max(0, y_min) | |
x_max = min(w, x_max) | |
y_max = min(h, y_max) | |
cropped_image = source_image[y_min:y_max, x_min:x_max] | |
cropped_image_pil = Image.fromarray(cropped_image) | |
if size is not None and size != 0: | |
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS) | |
source_image_tensor = self.img2tensor(cropped_image_pil) | |
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename( | |
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', '')) | |
def img2tensor(self, img): | |
img = np.array(img).astype(np.float32) / 255.0 | |
img = np.transpose(img, (2, 0, 1)) | |
return torch.FloatTensor(img) | |
def video_to_tensor(self, video, device): | |
video_tensor_list = [] | |
import torchvision.transforms as transforms | |
transform_func = transforms.ToTensor() | |
for frame in video: | |
frame_pil = Image.fromarray(frame) | |
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device) | |
video_tensor_list.append(frame_tensor) | |
video_tensor = torch.cat(video_tensor_list, dim=0) | |
return video_tensor | |
def process_audio(self, audio_path, sample_rate): | |
wav = load_wav_util(audio_path, sample_rate) | |
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) | |
return wav_tensor, sample_rate | |
def generate_still_pose(self, pose_style): | |
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) | |
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32) | |
return ref_pose_coeff | |
def generate_still_expression(self, exp_scale): | |
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) | |
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32) | |
return ref_expression_coeff | |
def generate_idles_pose(self, length_of_audio, pose_style): | |
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS']) | |
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) | |
start_pose = self.generate_still_pose(pose_style) | |
end_pose = self.generate_still_pose(pose_style) | |
for frame_idx in range(num_frames): | |
alpha = frame_idx / num_frames | |
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose | |
return ref_pose_coeff | |
def generate_idles_expression(self, length_of_audio): | |
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS']) | |
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) | |
start_exp = self.generate_still_expression(1.0) | |
end_exp = self.generate_still_expression(1.0) | |
for frame_idx in range(num_frames): | |
alpha = frame_idx / num_frames | |
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp | |
return ref_expression_coeff | |
class KeyPointExtractor(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(KeyPointExtractor, self).__init__() | |
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], | |
num_kp=10, | |
num_dilation_blocks=2, | |
dropout_rate=0.1).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors') | |
load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector') | |
def forward(self, x): | |
kp = self.kp_extractor(x) | |
return kp | |
class Audio2Coeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Audio2Coeff, self).__init__() | |
self.audio_model = Wav2Vec2Model().to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth') | |
load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2') | |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device) | |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device) | |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device) | |
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth') | |
load_state_dict_robust(self, mapping_checkpoint, device) | |
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): | |
audio_embedding = self.audio_model(audio_tensor) | |
pose_coeff = self.pose_mapper(audio_embedding) | |
if ref_pose_coeff is not None: | |
pose_coeff = ref_pose_coeff | |
if kp_ref is not None and use_ref_info == 'pose': | |
ref_pose_6d = kp_ref['value'][:, :6] | |
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1) | |
return pose_coeff | |
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): | |
audio_embedding = self.audio_model(audio_tensor) | |
expression_coeff = self.exp_mapper(audio_embedding) | |
if ref_expression_coeff is not None: | |
expression_coeff = ref_expression_coeff | |
return expression_coeff | |
def get_blink_coeff(self, audio_tensor): | |
audio_embedding = self.audio_model(audio_tensor) | |
blink_coeff = self.blink_mapper(audio_embedding) | |
return blink_coeff | |
def forward(self, audio): | |
audio_embedding = self.audio_model(audio) | |
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper( | |
audio_embedding), self.blink_mapper(audio_embedding) | |
return pose_coeff, expression_coeff, blink_coeff | |
def mean_std_normalize(self, coeff): | |
mean = coeff.mean(dim=1, keepdim=True) | |
std = coeff.std(dim=1, keepdim=True) | |
return (coeff - mean) / std | |
class AnimateFromCoeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(AnimateFromCoeff, self).__init__() | |
self.generator = Generator(sadtalker_cfg, device) | |
self.mapping = Mapping(sadtalker_cfg, device) | |
self.kp_norm = KeypointNorm(device=device) | |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device) | |
def normalize_kp(self, kp_driving): | |
return self.kp_norm(kp_driving) | |
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, | |
face_enhancer=None): | |
kp_driving = coeff['kp_driving'] | |
jacobian = coeff['jacobian'] | |
pose_coeff = coeff['pose_coeff'] | |
expression_coeff = coeff['expression_coeff'] | |
blink_coeff = coeff['blink_coeff'] | |
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff) | |
sparse_motion = he_estimator(kp_source, kp_driving, jacobian) | |
dense_motion = sparse_motion['dense_motion'] | |
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}) | |
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d) | |
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d'] | |
if face_enhancer is not None: | |
video_output_enhanced = [] | |
for frame in tqdm(video_output, 'Face enhancer running'): | |
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
enhanced_image = face_enhancer.forward(np.array(pil_image)) | |
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)) | |
video_output = video_output_enhanced | |
return video_output | |
def make_animation(self, video_array): | |
H, W, _ = video_array[0].shape | |
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) | |
for img in video_array: | |
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
out.release() | |
video = imageio.mimread('./tmp.mp4') | |
os.remove('./tmp.mp4') | |
return video | |
class Generator(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Generator, self).__init__() | |
self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], | |
num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], | |
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], | |
num_channels=3, | |
kp_size=10, | |
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth') | |
load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator') | |
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): | |
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param) | |
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d} | |
class Mapping(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Mapping, self).__init__() | |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth') | |
load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping') | |
self.f_3d_mean = torch.zeros(1, 64, device=device) | |
def forward(self, expression_coeff, pose_coeff, blink_coeff=None): | |
coeff = torch.cat([expression_coeff, pose_coeff], dim=1) | |
face_3d = self.mapping_net(coeff) + self.f_3d_mean | |
if blink_coeff is not None: | |
face_3d[:, -1:] = blink_coeff | |
return face_3d | |
class OcclusionAwareDenseMotion(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(OcclusionAwareDenseMotion, self).__init__() | |
self.dense_motion_network = DenseMotionNetwork(num_kp=10, | |
num_channels=3, | |
block_expansion=sadtalker_cfg['MODEL']['SCALE'], | |
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, | |
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth') | |
load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion') | |
def forward(self, kp_source, kp_driving, jacobian): | |
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian) | |
return sparse_motion | |
class FaceEnhancer(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(FaceEnhancer, self).__init__() | |
enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME'] | |
bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER'] | |
if enhancer_name == 'gfpgan': | |
from gfpgan import GFPGANer | |
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), | |
upscale=1, | |
arch='clean', | |
channel_multiplier=2, | |
bg_upsampler=bg_upsampler) | |
elif enhancer_name == 'realesrgan': | |
from realesrgan import RealESRGANer | |
half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF'] | |
self.face_enhancer = RealESRGANer(scale=2, | |
model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], | |
'RealESRGAN_x2plus.pth'), | |
tile=0, | |
tile_pad=10, | |
pre_pad=0, | |
half=half, | |
device=device) | |
else: | |
self.face_enhancer = None | |
def forward(self, x): | |
if self.face_enhancer: | |
return self.face_enhancer.enhance(x, outscale=1)[0] | |
return x | |
def load_models(): | |
checkpoint_path = './checkpoints' | |
config_path = './src/config' | |
size = 256 | |
preprocess = 'crop' | |
old_version = False | |
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version) | |
print("SadTalker models loaded successfully!") | |
return sadtalker_instance | |
if __name__ == '__main__': | |
sadtalker_instance = load_models() |