Hhhh / sadtalker_utils.py
Kfjjdjdjdhdhd's picture
Upload 26 files
7b74407 verified
raw
history blame contribute delete
22.9 kB
import os, shutil, uuid, cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, yaml, safetensors, librosa, imageio
from PIL import Image
from skimage import img_as_ubyte, transform
from scipy.io import loadmat, wavfile
class SadTalker():
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = self.get_cfg_defaults()
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
self.cfg['MODEL']['CONFIG_DIR'] = config_path
self.cfg['MODEL']['DEVICE'] = self.device
self.cfg['INPUT_IMAGE'] = {}
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
self.cfg['INPUT_IMAGE']['SIZE'] = size
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
for filename, url in [(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL)]: download_model(url, filename, checkpoint_dir)
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
def get_cfg_defaults(self):
return {'MODEL': {'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False}, 'INPUT_IMAGE': {}}
def merge_from_file(self, filepath):
if os.path.exists(filepath):
with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f); self.cfg.update(cfg_from_file)
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'):
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang); return self.sadtalker_model.save_result()
class SadTalkerModel():
def __init__(self, sadtalker_cfg, device_id=[0]):
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
self.preprocesser = self.sadtalker.preprocesser
self.kp_extractor = self.sadtalker.kp_extractor; self.generator = self.sadtalker.generator
self.mapping = self.sadtalker.mapping; self.he_estimator = self.sadtalker.he_estimator
self.audio_to_coeff = self.sadtalker.audio_to_coeff; self.animate_from_coeff = self.sadtalker.animate_from_coeff; self.face_enhancer = self.sadtalker.face_enhancer
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image); return self.inner_test.test()
def save_result(self):
return self.inner_test.save_result()
class SadTalkerInner():
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
self.sadtalker_model = sadtalker_model; self.source_image = source_image; self.driven_audio = driven_audio
self.preprocess = preprocess; self.still_mode = still_mode; self.use_enhancer = use_enhancer
self.batch_size = batch_size; self.size = size; self.pose_style = pose_style; self.exp_scale = exp_scale
self.use_ref_video = use_ref_video; self.ref_video = ref_video; self.ref_info = ref_info
self.use_idle_mode = use_idle_mode; self.length_of_audio = length_of_audio; self.use_blink = use_blink
self.result_dir = result_dir; self.tts_text = tts_text; self.tts_lang = tts_lang
self.jitter_amount = jitter_amount; self.jitter_source_image = jitter_source_image; self.device = self.sadtalker_model.device; self.output_path = None
def get_test_data(self):
proc = self.sadtalker_model.preprocesser
if self.tts_text is not None: temp_dir = tempfile.mkdtemp(); audio_path = os.path.join(temp_dir, 'audio.wav'); tts = TTSTalker(); tts.test(self.tts_text, self.tts_lang); self.driven_audio = audio_path
source_image_pil = Image.open(self.source_image).convert('RGB')
if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); source_image_pil = Image.fromarray(np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style); ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
else: ref_pose_coeff = None; ref_expression_coeff = None
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
batch = {'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info}
return batch, audio_sample_rate
def run_inference(self, batch):
kp_extractor, generator, mapping, he_estimator, audio_to_coeff, animate_from_coeff, face_enhancer = self.sadtalker_model.kp_extractor, self.sadtalker_model.generator, self.sadtalker_model.mapping, self.sadtalker_model.he_estimator, self.sadtalker_model.audio_to_coeff, self.sadtalker_model.animate_from_coeff, self.sadtalker_model.face_enhancer
with torch.no_grad():
kp_source = kp_extractor(batch['source_image'])
if self.still_mode or self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
elif self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
else:
if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']); pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info'])
else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
else: coeff['blink_coeff'] = None
kp_driving = audio_to_coeff(batch['audio'])[0]; kp_norm = animate_from_coeff.normalize_kp(kp_driving); coeff['kp_driving'] = kp_norm; coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer)
return output_video
def post_processing(self, output_video, audio_sample_rate, batch):
proc = self.sadtalker_model.preprocesser; base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]; audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4'); self.output_path = output_video_path
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4'); save_video_with_watermark(output_video, self.driven_audio, enhanced_path); paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path); os.remove(enhanced_path)
else: save_video_with_watermark(output_video, self.driven_audio, output_video_path)
if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio))
def save_result(self):
return self.output_path
def __call__(self):
return self.output_path
def test(self):
batch, audio_sample_rate = self.get_test_data(); output_video = self.run_inference(batch); self.post_processing(output_video, audio_sample_rate, batch); return self.save_result()
class SadTalkerInnerModel():
def __init__(self, sadtalker_cfg, device_id=[0]):
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
self.preprocesser = Preprocesser(sadtalker_cfg, self.device); self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device); self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL']['USE_ENHANCER'] else None
self.generator = Generator(sadtalker_cfg, self.device); self.mapping = Mapping(sadtalker_cfg, self.device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
class Preprocesser():
def __init__(self, sadtalker_cfg, device):
self.cfg = sadtalker_cfg; self.device = device
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device); self.mouth_detector = MouthDetector()
def crop(self, source_image_pil, preprocess_type, size=256):
source_image = np.array(source_image_pil); face_info = self.face3d_helper.run(source_image)
if face_info is None: raise Exception("No face detected")
x_min, y_min, x_max, y_max = face_info[:4]; old_size = (x_max - x_min, y_max - y_min); x_center = (x_max + x_min) / 2; y_center = (y_max + y_min) / 2
if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min); x_min = int(x_center - face_size / 2); y_min = int(y_center - face_size / 2); x_max = int(x_center + face_size / 2); y_max = int(y_center + face_size / 2)
else: x_min -= int((x_max - x_min) * 0.1); y_min -= int((y_max - y_min) * 0.1); x_max += int((x_max - x_min) * 0.1); y_max += int((y_max - y_min) * 0.1)
h, w = source_image.shape[:2]; x_min = max(0, x_min); y_min = max(0, y_min); x_max = min(w, x_max); y_max = min(h, y_max)
cropped_image = source_image[y_min:y_max, x_min:x_max]; cropped_image_pil = Image.fromarray(cropped_image)
if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
source_image_tensor = self.img2tensor(cropped_image_pil); return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
def img2tensor(self, img):
img = np.array(img).astype(np.float32) / 255.0; img = np.transpose(img, (2, 0, 1)); return torch.FloatTensor(img)
def video_to_tensor(self, video, device): return 0
def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate); wav_tensor = torch.FloatTensor(wav).unsqueeze(0); return wav_tensor, sample_rate
def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32); return ref_pose_coeff
def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32); return ref_expression_coeff
def generate_idles_pose(self, length_of_audio, pose_style): return 0
def generate_idles_expression(self, length_of_audio): return 0
class KeyPointExtractor(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(KeyPointExtractor, self).__init__(); self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors'); load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
def forward(self, x): kp = self.kp_extractor(x); return kp
class Audio2Coeff(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Audio2Coeff, self).__init__(); self.audio_model = Wav2Vec2Model().to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth'); load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth'); load_state_dict_robust(self, mapping_checkpoint, device)
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor); pose_coeff = self.pose_mapper(audio_embedding)
if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff
if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6]; pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
return pose_coeff
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor); expression_coeff = self.exp_mapper(audio_embedding)
if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff; return expression_coeff
def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor); blink_coeff = self.blink_mapper(audio_embedding); return blink_coeff
def forward(self, audio): audio_embedding = self.audio_model(audio); pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(audio_embedding), self.blink_mapper(audio_embedding); return pose_coeff, expression_coeff, blink_coeff
def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True); std = coeff.std(dim=1, keepdim=True); return (coeff - mean) / std
class AnimateFromCoeff(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(AnimateFromCoeff, self).__init__(); self.generator = Generator(sadtalker_cfg, device); self.mapping = Mapping(sadtalker_cfg, device); self.kp_norm = KeypointNorm(device=device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving)
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None):
kp_driving, jacobian, pose_coeff, expression_coeff, blink_coeff = coeff['kp_driving'], coeff['jacobian'], coeff['pose_coeff'], coeff['expression_coeff'], coeff['blink_coeff']
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff); sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
dense_motion = sparse_motion['dense_motion']; video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d); video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
if face_enhancer is not None: video_output_enhanced = []; for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)); enhanced_image = face_enhancer.forward(np.array(pil_image)); video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)); video_output = video_output_enhanced
return video_output
def make_animation(self, video_array): H, W, _ = video_array[0].shape; out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)); for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)); out.release(); video = imageio.mimread('./tmp.mp4'); os.remove('./tmp.mp4'); return video
class Generator(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Generator, self).__init__(); self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth'); load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param); return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
class Mapping(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(Mapping, self).__init__(); self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth'); load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
self.f_3d_mean = torch.zeros(1, 64, device=device)
def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1); face_3d = self.mapping_net(coeff) + self.f_3d_mean; if blink_coeff is not None: face_3d[:, -1:] = blink_coeff; return face_3d
class OcclusionAwareDenseMotion(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(OcclusionAwareDenseMotion, self).__init__(); self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth'); load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian); return sparse_motion
class FaceEnhancer(nn.Module):
def __init__(self, sadtalker_cfg, device):
super(FaceEnhancer, self).__init__(); enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']; bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
if enhancer_name == 'gfpgan': from gfpgan import GFPGANer; self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler)
elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer; half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']; self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device)
else: self.face_enhancer = None
def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0] if self.face_enhancer else x
def download_model(url, filename, checkpoint_dir):
if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}..."); os.makedirs(checkpoint_dir, exist_ok=True); urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)); print(f"{filename} downloaded.")
else: print(f"{filename} already exists.")
def load_models():
checkpoint_path = './checkpoints'; config_path = './src/config'; size = 256; preprocess = 'crop'; old_version = False
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version); print("SadTalker models loaded successfully!"); return sadtalker_instance
if __name__ == '__main__': sadtalker_instance = load_models()