Spaces:
Running
Running
import os, shutil, uuid, cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, yaml, safetensors, librosa, imageio | |
from PIL import Image | |
from skimage import img_as_ubyte, transform | |
from scipy.io import loadmat, wavfile | |
class SadTalker(): | |
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.cfg = self.get_cfg_defaults() | |
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml')) | |
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path | |
self.cfg['MODEL']['CONFIG_DIR'] = config_path | |
self.cfg['MODEL']['DEVICE'] = self.device | |
self.cfg['INPUT_IMAGE'] = {} | |
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None' | |
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None' | |
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess | |
self.cfg['INPUT_IMAGE']['SIZE'] = size | |
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version | |
for filename, url in [(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL)]: download_model(url, filename, checkpoint_dir) | |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0]) | |
def get_cfg_defaults(self): | |
return {'MODEL': {'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False}, 'INPUT_IMAGE': {}} | |
def merge_from_file(self, filepath): | |
if os.path.exists(filepath): | |
with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f); self.cfg.update(cfg_from_file) | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'): | |
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang); return self.sadtalker_model.save_result() | |
class SadTalkerModel(): | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') | |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) | |
self.preprocesser = self.sadtalker.preprocesser | |
self.kp_extractor = self.sadtalker.kp_extractor; self.generator = self.sadtalker.generator | |
self.mapping = self.sadtalker.mapping; self.he_estimator = self.sadtalker.he_estimator | |
self.audio_to_coeff = self.sadtalker.audio_to_coeff; self.animate_from_coeff = self.sadtalker.animate_from_coeff; self.face_enhancer = self.sadtalker.face_enhancer | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False): | |
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image); return self.inner_test.test() | |
def save_result(self): | |
return self.inner_test.save_result() | |
class SadTalkerInner(): | |
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image): | |
self.sadtalker_model = sadtalker_model; self.source_image = source_image; self.driven_audio = driven_audio | |
self.preprocess = preprocess; self.still_mode = still_mode; self.use_enhancer = use_enhancer | |
self.batch_size = batch_size; self.size = size; self.pose_style = pose_style; self.exp_scale = exp_scale | |
self.use_ref_video = use_ref_video; self.ref_video = ref_video; self.ref_info = ref_info | |
self.use_idle_mode = use_idle_mode; self.length_of_audio = length_of_audio; self.use_blink = use_blink | |
self.result_dir = result_dir; self.tts_text = tts_text; self.tts_lang = tts_lang | |
self.jitter_amount = jitter_amount; self.jitter_source_image = jitter_source_image; self.device = self.sadtalker_model.device; self.output_path = None | |
def get_test_data(self): | |
proc = self.sadtalker_model.preprocesser | |
if self.tts_text is not None: temp_dir = tempfile.mkdtemp(); audio_path = os.path.join(temp_dir, 'audio.wav'); tts = TTSTalker(); tts.test(self.tts_text, self.tts_lang); self.driven_audio = audio_path | |
source_image_pil = Image.open(self.source_image).convert('RGB') | |
if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); source_image_pil = Image.fromarray(np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0)) | |
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size) | |
if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style); ref_expression_coeff = proc.generate_still_expression(self.exp_scale) | |
else: ref_pose_coeff = None; ref_expression_coeff = None | |
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE']) | |
batch = {'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info} | |
return batch, audio_sample_rate | |
def run_inference(self, batch): | |
kp_extractor, generator, mapping, he_estimator, audio_to_coeff, animate_from_coeff, face_enhancer = self.sadtalker_model.kp_extractor, self.sadtalker_model.generator, self.sadtalker_model.mapping, self.sadtalker_model.he_estimator, self.sadtalker_model.audio_to_coeff, self.sadtalker_model.animate_from_coeff, self.sadtalker_model.face_enhancer | |
with torch.no_grad(): | |
kp_source = kp_extractor(batch['source_image']) | |
if self.still_mode or self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff']) | |
elif self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff']) | |
else: | |
if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']); pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info']) | |
else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio']) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio']) | |
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff} | |
if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio']) | |
else: coeff['blink_coeff'] = None | |
kp_driving = audio_to_coeff(batch['audio'])[0]; kp_norm = animate_from_coeff.normalize_kp(kp_driving); coeff['kp_driving'] = kp_norm; coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4 | |
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer) | |
return output_video | |
def post_processing(self, output_video, audio_sample_rate, batch): | |
proc = self.sadtalker_model.preprocesser; base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]; audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0] | |
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4'); self.output_path = output_video_path | |
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] | |
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] | |
if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4'); save_video_with_watermark(output_video, self.driven_audio, enhanced_path); paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path); os.remove(enhanced_path) | |
else: save_video_with_watermark(output_video, self.driven_audio, output_video_path) | |
if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio)) | |
def save_result(self): | |
return self.output_path | |
def __call__(self): | |
return self.output_path | |
def test(self): | |
batch, audio_sample_rate = self.get_test_data(); output_video = self.run_inference(batch); self.post_processing(output_video, audio_sample_rate, batch); return self.save_result() | |
class SadTalkerInnerModel(): | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu') | |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) | |
self.preprocesser = Preprocesser(sadtalker_cfg, self.device); self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device) | |
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device); self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device) | |
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL']['USE_ENHANCER'] else None | |
self.generator = Generator(sadtalker_cfg, self.device); self.mapping = Mapping(sadtalker_cfg, self.device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device) | |
class Preprocesser(): | |
def __init__(self, sadtalker_cfg, device): | |
self.cfg = sadtalker_cfg; self.device = device | |
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device); self.mouth_detector = MouthDetector() | |
def crop(self, source_image_pil, preprocess_type, size=256): | |
source_image = np.array(source_image_pil); face_info = self.face3d_helper.run(source_image) | |
if face_info is None: raise Exception("No face detected") | |
x_min, y_min, x_max, y_max = face_info[:4]; old_size = (x_max - x_min, y_max - y_min); x_center = (x_max + x_min) / 2; y_center = (y_max + y_min) / 2 | |
if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min); x_min = int(x_center - face_size / 2); y_min = int(y_center - face_size / 2); x_max = int(x_center + face_size / 2); y_max = int(y_center + face_size / 2) | |
else: x_min -= int((x_max - x_min) * 0.1); y_min -= int((y_max - y_min) * 0.1); x_max += int((x_max - x_min) * 0.1); y_max += int((y_max - y_min) * 0.1) | |
h, w = source_image.shape[:2]; x_min = max(0, x_min); y_min = max(0, y_min); x_max = min(w, x_max); y_max = min(h, y_max) | |
cropped_image = source_image[y_min:y_max, x_min:x_max]; cropped_image_pil = Image.fromarray(cropped_image) | |
if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS) | |
source_image_tensor = self.img2tensor(cropped_image_pil); return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', '')) | |
def img2tensor(self, img): | |
img = np.array(img).astype(np.float32) / 255.0; img = np.transpose(img, (2, 0, 1)); return torch.FloatTensor(img) | |
def video_to_tensor(self, video, device): return 0 | |
def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate); wav_tensor = torch.FloatTensor(wav).unsqueeze(0); return wav_tensor, sample_rate | |
def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32); return ref_pose_coeff | |
def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32); return ref_expression_coeff | |
def generate_idles_pose(self, length_of_audio, pose_style): return 0 | |
def generate_idles_expression(self, length_of_audio): return 0 | |
class KeyPointExtractor(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(KeyPointExtractor, self).__init__(); self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors'); load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector') | |
def forward(self, x): kp = self.kp_extractor(x); return kp | |
class Audio2Coeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Audio2Coeff, self).__init__(); self.audio_model = Wav2Vec2Model().to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth'); load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2') | |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device) | |
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth'); load_state_dict_robust(self, mapping_checkpoint, device) | |
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor); pose_coeff = self.pose_mapper(audio_embedding) | |
if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff | |
if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6]; pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1) | |
return pose_coeff | |
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor); expression_coeff = self.exp_mapper(audio_embedding) | |
if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff; return expression_coeff | |
def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor); blink_coeff = self.blink_mapper(audio_embedding); return blink_coeff | |
def forward(self, audio): audio_embedding = self.audio_model(audio); pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(audio_embedding), self.blink_mapper(audio_embedding); return pose_coeff, expression_coeff, blink_coeff | |
def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True); std = coeff.std(dim=1, keepdim=True); return (coeff - mean) / std | |
class AnimateFromCoeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(AnimateFromCoeff, self).__init__(); self.generator = Generator(sadtalker_cfg, device); self.mapping = Mapping(sadtalker_cfg, device); self.kp_norm = KeypointNorm(device=device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device) | |
def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving) | |
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None): | |
kp_driving, jacobian, pose_coeff, expression_coeff, blink_coeff = coeff['kp_driving'], coeff['jacobian'], coeff['pose_coeff'], coeff['expression_coeff'], coeff['blink_coeff'] | |
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff); sparse_motion = he_estimator(kp_source, kp_driving, jacobian) | |
dense_motion = sparse_motion['dense_motion']; video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}) | |
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d); video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d'] | |
if face_enhancer is not None: video_output_enhanced = []; for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)); enhanced_image = face_enhancer.forward(np.array(pil_image)); video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)); video_output = video_output_enhanced | |
return video_output | |
def make_animation(self, video_array): H, W, _ = video_array[0].shape; out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)); for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)); out.release(); video = imageio.mimread('./tmp.mp4'); os.remove('./tmp.mp4'); return video | |
class Generator(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Generator, self).__init__(); self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth'); load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator') | |
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param); return {'video_3d': video_3d, 'video_no_reocclusion': video_3d} | |
class Mapping(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Mapping, self).__init__(); self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth'); load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping') | |
self.f_3d_mean = torch.zeros(1, 64, device=device) | |
def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1); face_3d = self.mapping_net(coeff) + self.f_3d_mean; if blink_coeff is not None: face_3d[:, -1:] = blink_coeff; return face_3d | |
class OcclusionAwareDenseMotion(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(OcclusionAwareDenseMotion, self).__init__(); self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth'); load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion') | |
def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian); return sparse_motion | |
class FaceEnhancer(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(FaceEnhancer, self).__init__(); enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']; bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER'] | |
if enhancer_name == 'gfpgan': from gfpgan import GFPGANer; self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) | |
elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer; half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']; self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device) | |
else: self.face_enhancer = None | |
def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0] if self.face_enhancer else x | |
def download_model(url, filename, checkpoint_dir): | |
if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}..."); os.makedirs(checkpoint_dir, exist_ok=True); urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)); print(f"{filename} downloaded.") | |
else: print(f"{filename} already exists.") | |
def load_models(): | |
checkpoint_path = './checkpoints'; config_path = './src/config'; size = 256; preprocess = 'crop'; old_version = False | |
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version); print("SadTalker models loaded successfully!"); return sadtalker_instance | |
if __name__ == '__main__': sadtalker_instance = load_models() | |