Hjgugugjhuhjggg commited on
Commit
64bd56c
·
verified ·
1 Parent(s): 7b127f2

Update sadtalker_utils.py

Browse files
Files changed (1) hide show
  1. sadtalker_utils.py +63 -130
sadtalker_utils.py CHANGED
@@ -17,7 +17,7 @@ from scipy import signal
17
  from scipy.io import loadmat, savemat, wavfile
18
  import glob
19
  import tempfile
20
- from tqdm import tqdm
21
  import math
22
  import torchaudio
23
  import urllib.request
@@ -64,6 +64,34 @@ def save_wav_util(wav, path, sr):
64
  wavfile.write(path, sr, wav.astype(np.int16))
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  class OcclusionAwareKPDetector(nn.Module):
68
 
69
  def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
@@ -175,12 +203,6 @@ class Face3DHelper:
175
  return [x_min, y_min, x_max, y_max]
176
 
177
 
178
- class Face3DHelperOld(Face3DHelper):
179
-
180
- def __init__(self, local_pca_path, device):
181
- super(Face3DHelperOld, self).__init__(local_pca_path, device)
182
-
183
-
184
  class MouthDetector:
185
 
186
  def __init__(self):
@@ -258,14 +280,12 @@ class SadTalker:
258
  self.cfg['INPUT_IMAGE']['SIZE'] = size
259
  self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
260
 
261
- download_model(kp_url, kp_file, checkpoint_path)
262
- download_model(aud_url, aud_file, checkpoint_path)
263
- download_model(wav_url, wav_file, checkpoint_path)
264
- download_model(gen_url, gen_file, checkpoint_path)
265
- download_model(mapx_url, mapx_file, checkpoint_path)
266
- download_model(den_url, den_file, checkpoint_path)
267
- download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
268
- download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
269
 
270
  self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
271
 
@@ -383,9 +403,6 @@ class SadTalkerInner:
383
  if self.still_mode or self.use_idle_mode:
384
  ref_pose_coeff = proc.generate_still_pose(self.pose_style)
385
  ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
386
- elif self.use_idle_mode:
387
- ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
388
- ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
389
  else:
390
  ref_pose_coeff = None
391
  ref_expression_coeff = None
@@ -414,7 +431,7 @@ class SadTalkerInner:
414
  he_estimator = self.sadtalker_model.he_estimator
415
  audio_to_coeff = self.sadtalker_model.audio_to_coeff
416
  animate_from_coeff = self.sadtalker_model.animate_from_coeff
417
- proc = self.sadtalker_model.preprocesser
418
  with torch.no_grad():
419
  kp_source = kp_extractor(batch['source_image'])
420
  if self.still_mode or self.use_idle_mode:
@@ -444,7 +461,6 @@ class SadTalkerInner:
444
  kp_norm = animate_from_coeff.normalize_kp(kp_driving)
445
  coeff['kp_driving'] = kp_norm
446
  coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
447
- face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
448
  output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
449
  he_estimator, batch['audio'], batch['source_image_crop'],
450
  face_enhancer=face_enhancer)
@@ -507,10 +523,7 @@ class Preprocesser:
507
  def __init__(self, sadtalker_cfg, device):
508
  self.cfg = sadtalker_cfg
509
  self.device = device
510
- if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
511
- self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
512
- else:
513
- self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
514
  self.mouth_detector = MouthDetector()
515
 
516
  def crop(self, source_image_pil, preprocess_type, size=256):
@@ -607,23 +620,7 @@ class KeyPointExtractor(nn.Module):
607
  num_dilation_blocks=2,
608
  dropout_rate=0.1).to(device)
609
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
610
- self.load_kp_detector(checkpoint_path, device)
611
-
612
- def load_kp_detector(self, checkpoint_path, device):
613
- if os.path.exists(checkpoint_path):
614
- if checkpoint_path.endswith('safetensors'):
615
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
616
- else:
617
- checkpoint = torch.load(checkpoint_path, map_location=device)
618
- try:
619
- self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
620
- except RuntimeError as e:
621
- print(f"Error loading kp_detector state_dict: {e}")
622
- print("Trying to load state_dict without prefix 'kp_detector.'")
623
- self.kp_extractor.load_state_dict(checkpoint, strict=False)
624
-
625
- else:
626
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
627
 
628
  def forward(self, x):
629
  kp = self.kp_extractor(x)
@@ -636,34 +633,12 @@ class Audio2Coeff(nn.Module):
636
  super(Audio2Coeff, self).__init__()
637
  self.audio_model = Wav2Vec2Model().to(device)
638
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
639
- self.load_audio_model(checkpoint_path, device)
640
  self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
641
  self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
642
  self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
643
- mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
644
- self.load_mapping_model(mapping_checkpoint, device)
645
-
646
- def load_audio_model(self, checkpoint_path, device):
647
- if os.path.exists(checkpoint_path):
648
- if checkpoint_path.endswith('safetensors'):
649
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
650
- else:
651
- checkpoint = torch.load(checkpoint_path, map_location=device)
652
- self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
653
- else:
654
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
655
-
656
- def load_mapping_model(self, checkpoint_path, device):
657
- if os.path.exists(checkpoint_path):
658
- if checkpoint_path.endswith('safetensors'):
659
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
660
- else:
661
- checkpoint = torch.load(checkpoint_path, map_location=device)
662
- self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
663
- self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
664
- self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
665
- else:
666
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
667
 
668
  def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
669
  audio_embedding = self.audio_model(audio_tensor)
@@ -718,31 +693,19 @@ class AnimateFromCoeff(nn.Module):
718
  pose_coeff = coeff['pose_coeff']
719
  expression_coeff = coeff['expression_coeff']
720
  blink_coeff = coeff['blink_coeff']
721
- with torch.no_grad():
722
- if blink_coeff is not None:
723
- sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
724
- dense_motion = sparse_motion['dense_motion']
725
- video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
726
- face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
727
- video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
728
- face_3d_param=face_3d)
729
- video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
730
- video_output = self.make_animation(video_output)
731
- else:
732
- sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
733
- dense_motion = sparse_motion['dense_motion']
734
- face_3d = mapping(expression_coeff, pose_coeff)
735
- video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
736
- face_3d_param=face_3d)
737
- video_output = video_3d['video_3d']
738
- video_output = self.make_animation(video_output)
739
- if face_enhancer is not None:
740
- video_output_enhanced = []
741
- for frame in tqdm(video_output, 'Face enhancer running'):
742
- pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
743
- enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
744
- video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
745
- video_output = video_output_enhanced
746
  return video_output
747
 
748
  def make_animation(self, video_array):
@@ -767,24 +730,10 @@ class Generator(nn.Module):
767
  kp_size=10,
768
  num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
769
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
770
- self.load_generator(checkpoint_path, device)
771
-
772
- def load_generator(self, checkpoint_path, device):
773
- if os.path.exists(checkpoint_path):
774
- if checkpoint_path.endswith('safetensors'):
775
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
776
- else:
777
- checkpoint = torch.load(checkpoint_path, map_location=device)
778
- self.generator.load_state_dict(checkpoint.get('generator', {}))
779
- else:
780
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
781
 
782
  def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
783
- if face_3d_param is not None:
784
- video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
785
- face_3d_param=face_3d_param)
786
- else:
787
- video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
788
  return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
789
 
790
 
@@ -794,19 +743,9 @@ class Mapping(nn.Module):
794
  super(Mapping, self).__init__()
795
  self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
796
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
797
- self.load_mapping_net(checkpoint_path, device)
798
  self.f_3d_mean = torch.zeros(1, 64, device=device)
799
 
800
- def load_mapping_net(self, checkpoint_path, device):
801
- if os.path.exists(checkpoint_path):
802
- if checkpoint_path.endswith('safetensors'):
803
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
804
- else:
805
- checkpoint = torch.load(checkpoint_path, map_location=device)
806
- self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
807
- else:
808
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
809
-
810
  def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
811
  coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
812
  face_3d = self.mapping_net(coeff) + self.f_3d_mean
@@ -825,17 +764,7 @@ class OcclusionAwareDenseMotion(nn.Module):
825
  num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
826
  max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
827
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
828
- self.load_dense_motion_network(checkpoint_path, device)
829
-
830
- def load_dense_motion_network(self, checkpoint_path, device):
831
- if os.path.exists(checkpoint_path):
832
- if checkpoint_path.endswith('safetensors'):
833
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
834
- else:
835
- checkpoint = torch.load(checkpoint_path, map_location=device)
836
- self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
837
- else:
838
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
839
 
840
  def forward(self, kp_source, kp_driving, jacobian):
841
  sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
@@ -870,7 +799,10 @@ class FaceEnhancer(nn.Module):
870
  self.face_enhancer = None
871
 
872
  def forward(self, x):
873
- return self.face_enhancer.enhance(x, outscale=1)[0]
 
 
 
874
 
875
  def load_models():
876
  checkpoint_path = './checkpoints'
@@ -883,5 +815,6 @@ def load_models():
883
  print("SadTalker models loaded successfully!")
884
  return sadtalker_instance
885
 
 
886
  if __name__ == '__main__':
887
  sadtalker_instance = load_models()
 
17
  from scipy.io import loadmat, savemat, wavfile
18
  import glob
19
  import tempfile
20
+ import tqdm
21
  import math
22
  import torchaudio
23
  import urllib.request
 
64
  wavfile.write(path, sr, wav.astype(np.int16))
65
 
66
 
67
+ def load_state_dict_robust(model, checkpoint_path, device, model_name="model"):
68
+ if not os.path.exists(checkpoint_path):
69
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
70
+ if checkpoint_path.endswith('safetensors'):
71
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
72
+ else:
73
+ checkpoint = torch.load(checkpoint_path, map_location=device)
74
+
75
+ state_dict = checkpoint.get(model_name, checkpoint)
76
+ try:
77
+ model.load_state_dict(state_dict)
78
+ except RuntimeError as e:
79
+ print(f"Error loading {model_name} state_dict: {e}")
80
+ print(f"Trying to load state_dict with key mapping for {model_name}.")
81
+ model_state_dict = model.state_dict()
82
+ mapped_state_dict = {}
83
+ for key, value in state_dict.items():
84
+ if key in model_state_dict and model_state_dict[key].shape == value.shape:
85
+ mapped_state_dict[key] = value
86
+ else:
87
+ print(f"Skipping key {key} due to shape mismatch or missing in model.")
88
+ missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False)
89
+ if missing_keys or unexpected_keys:
90
+ print(f"Missing keys: {missing_keys}")
91
+ print(f"Unexpected keys: {unexpected_keys}")
92
+ print(f"Successfully loaded {model_name} state_dict with key mapping.")
93
+
94
+
95
  class OcclusionAwareKPDetector(nn.Module):
96
 
97
  def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
 
203
  return [x_min, y_min, x_max, y_max]
204
 
205
 
 
 
 
 
 
 
206
  class MouthDetector:
207
 
208
  def __init__(self):
 
280
  self.cfg['INPUT_IMAGE']['SIZE'] = size
281
  self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
282
 
283
+ for filename, url in [
284
+ (kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url),
285
+ (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL),
286
+ ('RealESRGAN_x2plus.pth', REALESRGAN_URL)
287
+ ]:
288
+ download_model(url, filename, checkpoint_path)
 
 
289
 
290
  self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
291
 
 
403
  if self.still_mode or self.use_idle_mode:
404
  ref_pose_coeff = proc.generate_still_pose(self.pose_style)
405
  ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
 
 
 
406
  else:
407
  ref_pose_coeff = None
408
  ref_expression_coeff = None
 
431
  he_estimator = self.sadtalker_model.he_estimator
432
  audio_to_coeff = self.sadtalker_model.audio_to_coeff
433
  animate_from_coeff = self.sadtalker_model.animate_from_coeff
434
+ face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
435
  with torch.no_grad():
436
  kp_source = kp_extractor(batch['source_image'])
437
  if self.still_mode or self.use_idle_mode:
 
461
  kp_norm = animate_from_coeff.normalize_kp(kp_driving)
462
  coeff['kp_driving'] = kp_norm
463
  coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
 
464
  output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
465
  he_estimator, batch['audio'], batch['source_image_crop'],
466
  face_enhancer=face_enhancer)
 
523
  def __init__(self, sadtalker_cfg, device):
524
  self.cfg = sadtalker_cfg
525
  self.device = device
526
+ self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
 
 
 
527
  self.mouth_detector = MouthDetector()
528
 
529
  def crop(self, source_image_pil, preprocess_type, size=256):
 
620
  num_dilation_blocks=2,
621
  dropout_rate=0.1).to(device)
622
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
623
+ load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  def forward(self, x):
626
  kp = self.kp_extractor(x)
 
633
  super(Audio2Coeff, self).__init__()
634
  self.audio_model = Wav2Vec2Model().to(device)
635
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
636
+ load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
637
  self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
638
  self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
639
  self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
640
+ mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth')
641
+ load_state_dict_robust(self, mapping_checkpoint, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
644
  audio_embedding = self.audio_model(audio_tensor)
 
693
  pose_coeff = coeff['pose_coeff']
694
  expression_coeff = coeff['expression_coeff']
695
  blink_coeff = coeff['blink_coeff']
696
+ face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff)
697
+ sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
698
+ dense_motion = sparse_motion['dense_motion']
699
+ video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
700
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d)
701
+ video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
702
+ if face_enhancer is not None:
703
+ video_output_enhanced = []
704
+ for frame in tqdm(video_output, 'Face enhancer running'):
705
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
706
+ enhanced_image = face_enhancer.forward(np.array(pil_image))
707
+ video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
708
+ video_output = video_output_enhanced
 
 
 
 
 
 
 
 
 
 
 
 
709
  return video_output
710
 
711
  def make_animation(self, video_array):
 
730
  kp_size=10,
731
  num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
732
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
733
+ load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
 
 
 
 
 
 
 
 
 
 
734
 
735
  def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
736
+ video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param)
 
 
 
 
737
  return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
738
 
739
 
 
743
  super(Mapping, self).__init__()
744
  self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
745
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
746
+ load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
747
  self.f_3d_mean = torch.zeros(1, 64, device=device)
748
 
 
 
 
 
 
 
 
 
 
 
749
  def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
750
  coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
751
  face_3d = self.mapping_net(coeff) + self.f_3d_mean
 
764
  num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
765
  max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
766
  checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
767
+ load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
 
 
 
 
 
 
 
 
 
 
768
 
769
  def forward(self, kp_source, kp_driving, jacobian):
770
  sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
 
799
  self.face_enhancer = None
800
 
801
  def forward(self, x):
802
+ if self.face_enhancer:
803
+ return self.face_enhancer.enhance(x, outscale=1)[0]
804
+ return x
805
+
806
 
807
  def load_models():
808
  checkpoint_path = './checkpoints'
 
815
  print("SadTalker models loaded successfully!")
816
  return sadtalker_instance
817
 
818
+
819
  if __name__ == '__main__':
820
  sadtalker_instance = load_models()