Spaces:
Running
Running
Update sadtalker_utils.py
Browse files- sadtalker_utils.py +63 -130
sadtalker_utils.py
CHANGED
@@ -17,7 +17,7 @@ from scipy import signal
|
|
17 |
from scipy.io import loadmat, savemat, wavfile
|
18 |
import glob
|
19 |
import tempfile
|
20 |
-
|
21 |
import math
|
22 |
import torchaudio
|
23 |
import urllib.request
|
@@ -64,6 +64,34 @@ def save_wav_util(wav, path, sr):
|
|
64 |
wavfile.write(path, sr, wav.astype(np.int16))
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
class OcclusionAwareKPDetector(nn.Module):
|
68 |
|
69 |
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
|
@@ -175,12 +203,6 @@ class Face3DHelper:
|
|
175 |
return [x_min, y_min, x_max, y_max]
|
176 |
|
177 |
|
178 |
-
class Face3DHelperOld(Face3DHelper):
|
179 |
-
|
180 |
-
def __init__(self, local_pca_path, device):
|
181 |
-
super(Face3DHelperOld, self).__init__(local_pca_path, device)
|
182 |
-
|
183 |
-
|
184 |
class MouthDetector:
|
185 |
|
186 |
def __init__(self):
|
@@ -258,14 +280,12 @@ class SadTalker:
|
|
258 |
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
259 |
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
260 |
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
|
268 |
-
download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
|
269 |
|
270 |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
271 |
|
@@ -383,9 +403,6 @@ class SadTalkerInner:
|
|
383 |
if self.still_mode or self.use_idle_mode:
|
384 |
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
385 |
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
386 |
-
elif self.use_idle_mode:
|
387 |
-
ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
|
388 |
-
ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
|
389 |
else:
|
390 |
ref_pose_coeff = None
|
391 |
ref_expression_coeff = None
|
@@ -414,7 +431,7 @@ class SadTalkerInner:
|
|
414 |
he_estimator = self.sadtalker_model.he_estimator
|
415 |
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
416 |
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
417 |
-
|
418 |
with torch.no_grad():
|
419 |
kp_source = kp_extractor(batch['source_image'])
|
420 |
if self.still_mode or self.use_idle_mode:
|
@@ -444,7 +461,6 @@ class SadTalkerInner:
|
|
444 |
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
445 |
coeff['kp_driving'] = kp_norm
|
446 |
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
447 |
-
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
448 |
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
449 |
he_estimator, batch['audio'], batch['source_image_crop'],
|
450 |
face_enhancer=face_enhancer)
|
@@ -507,10 +523,7 @@ class Preprocesser:
|
|
507 |
def __init__(self, sadtalker_cfg, device):
|
508 |
self.cfg = sadtalker_cfg
|
509 |
self.device = device
|
510 |
-
|
511 |
-
self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
512 |
-
else:
|
513 |
-
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
514 |
self.mouth_detector = MouthDetector()
|
515 |
|
516 |
def crop(self, source_image_pil, preprocess_type, size=256):
|
@@ -607,23 +620,7 @@ class KeyPointExtractor(nn.Module):
|
|
607 |
num_dilation_blocks=2,
|
608 |
dropout_rate=0.1).to(device)
|
609 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
610 |
-
self.
|
611 |
-
|
612 |
-
def load_kp_detector(self, checkpoint_path, device):
|
613 |
-
if os.path.exists(checkpoint_path):
|
614 |
-
if checkpoint_path.endswith('safetensors'):
|
615 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
616 |
-
else:
|
617 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
618 |
-
try:
|
619 |
-
self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
|
620 |
-
except RuntimeError as e:
|
621 |
-
print(f"Error loading kp_detector state_dict: {e}")
|
622 |
-
print("Trying to load state_dict without prefix 'kp_detector.'")
|
623 |
-
self.kp_extractor.load_state_dict(checkpoint, strict=False)
|
624 |
-
|
625 |
-
else:
|
626 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
627 |
|
628 |
def forward(self, x):
|
629 |
kp = self.kp_extractor(x)
|
@@ -636,34 +633,12 @@ class Audio2Coeff(nn.Module):
|
|
636 |
super(Audio2Coeff, self).__init__()
|
637 |
self.audio_model = Wav2Vec2Model().to(device)
|
638 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
|
639 |
-
self.
|
640 |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
641 |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
642 |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
643 |
-
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], '
|
644 |
-
self
|
645 |
-
|
646 |
-
def load_audio_model(self, checkpoint_path, device):
|
647 |
-
if os.path.exists(checkpoint_path):
|
648 |
-
if checkpoint_path.endswith('safetensors'):
|
649 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
650 |
-
else:
|
651 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
652 |
-
self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
|
653 |
-
else:
|
654 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
655 |
-
|
656 |
-
def load_mapping_model(self, checkpoint_path, device):
|
657 |
-
if os.path.exists(checkpoint_path):
|
658 |
-
if checkpoint_path.endswith('safetensors'):
|
659 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
660 |
-
else:
|
661 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
662 |
-
self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
|
663 |
-
self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
|
664 |
-
self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
|
665 |
-
else:
|
666 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
667 |
|
668 |
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
|
669 |
audio_embedding = self.audio_model(audio_tensor)
|
@@ -718,31 +693,19 @@ class AnimateFromCoeff(nn.Module):
|
|
718 |
pose_coeff = coeff['pose_coeff']
|
719 |
expression_coeff = coeff['expression_coeff']
|
720 |
blink_coeff = coeff['blink_coeff']
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
face_3d = mapping(expression_coeff, pose_coeff)
|
735 |
-
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
|
736 |
-
face_3d_param=face_3d)
|
737 |
-
video_output = video_3d['video_3d']
|
738 |
-
video_output = self.make_animation(video_output)
|
739 |
-
if face_enhancer is not None:
|
740 |
-
video_output_enhanced = []
|
741 |
-
for frame in tqdm(video_output, 'Face enhancer running'):
|
742 |
-
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
743 |
-
enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
|
744 |
-
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
|
745 |
-
video_output = video_output_enhanced
|
746 |
return video_output
|
747 |
|
748 |
def make_animation(self, video_array):
|
@@ -767,24 +730,10 @@ class Generator(nn.Module):
|
|
767 |
kp_size=10,
|
768 |
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
769 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
|
770 |
-
self.
|
771 |
-
|
772 |
-
def load_generator(self, checkpoint_path, device):
|
773 |
-
if os.path.exists(checkpoint_path):
|
774 |
-
if checkpoint_path.endswith('safetensors'):
|
775 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
776 |
-
else:
|
777 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
778 |
-
self.generator.load_state_dict(checkpoint.get('generator', {}))
|
779 |
-
else:
|
780 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
781 |
|
782 |
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
|
783 |
-
|
784 |
-
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
|
785 |
-
face_3d_param=face_3d_param)
|
786 |
-
else:
|
787 |
-
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
|
788 |
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
789 |
|
790 |
|
@@ -794,19 +743,9 @@ class Mapping(nn.Module):
|
|
794 |
super(Mapping, self).__init__()
|
795 |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
796 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
|
797 |
-
self.
|
798 |
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
799 |
|
800 |
-
def load_mapping_net(self, checkpoint_path, device):
|
801 |
-
if os.path.exists(checkpoint_path):
|
802 |
-
if checkpoint_path.endswith('safetensors'):
|
803 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
804 |
-
else:
|
805 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
806 |
-
self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
|
807 |
-
else:
|
808 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
809 |
-
|
810 |
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
|
811 |
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
|
812 |
face_3d = self.mapping_net(coeff) + self.f_3d_mean
|
@@ -825,17 +764,7 @@ class OcclusionAwareDenseMotion(nn.Module):
|
|
825 |
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
|
826 |
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
827 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
|
828 |
-
self.
|
829 |
-
|
830 |
-
def load_dense_motion_network(self, checkpoint_path, device):
|
831 |
-
if os.path.exists(checkpoint_path):
|
832 |
-
if checkpoint_path.endswith('safetensors'):
|
833 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
834 |
-
else:
|
835 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
836 |
-
self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
|
837 |
-
else:
|
838 |
-
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
839 |
|
840 |
def forward(self, kp_source, kp_driving, jacobian):
|
841 |
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
|
@@ -870,7 +799,10 @@ class FaceEnhancer(nn.Module):
|
|
870 |
self.face_enhancer = None
|
871 |
|
872 |
def forward(self, x):
|
873 |
-
|
|
|
|
|
|
|
874 |
|
875 |
def load_models():
|
876 |
checkpoint_path = './checkpoints'
|
@@ -883,5 +815,6 @@ def load_models():
|
|
883 |
print("SadTalker models loaded successfully!")
|
884 |
return sadtalker_instance
|
885 |
|
|
|
886 |
if __name__ == '__main__':
|
887 |
sadtalker_instance = load_models()
|
|
|
17 |
from scipy.io import loadmat, savemat, wavfile
|
18 |
import glob
|
19 |
import tempfile
|
20 |
+
import tqdm
|
21 |
import math
|
22 |
import torchaudio
|
23 |
import urllib.request
|
|
|
64 |
wavfile.write(path, sr, wav.astype(np.int16))
|
65 |
|
66 |
|
67 |
+
def load_state_dict_robust(model, checkpoint_path, device, model_name="model"):
|
68 |
+
if not os.path.exists(checkpoint_path):
|
69 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
70 |
+
if checkpoint_path.endswith('safetensors'):
|
71 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
72 |
+
else:
|
73 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
74 |
+
|
75 |
+
state_dict = checkpoint.get(model_name, checkpoint)
|
76 |
+
try:
|
77 |
+
model.load_state_dict(state_dict)
|
78 |
+
except RuntimeError as e:
|
79 |
+
print(f"Error loading {model_name} state_dict: {e}")
|
80 |
+
print(f"Trying to load state_dict with key mapping for {model_name}.")
|
81 |
+
model_state_dict = model.state_dict()
|
82 |
+
mapped_state_dict = {}
|
83 |
+
for key, value in state_dict.items():
|
84 |
+
if key in model_state_dict and model_state_dict[key].shape == value.shape:
|
85 |
+
mapped_state_dict[key] = value
|
86 |
+
else:
|
87 |
+
print(f"Skipping key {key} due to shape mismatch or missing in model.")
|
88 |
+
missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False)
|
89 |
+
if missing_keys or unexpected_keys:
|
90 |
+
print(f"Missing keys: {missing_keys}")
|
91 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
92 |
+
print(f"Successfully loaded {model_name} state_dict with key mapping.")
|
93 |
+
|
94 |
+
|
95 |
class OcclusionAwareKPDetector(nn.Module):
|
96 |
|
97 |
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
|
|
|
203 |
return [x_min, y_min, x_max, y_max]
|
204 |
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
class MouthDetector:
|
207 |
|
208 |
def __init__(self):
|
|
|
280 |
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
281 |
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
282 |
|
283 |
+
for filename, url in [
|
284 |
+
(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url),
|
285 |
+
(mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL),
|
286 |
+
('RealESRGAN_x2plus.pth', REALESRGAN_URL)
|
287 |
+
]:
|
288 |
+
download_model(url, filename, checkpoint_path)
|
|
|
|
|
289 |
|
290 |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
291 |
|
|
|
403 |
if self.still_mode or self.use_idle_mode:
|
404 |
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
405 |
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
|
|
|
|
|
|
406 |
else:
|
407 |
ref_pose_coeff = None
|
408 |
ref_expression_coeff = None
|
|
|
431 |
he_estimator = self.sadtalker_model.he_estimator
|
432 |
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
433 |
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
434 |
+
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
435 |
with torch.no_grad():
|
436 |
kp_source = kp_extractor(batch['source_image'])
|
437 |
if self.still_mode or self.use_idle_mode:
|
|
|
461 |
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
462 |
coeff['kp_driving'] = kp_norm
|
463 |
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
|
|
464 |
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
465 |
he_estimator, batch['audio'], batch['source_image_crop'],
|
466 |
face_enhancer=face_enhancer)
|
|
|
523 |
def __init__(self, sadtalker_cfg, device):
|
524 |
self.cfg = sadtalker_cfg
|
525 |
self.device = device
|
526 |
+
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
|
|
|
|
|
|
527 |
self.mouth_detector = MouthDetector()
|
528 |
|
529 |
def crop(self, source_image_pil, preprocess_type, size=256):
|
|
|
620 |
num_dilation_blocks=2,
|
621 |
dropout_rate=0.1).to(device)
|
622 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
623 |
+
load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
def forward(self, x):
|
626 |
kp = self.kp_extractor(x)
|
|
|
633 |
super(Audio2Coeff, self).__init__()
|
634 |
self.audio_model = Wav2Vec2Model().to(device)
|
635 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
|
636 |
+
load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
|
637 |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
638 |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
639 |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
640 |
+
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth')
|
641 |
+
load_state_dict_robust(self, mapping_checkpoint, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
|
643 |
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
|
644 |
audio_embedding = self.audio_model(audio_tensor)
|
|
|
693 |
pose_coeff = coeff['pose_coeff']
|
694 |
expression_coeff = coeff['expression_coeff']
|
695 |
blink_coeff = coeff['blink_coeff']
|
696 |
+
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff)
|
697 |
+
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
698 |
+
dense_motion = sparse_motion['dense_motion']
|
699 |
+
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
|
700 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d)
|
701 |
+
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
|
702 |
+
if face_enhancer is not None:
|
703 |
+
video_output_enhanced = []
|
704 |
+
for frame in tqdm(video_output, 'Face enhancer running'):
|
705 |
+
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
706 |
+
enhanced_image = face_enhancer.forward(np.array(pil_image))
|
707 |
+
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
|
708 |
+
video_output = video_output_enhanced
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
return video_output
|
710 |
|
711 |
def make_animation(self, video_array):
|
|
|
730 |
kp_size=10,
|
731 |
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
732 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
|
733 |
+
load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
|
735 |
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
|
736 |
+
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param)
|
|
|
|
|
|
|
|
|
737 |
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
738 |
|
739 |
|
|
|
743 |
super(Mapping, self).__init__()
|
744 |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
745 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
|
746 |
+
load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
|
747 |
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
749 |
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
|
750 |
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
|
751 |
face_3d = self.mapping_net(coeff) + self.f_3d_mean
|
|
|
764 |
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
|
765 |
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
766 |
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
|
767 |
+
load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
|
769 |
def forward(self, kp_source, kp_driving, jacobian):
|
770 |
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
|
|
|
799 |
self.face_enhancer = None
|
800 |
|
801 |
def forward(self, x):
|
802 |
+
if self.face_enhancer:
|
803 |
+
return self.face_enhancer.enhance(x, outscale=1)[0]
|
804 |
+
return x
|
805 |
+
|
806 |
|
807 |
def load_models():
|
808 |
checkpoint_path = './checkpoints'
|
|
|
815 |
print("SadTalker models loaded successfully!")
|
816 |
return sadtalker_instance
|
817 |
|
818 |
+
|
819 |
if __name__ == '__main__':
|
820 |
sadtalker_instance = load_models()
|