Hjgugugjhuhjggg commited on
Commit
7b127f2
·
verified ·
1 Parent(s): 9a9e67b

Update sadtalker_utils.py

Browse files
Files changed (1) hide show
  1. sadtalker_utils.py +887 -867
sadtalker_utils.py CHANGED
@@ -1,867 +1,887 @@
1
- import os
2
- import shutil
3
- import uuid
4
- import cv2
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import yaml
10
- from PIL import Image
11
- from skimage import img_as_ubyte, transform
12
- import safetensors
13
- import librosa
14
- from pydub import AudioSegment
15
- import imageio
16
- from scipy import signal
17
- from scipy.io import loadmat, savemat, wavfile
18
- import glob
19
- import tempfile
20
- from tqdm import tqdm
21
- import math
22
- import torchaudio
23
- import urllib.request
24
- from safetensors.torch import load_file, save_file
25
-
26
- REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
27
- CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
28
- RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
29
- GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
30
- kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
31
- kp_file = "kp_detector.safetensors"
32
- aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
33
- aud_file = "auido2pose_00140-model.pth"
34
- wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
35
- wav_file = "wav2vec2.pth"
36
- gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
37
- gen_file = "generator.pth"
38
- mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
39
- mapx_file = "mapping.pth"
40
- den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
41
- den_file = "dense_motion.pth"
42
-
43
-
44
- def download_model(url, filename, checkpoint_dir):
45
- if not os.path.exists(os.path.join(checkpoint_dir, filename)):
46
- print(f"Downloading {filename}...")
47
- os.makedirs(checkpoint_dir, exist_ok=True)
48
- urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
49
- print(f"{filename} downloaded.")
50
- else:
51
- print(f"{filename} already exists.")
52
-
53
-
54
- def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
55
- AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
56
-
57
-
58
- def load_wav_util(path, sr):
59
- return librosa.core.load(path, sr=sr)[0]
60
-
61
-
62
- def save_wav_util(wav, path, sr):
63
- wav *= 32767 / max(0.01, np.max(np.abs(wav)))
64
- wavfile.write(path, sr, wav.astype(np.int16))
65
-
66
-
67
- class OcclusionAwareKPDetector(nn.Module):
68
-
69
- def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
70
- super(OcclusionAwareKPDetector, self).__init__()
71
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
72
- self.bn1 = nn.BatchNorm2d(64)
73
- self.relu = nn.ReLU()
74
- self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
75
-
76
- def forward(self, x):
77
- x = self.relu(self.bn1(self.conv1(x)))
78
- x = self.conv2(x)
79
- kp = {'value': x.view(x.size(0), -1)}
80
- return kp
81
-
82
-
83
- class Wav2Vec2Model(nn.Module):
84
-
85
- def __init__(self):
86
- super(Wav2Vec2Model, self).__init__()
87
- self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
88
- self.bn = nn.BatchNorm1d(64)
89
- self.relu = nn.ReLU()
90
- self.fc = nn.Linear(64, 2048)
91
-
92
- def forward(self, audio):
93
- x = audio.unsqueeze(1)
94
- x = self.relu(self.bn(self.conv(x)))
95
- x = torch.mean(x, dim=-1)
96
- x = self.fc(x)
97
- return x
98
-
99
-
100
- class AudioCoeffsPredictor(nn.Module):
101
-
102
- def __init__(self, input_dim, output_dim):
103
- super(AudioCoeffsPredictor, self).__init__()
104
- self.linear = nn.Linear(input_dim, output_dim)
105
-
106
- def forward(self, audio_embedding):
107
- return self.linear(audio_embedding)
108
-
109
-
110
- class MappingNet(nn.Module):
111
-
112
- def __init__(self, num_coeffs, num_layers, hidden_dim):
113
- super(MappingNet, self).__init__()
114
- layers = []
115
- input_dim = num_coeffs * 2
116
- for _ in range(num_layers):
117
- layers.append(nn.Linear(input_dim, hidden_dim))
118
- layers.append(nn.ReLU())
119
- input_dim = hidden_dim
120
- layers.append(nn.Linear(hidden_dim, num_coeffs))
121
- self.net = nn.Sequential(*layers)
122
-
123
- def forward(self, x):
124
- return self.net(x)
125
-
126
-
127
- class DenseMotionNetwork(nn.Module):
128
-
129
- def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
130
- super(DenseMotionNetwork, self).__init__()
131
- self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
132
- self.relu = nn.ReLU()
133
- self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
134
-
135
- def forward(self, kp_source, kp_driving, jacobian):
136
- x = self.relu(self.conv1(kp_source))
137
- x = self.conv2(x)
138
- sparse_motion = {'dense_motion': x}
139
- return sparse_motion
140
-
141
-
142
- class Hourglass(nn.Module):
143
-
144
- def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
145
- super(Hourglass, self).__init__()
146
- self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
147
- nn.BatchNorm2d(max_features), nn.ReLU())
148
- self.decoder = nn.Sequential(
149
- nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
150
-
151
- def forward(self, source_image, kp_driving, **kwargs):
152
- x = self.encoder(source_image)
153
- x = self.decoder(x)
154
- B, C, H, W = x.size()
155
- video = []
156
- for _ in range(10):
157
- frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
158
- np.uint8)
159
- video.append(frame)
160
- return video
161
-
162
-
163
- class Face3DHelper:
164
-
165
- def __init__(self, local_pca_path, device):
166
- self.local_pca_path = local_pca_path
167
- self.device = device
168
-
169
- def run(self, source_image):
170
- h, w, _ = source_image.shape
171
- x_min = w // 4
172
- y_min = h // 4
173
- x_max = x_min + w // 2
174
- y_max = y_min + h // 2
175
- return [x_min, y_min, x_max, y_max]
176
-
177
-
178
- class Face3DHelperOld(Face3DHelper):
179
-
180
- def __init__(self, local_pca_path, device):
181
- super(Face3DHelperOld, self).__init__(local_pca_path, device)
182
-
183
-
184
- class MouthDetector:
185
-
186
- def __init__(self):
187
- pass
188
-
189
- def detect(self, image):
190
- h, w = image.shape[:2]
191
- return (w // 2, h // 2)
192
-
193
-
194
- class KeypointNorm(nn.Module):
195
-
196
- def __init__(self, device):
197
- super(KeypointNorm, self).__init__()
198
- self.device = device
199
-
200
- def forward(self, kp_driving):
201
- return kp_driving
202
-
203
-
204
- def save_video_with_watermark(video_frames, audio_path, output_path):
205
- H, W, _ = video_frames[0].shape
206
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
207
- for frame in video_frames:
208
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
209
- out.release()
210
-
211
-
212
- def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
213
- shutil.copy(video_path, output_path)
214
-
215
-
216
- class TTSTalker:
217
-
218
- def __init__(self):
219
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
220
- self.tts_model = None
221
-
222
- def load_model(self):
223
- self.tts_model = self
224
-
225
- def tokenizer(self, text):
226
- return [ord(c) for c in text]
227
-
228
- def __call__(self, input_tokens):
229
- return torch.zeros(1, 16000, device=self.device)
230
-
231
- def test(self, text, lang='en'):
232
- if self.tts_model is None:
233
- self.load_model()
234
- output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
235
- os.makedirs('./results', exist_ok=True)
236
- tokens = self.tokenizer(text)
237
- input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
238
- with torch.no_grad():
239
- audio_output = self(input_tokens)
240
- torchaudio.save(output_path, audio_output.cpu(), 16000)
241
- return output_path
242
-
243
-
244
- class SadTalker:
245
-
246
- def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
247
- old_version=False):
248
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
249
- self.cfg = self.get_cfg_defaults()
250
- self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
251
- self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
252
- self.cfg['MODEL']['CONFIG_DIR'] = config_path
253
- self.cfg['MODEL']['DEVICE'] = self.device
254
- self.cfg['INPUT_IMAGE'] = {}
255
- self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
256
- self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
257
- self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
258
- self.cfg['INPUT_IMAGE']['SIZE'] = size
259
- self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
260
-
261
- download_model(kp_url, kp_file, checkpoint_path)
262
- download_model(aud_url, aud_file, checkpoint_path)
263
- download_model(wav_url, wav_file, checkpoint_path)
264
- download_model(gen_url, gen_file, checkpoint_path)
265
- download_model(mapx_url, mapx_file, checkpoint_path)
266
- download_model(den_url, den_file, checkpoint_path)
267
- download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
268
- download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
269
-
270
- self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
271
-
272
- def get_cfg_defaults(self):
273
- return {
274
- 'MODEL': {
275
- 'CHECKPOINTS_DIR': '',
276
- 'CONFIG_DIR': '',
277
- 'DEVICE': self.device,
278
- 'SCALE': 64,
279
- 'NUM_VOXEL_FRAMES': 8,
280
- 'NUM_MOTION_FRAMES': 10,
281
- 'MAX_FEATURES': 256,
282
- 'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
283
- 'VIDEO_FPS': 25,
284
- 'OUTPUT_VIDEO_FPS': None,
285
- 'OUTPUT_AUDIO_SAMPLE_RATE': None,
286
- 'USE_ENHANCER': False,
287
- 'ENHANCER_NAME': '',
288
- 'BG_UPSAMPLER': None,
289
- 'IS_HALF': False
290
- },
291
- 'INPUT_IMAGE': {}
292
- }
293
-
294
- def merge_from_file(self, filepath):
295
- if os.path.exists(filepath):
296
- with open(filepath, 'r') as f:
297
- cfg_from_file = yaml.safe_load(f)
298
- self.cfg.update(cfg_from_file)
299
-
300
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
301
- batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
302
- ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
303
- tts_text=None, tts_lang='en'):
304
- self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
305
- pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
306
- length_of_audio, use_blink, result_dir, tts_text, tts_lang)
307
- return self.sadtalker_model.save_result()
308
-
309
-
310
- class SadTalkerModel:
311
-
312
- def __init__(self, sadtalker_cfg, device_id=[0]):
313
- self.cfg = sadtalker_cfg
314
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
315
- self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
316
- self.preprocesser = self.sadtalker.preprocesser
317
- self.kp_extractor = self.sadtalker.kp_extractor
318
- self.generator = self.sadtalker.generator
319
- self.mapping = self.sadtalker.mapping
320
- self.he_estimator = self.sadtalker.he_estimator
321
- self.audio_to_coeff = self.sadtalker.audio_to_coeff
322
- self.animate_from_coeff = self.sadtalker.animate_from_coeff
323
- self.face_enhancer = self.sadtalker.face_enhancer
324
-
325
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
326
- batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
327
- ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
328
- tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
329
- self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
330
- batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
331
- use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
332
- jitter_amount, jitter_source_image)
333
- return self.inner_test.test()
334
-
335
- def save_result(self):
336
- return self.inner_test.save_result()
337
-
338
-
339
- class SadTalkerInner:
340
-
341
- def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
342
- batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
343
- length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
344
- self.sadtalker_model = sadtalker_model
345
- self.source_image = source_image
346
- self.driven_audio = driven_audio
347
- self.preprocess = preprocess
348
- self.still_mode = still_mode
349
- self.use_enhancer = use_enhancer
350
- self.batch_size = batch_size
351
- self.size = size
352
- self.pose_style = pose_style
353
- self.exp_scale = exp_scale
354
- self.use_ref_video = use_ref_video
355
- self.ref_video = ref_video
356
- self.ref_info = ref_info
357
- self.use_idle_mode = use_idle_mode
358
- self.length_of_audio = length_of_audio
359
- self.use_blink = use_blink
360
- self.result_dir = result_dir
361
- self.tts_text = tts_text
362
- self.tts_lang = tts_lang
363
- self.jitter_amount = jitter_amount
364
- self.jitter_source_image = jitter_source_image
365
- self.device = self.sadtalker_model.device
366
- self.output_path = None
367
-
368
- def get_test_data(self):
369
- proc = self.sadtalker_model.preprocesser
370
- if self.tts_text is not None:
371
- temp_dir = tempfile.mkdtemp()
372
- audio_path = os.path.join(temp_dir, 'audio.wav')
373
- tts = TTSTalker()
374
- tts.test(self.tts_text, self.tts_lang)
375
- self.driven_audio = audio_path
376
- source_image_pil = Image.open(self.source_image).convert('RGB')
377
- if self.jitter_source_image:
378
- jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
379
- jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
380
- source_image_pil = Image.fromarray(
381
- np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
382
- source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
383
- if self.still_mode or self.use_idle_mode:
384
- ref_pose_coeff = proc.generate_still_pose(self.pose_style)
385
- ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
386
- elif self.use_idle_mode:
387
- ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
388
- ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
389
- else:
390
- ref_pose_coeff = None
391
- ref_expression_coeff = None
392
- audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
393
- self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
394
- batch = {
395
- 'source_image': source_image_tensor.unsqueeze(0).to(self.device),
396
- 'audio': audio_tensor.unsqueeze(0).to(self.device),
397
- 'ref_pose_coeff': ref_pose_coeff,
398
- 'ref_expression_coeff': ref_expression_coeff,
399
- 'source_image_crop': cropped_image,
400
- 'crop_info': crop_info,
401
- 'use_blink': self.use_blink,
402
- 'pose_style': self.pose_style,
403
- 'exp_scale': self.exp_scale,
404
- 'ref_video': self.ref_video,
405
- 'use_ref_video': self.use_ref_video,
406
- 'ref_info': self.ref_info,
407
- }
408
- return batch, audio_sample_rate
409
-
410
- def run_inference(self, batch):
411
- kp_extractor = self.sadtalker_model.kp_extractor
412
- generator = self.sadtalker_model.generator
413
- mapping = self.sadtalker_model.mapping
414
- he_estimator = self.sadtalker_model.he_estimator
415
- audio_to_coeff = self.sadtalker_model.audio_to_coeff
416
- animate_from_coeff = self.sadtalker_model.animate_from_coeff
417
- proc = self.sadtalker_model.preprocesser
418
- with torch.no_grad():
419
- kp_source = kp_extractor(batch['source_image'])
420
- if self.still_mode or self.use_idle_mode:
421
- ref_pose_coeff = batch['ref_pose_coeff']
422
- ref_expression_coeff = batch['ref_expression_coeff']
423
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
424
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
425
- elif self.use_idle_mode:
426
- ref_pose_coeff = batch['ref_pose_coeff']
427
- ref_expression_coeff = batch['ref_expression_coeff']
428
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
429
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
430
- else:
431
- if self.use_ref_video:
432
- kp_ref = kp_extractor(batch['source_image'])
433
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
434
- use_ref_info=batch['ref_info'])
435
- else:
436
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
437
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
438
- coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
439
- if self.use_blink:
440
- coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
441
- else:
442
- coeff['blink_coeff'] = None
443
- kp_driving = audio_to_coeff(batch['audio'])[0]
444
- kp_norm = animate_from_coeff.normalize_kp(kp_driving)
445
- coeff['kp_driving'] = kp_norm
446
- coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
447
- face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
448
- output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
449
- he_estimator, batch['audio'], batch['source_image_crop'],
450
- face_enhancer=face_enhancer)
451
- return output_video
452
-
453
- def post_processing(self, output_video, audio_sample_rate, batch):
454
- proc = self.sadtalker_model.preprocesser
455
- base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
456
- audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
457
- output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
458
- self.output_path = output_video_path
459
- video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
460
- 'OUTPUT_VIDEO_FPS'] is None else \
461
- self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
462
- audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
463
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
464
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
465
- if self.use_enhancer:
466
- enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
467
- save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
468
- paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
469
- output_video_path)
470
- os.remove(enhanced_path)
471
- else:
472
- save_video_with_watermark(output_video, self.driven_audio, output_video_path)
473
- if self.tts_text is not None:
474
- shutil.rmtree(os.path.dirname(self.driven_audio))
475
-
476
- def save_result(self):
477
- return self.output_path
478
-
479
- def __call__(self):
480
- return self.output_path
481
-
482
- def test(self):
483
- batch, audio_sample_rate = self.get_test_data()
484
- output_video = self.run_inference(batch)
485
- self.post_processing(output_video, audio_sample_rate, batch)
486
- return self.save_result()
487
-
488
-
489
- class SadTalkerInnerModel:
490
-
491
- def __init__(self, sadtalker_cfg, device_id=[0]):
492
- self.cfg = sadtalker_cfg
493
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
494
- self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
495
- self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
496
- self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
497
- self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
498
- self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
499
- 'USE_ENHANCER'] else None
500
- self.generator = Generator(sadtalker_cfg, self.device)
501
- self.mapping = Mapping(sadtalker_cfg, self.device)
502
- self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
503
-
504
-
505
- class Preprocesser:
506
-
507
- def __init__(self, sadtalker_cfg, device):
508
- self.cfg = sadtalker_cfg
509
- self.device = device
510
- if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
511
- self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
512
- else:
513
- self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
514
- self.mouth_detector = MouthDetector()
515
-
516
- def crop(self, source_image_pil, preprocess_type, size=256):
517
- source_image = np.array(source_image_pil)
518
- face_info = self.face3d_helper.run(source_image)
519
- if face_info is None:
520
- raise Exception("No face detected")
521
- x_min, y_min, x_max, y_max = face_info[:4]
522
- old_size = (x_max - x_min, y_max - y_min)
523
- x_center = (x_max + x_min) / 2
524
- y_center = (y_max + y_min) / 2
525
- if preprocess_type == 'crop':
526
- face_size = max(x_max - x_min, y_max - y_min)
527
- x_min = int(x_center - face_size / 2)
528
- y_min = int(y_center - face_size / 2)
529
- x_max = int(x_center + face_size / 2)
530
- y_max = int(y_center + face_size / 2)
531
- else:
532
- x_min -= int((x_max - x_min) * 0.1)
533
- y_min -= int((y_max - y_min) * 0.1)
534
- x_max += int((x_max - x_min) * 0.1)
535
- y_max += int((y_max - y_min) * 0.1)
536
- h, w = source_image.shape[:2]
537
- x_min = max(0, x_min)
538
- y_min = max(0, y_min)
539
- x_max = min(w, x_max)
540
- y_max = min(h, y_max)
541
- cropped_image = source_image[y_min:y_max, x_min:x_max]
542
- cropped_image_pil = Image.fromarray(cropped_image)
543
- if size is not None and size != 0:
544
- cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
545
- source_image_tensor = self.img2tensor(cropped_image_pil)
546
- return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
547
- self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
548
-
549
- def img2tensor(self, img):
550
- img = np.array(img).astype(np.float32) / 255.0
551
- img = np.transpose(img, (2, 0, 1))
552
- return torch.FloatTensor(img)
553
-
554
- def video_to_tensor(self, video, device):
555
- video_tensor_list = []
556
- import torchvision.transforms as transforms
557
- transform_func = transforms.ToTensor()
558
- for frame in video:
559
- frame_pil = Image.fromarray(frame)
560
- frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
561
- video_tensor_list.append(frame_tensor)
562
- video_tensor = torch.cat(video_tensor_list, dim=0)
563
- return video_tensor
564
-
565
- def process_audio(self, audio_path, sample_rate):
566
- wav = load_wav_util(audio_path, sample_rate)
567
- wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
568
- return wav_tensor, sample_rate
569
-
570
- def generate_still_pose(self, pose_style):
571
- ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
572
- ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
573
- return ref_pose_coeff
574
-
575
- def generate_still_expression(self, exp_scale):
576
- ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
577
- ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
578
- return ref_expression_coeff
579
-
580
- def generate_idles_pose(self, length_of_audio, pose_style):
581
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
582
- ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
583
- start_pose = self.generate_still_pose(pose_style)
584
- end_pose = self.generate_still_pose(pose_style)
585
- for frame_idx in range(num_frames):
586
- alpha = frame_idx / num_frames
587
- ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
588
- return ref_pose_coeff
589
-
590
- def generate_idles_expression(self, length_of_audio):
591
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
592
- ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
593
- start_exp = self.generate_still_expression(1.0)
594
- end_exp = self.generate_still_expression(1.0)
595
- for frame_idx in range(num_frames):
596
- alpha = frame_idx / num_frames
597
- ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
598
- return ref_expression_coeff
599
-
600
-
601
- class KeyPointExtractor(nn.Module):
602
-
603
- def __init__(self, sadtalker_cfg, device):
604
- super(KeyPointExtractor, self).__init__()
605
- self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
606
- num_kp=10,
607
- num_dilation_blocks=2,
608
- dropout_rate=0.1).to(device)
609
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
610
- self.load_kp_detector(checkpoint_path, device)
611
-
612
- def load_kp_detector(self, checkpoint_path, device):
613
- if os.path.exists(checkpoint_path):
614
- if checkpoint_path.endswith('safetensors'):
615
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
616
- else:
617
- checkpoint = torch.load(checkpoint_path, map_location=device)
618
- self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
619
- else:
620
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
621
-
622
- def forward(self, x):
623
- kp = self.kp_extractor(x)
624
- return kp
625
-
626
-
627
- class Audio2Coeff(nn.Module):
628
-
629
- def __init__(self, sadtalker_cfg, device):
630
- super(Audio2Coeff, self).__init__()
631
- self.audio_model = Wav2Vec2Model().to(device)
632
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
633
- self.load_audio_model(checkpoint_path, device)
634
- self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
635
- self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
636
- self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
637
- mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
638
- self.load_mapping_model(mapping_checkpoint, device)
639
-
640
- def load_audio_model(self, checkpoint_path, device):
641
- if os.path.exists(checkpoint_path):
642
- if checkpoint_path.endswith('safetensors'):
643
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
644
- else:
645
- checkpoint = torch.load(checkpoint_path, map_location=device)
646
- self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
647
- else:
648
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
649
-
650
- def load_mapping_model(self, checkpoint_path, device):
651
- if os.path.exists(checkpoint_path):
652
- if checkpoint_path.endswith('safetensors'):
653
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
654
- else:
655
- checkpoint = torch.load(checkpoint_path, map_location=device)
656
- self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
657
- self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
658
- self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
659
- else:
660
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
661
-
662
- def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
663
- audio_embedding = self.audio_model(audio_tensor)
664
- pose_coeff = self.pose_mapper(audio_embedding)
665
- if ref_pose_coeff is not None:
666
- pose_coeff = ref_pose_coeff
667
- if kp_ref is not None and use_ref_info == 'pose':
668
- ref_pose_6d = kp_ref['value'][:, :6]
669
- pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
670
- return pose_coeff
671
-
672
- def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
673
- audio_embedding = self.audio_model(audio_tensor)
674
- expression_coeff = self.exp_mapper(audio_embedding)
675
- if ref_expression_coeff is not None:
676
- expression_coeff = ref_expression_coeff
677
- return expression_coeff
678
-
679
- def get_blink_coeff(self, audio_tensor):
680
- audio_embedding = self.audio_model(audio_tensor)
681
- blink_coeff = self.blink_mapper(audio_embedding)
682
- return blink_coeff
683
-
684
- def forward(self, audio):
685
- audio_embedding = self.audio_model(audio)
686
- pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
687
- audio_embedding), self.blink_mapper(audio_embedding)
688
- return pose_coeff, expression_coeff, blink_coeff
689
-
690
- def mean_std_normalize(self, coeff):
691
- mean = coeff.mean(dim=1, keepdim=True)
692
- std = coeff.std(dim=1, keepdim=True)
693
- return (coeff - mean) / std
694
-
695
-
696
- class AnimateFromCoeff(nn.Module):
697
-
698
- def __init__(self, sadtalker_cfg, device):
699
- super(AnimateFromCoeff, self).__init__()
700
- self.generator = Generator(sadtalker_cfg, device)
701
- self.mapping = Mapping(sadtalker_cfg, device)
702
- self.kp_norm = KeypointNorm(device=device)
703
- self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
704
-
705
- def normalize_kp(self, kp_driving):
706
- return self.kp_norm(kp_driving)
707
-
708
- def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
709
- face_enhancer=None):
710
- kp_driving = coeff['kp_driving']
711
- jacobian = coeff['jacobian']
712
- pose_coeff = coeff['pose_coeff']
713
- expression_coeff = coeff['expression_coeff']
714
- blink_coeff = coeff['blink_coeff']
715
- with torch.no_grad():
716
- if blink_coeff is not None:
717
- sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
718
- dense_motion = sparse_motion['dense_motion']
719
- video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
720
- face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
721
- video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
722
- face_3d_param=face_3d)
723
- video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
724
- video_output = self.make_animation(video_output)
725
- else:
726
- sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
727
- dense_motion = sparse_motion['dense_motion']
728
- face_3d = mapping(expression_coeff, pose_coeff)
729
- video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
730
- face_3d_param=face_3d)
731
- video_output = video_3d['video_3d']
732
- video_output = self.make_animation(video_output)
733
- if face_enhancer is not None:
734
- video_output_enhanced = []
735
- for frame in tqdm(video_output, 'Face enhancer running'):
736
- pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
737
- enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
738
- video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
739
- video_output = video_output_enhanced
740
- return video_output
741
-
742
- def make_animation(self, video_array):
743
- H, W, _ = video_array[0].shape
744
- out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
745
- for img in video_array:
746
- out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
747
- out.release()
748
- video = imageio.mimread('./tmp.mp4')
749
- os.remove('./tmp.mp4')
750
- return video
751
-
752
-
753
- class Generator(nn.Module):
754
-
755
- def __init__(self, sadtalker_cfg, device):
756
- super(Generator, self).__init__()
757
- self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
758
- num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
759
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
760
- num_channels=3,
761
- kp_size=10,
762
- num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
763
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
764
- self.load_generator(checkpoint_path, device)
765
-
766
- def load_generator(self, checkpoint_path, device):
767
- if os.path.exists(checkpoint_path):
768
- if checkpoint_path.endswith('safetensors'):
769
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
770
- else:
771
- checkpoint = torch.load(checkpoint_path, map_location=device)
772
- self.generator.load_state_dict(checkpoint.get('generator', {}))
773
- else:
774
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
775
-
776
- def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
777
- if face_3d_param is not None:
778
- video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
779
- face_3d_param=face_3d_param)
780
- else:
781
- video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
782
- return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
783
-
784
-
785
- class Mapping(nn.Module):
786
-
787
- def __init__(self, sadtalker_cfg, device):
788
- super(Mapping, self).__init__()
789
- self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
790
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
791
- self.load_mapping_net(checkpoint_path, device)
792
- self.f_3d_mean = torch.zeros(1, 64, device=device)
793
-
794
- def load_mapping_net(self, checkpoint_path, device):
795
- if os.path.exists(checkpoint_path):
796
- if checkpoint_path.endswith('safetensors'):
797
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
798
- else:
799
- checkpoint = torch.load(checkpoint_path, map_location=device)
800
- self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
801
- else:
802
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
803
-
804
- def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
805
- coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
806
- face_3d = self.mapping_net(coeff) + self.f_3d_mean
807
- if blink_coeff is not None:
808
- face_3d[:, -1:] = blink_coeff
809
- return face_3d
810
-
811
-
812
- class OcclusionAwareDenseMotion(nn.Module):
813
-
814
- def __init__(self, sadtalker_cfg, device):
815
- super(OcclusionAwareDenseMotion, self).__init__()
816
- self.dense_motion_network = DenseMotionNetwork(num_kp=10,
817
- num_channels=3,
818
- block_expansion=sadtalker_cfg['MODEL']['SCALE'],
819
- num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
820
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
821
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
822
- self.load_dense_motion_network(checkpoint_path, device)
823
-
824
- def load_dense_motion_network(self, checkpoint_path, device):
825
- if os.path.exists(checkpoint_path):
826
- if checkpoint_path.endswith('safetensors'):
827
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
828
- else:
829
- checkpoint = torch.load(checkpoint_path, map_location=device)
830
- self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
831
- else:
832
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
833
-
834
- def forward(self, kp_source, kp_driving, jacobian):
835
- sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
836
- return sparse_motion
837
-
838
-
839
- class FaceEnhancer(nn.Module):
840
-
841
- def __init__(self, sadtalker_cfg, device):
842
- super(FaceEnhancer, self).__init__()
843
- enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
844
- bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
845
- if enhancer_name == 'gfpgan':
846
- from gfpgan import GFPGANer
847
- self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
848
- upscale=1,
849
- arch='clean',
850
- channel_multiplier=2,
851
- bg_upsampler=bg_upsampler)
852
- elif enhancer_name == 'realesrgan':
853
- from realesrgan import RealESRGANer
854
- half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
855
- self.face_enhancer = RealESRGANer(scale=2,
856
- model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
857
- 'RealESRGAN_x2plus.pth'),
858
- tile=0,
859
- tile_pad=10,
860
- pre_pad=0,
861
- half=half,
862
- device=device)
863
- else:
864
- self.face_enhancer = None
865
-
866
- def forward(self, x):
867
- return self.face_enhancer.enhance(x, outscale=1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import yaml
10
+ from PIL import Image
11
+ from skimage import img_as_ubyte, transform
12
+ import safetensors
13
+ import librosa
14
+ from pydub import AudioSegment
15
+ import imageio
16
+ from scipy import signal
17
+ from scipy.io import loadmat, savemat, wavfile
18
+ import glob
19
+ import tempfile
20
+ from tqdm import tqdm
21
+ import math
22
+ import torchaudio
23
+ import urllib.request
24
+ from safetensors.torch import load_file, save_file
25
+
26
+ REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
27
+ CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
28
+ RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
29
+ GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
30
+ kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
31
+ kp_file = "kp_detector.safetensors"
32
+ aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
33
+ aud_file = "auido2pose_00140-model.pth"
34
+ wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
35
+ wav_file = "wav2vec2.pth"
36
+ gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
37
+ gen_file = "generator.pth"
38
+ mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
39
+ mapx_file = "mapping.pth"
40
+ den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
41
+ den_file = "dense_motion.pth"
42
+
43
+
44
+ def download_model(url, filename, checkpoint_dir):
45
+ if not os.path.exists(os.path.join(checkpoint_dir, filename)):
46
+ print(f"Downloading {filename}...")
47
+ os.makedirs(checkpoint_dir, exist_ok=True)
48
+ urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
49
+ print(f"{filename} downloaded.")
50
+ else:
51
+ print(f"{filename} already exists.")
52
+
53
+
54
+ def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
55
+ AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
56
+
57
+
58
+ def load_wav_util(path, sr):
59
+ return librosa.core.load(path, sr=sr)[0]
60
+
61
+
62
+ def save_wav_util(wav, path, sr):
63
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
64
+ wavfile.write(path, sr, wav.astype(np.int16))
65
+
66
+
67
+ class OcclusionAwareKPDetector(nn.Module):
68
+
69
+ def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
70
+ super(OcclusionAwareKPDetector, self).__init__()
71
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
72
+ self.bn1 = nn.BatchNorm2d(64)
73
+ self.relu = nn.ReLU()
74
+ self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
75
+
76
+ def forward(self, x):
77
+ x = self.relu(self.bn1(self.conv1(x)))
78
+ x = self.conv2(x)
79
+ kp = {'value': x.view(x.size(0), -1)}
80
+ return kp
81
+
82
+
83
+ class Wav2Vec2Model(nn.Module):
84
+
85
+ def __init__(self):
86
+ super(Wav2Vec2Model, self).__init__()
87
+ self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
88
+ self.bn = nn.BatchNorm1d(64)
89
+ self.relu = nn.ReLU()
90
+ self.fc = nn.Linear(64, 2048)
91
+
92
+ def forward(self, audio):
93
+ x = audio.unsqueeze(1)
94
+ x = self.relu(self.bn(self.conv(x)))
95
+ x = torch.mean(x, dim=-1)
96
+ x = self.fc(x)
97
+ return x
98
+
99
+
100
+ class AudioCoeffsPredictor(nn.Module):
101
+
102
+ def __init__(self, input_dim, output_dim):
103
+ super(AudioCoeffsPredictor, self).__init__()
104
+ self.linear = nn.Linear(input_dim, output_dim)
105
+
106
+ def forward(self, audio_embedding):
107
+ return self.linear(audio_embedding)
108
+
109
+
110
+ class MappingNet(nn.Module):
111
+
112
+ def __init__(self, num_coeffs, num_layers, hidden_dim):
113
+ super(MappingNet, self).__init__()
114
+ layers = []
115
+ input_dim = num_coeffs * 2
116
+ for _ in range(num_layers):
117
+ layers.append(nn.Linear(input_dim, hidden_dim))
118
+ layers.append(nn.ReLU())
119
+ input_dim = hidden_dim
120
+ layers.append(nn.Linear(hidden_dim, num_coeffs))
121
+ self.net = nn.Sequential(*layers)
122
+
123
+ def forward(self, x):
124
+ return self.net(x)
125
+
126
+
127
+ class DenseMotionNetwork(nn.Module):
128
+
129
+ def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
130
+ super(DenseMotionNetwork, self).__init__()
131
+ self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
132
+ self.relu = nn.ReLU()
133
+ self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
134
+
135
+ def forward(self, kp_source, kp_driving, jacobian):
136
+ x = self.relu(self.conv1(kp_source))
137
+ x = self.conv2(x)
138
+ sparse_motion = {'dense_motion': x}
139
+ return sparse_motion
140
+
141
+
142
+ class Hourglass(nn.Module):
143
+
144
+ def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
145
+ super(Hourglass, self).__init__()
146
+ self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
147
+ nn.BatchNorm2d(max_features), nn.ReLU())
148
+ self.decoder = nn.Sequential(
149
+ nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
150
+
151
+ def forward(self, source_image, kp_driving, **kwargs):
152
+ x = self.encoder(source_image)
153
+ x = self.decoder(x)
154
+ B, C, H, W = x.size()
155
+ video = []
156
+ for _ in range(10):
157
+ frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
158
+ np.uint8)
159
+ video.append(frame)
160
+ return video
161
+
162
+
163
+ class Face3DHelper:
164
+
165
+ def __init__(self, local_pca_path, device):
166
+ self.local_pca_path = local_pca_path
167
+ self.device = device
168
+
169
+ def run(self, source_image):
170
+ h, w, _ = source_image.shape
171
+ x_min = w // 4
172
+ y_min = h // 4
173
+ x_max = x_min + w // 2
174
+ y_max = y_min + h // 2
175
+ return [x_min, y_min, x_max, y_max]
176
+
177
+
178
+ class Face3DHelperOld(Face3DHelper):
179
+
180
+ def __init__(self, local_pca_path, device):
181
+ super(Face3DHelperOld, self).__init__(local_pca_path, device)
182
+
183
+
184
+ class MouthDetector:
185
+
186
+ def __init__(self):
187
+ pass
188
+
189
+ def detect(self, image):
190
+ h, w = image.shape[:2]
191
+ return (w // 2, h // 2)
192
+
193
+
194
+ class KeypointNorm(nn.Module):
195
+
196
+ def __init__(self, device):
197
+ super(KeypointNorm, self).__init__()
198
+ self.device = device
199
+
200
+ def forward(self, kp_driving):
201
+ return kp_driving
202
+
203
+
204
+ def save_video_with_watermark(video_frames, audio_path, output_path):
205
+ H, W, _ = video_frames[0].shape
206
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
207
+ for frame in video_frames:
208
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
209
+ out.release()
210
+
211
+
212
+ def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
213
+ shutil.copy(video_path, output_path)
214
+
215
+
216
+ class TTSTalker:
217
+
218
+ def __init__(self):
219
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
220
+ self.tts_model = None
221
+
222
+ def load_model(self):
223
+ self.tts_model = self
224
+
225
+ def tokenizer(self, text):
226
+ return [ord(c) for c in text]
227
+
228
+ def __call__(self, input_tokens):
229
+ return torch.zeros(1, 16000, device=self.device)
230
+
231
+ def test(self, text, lang='en'):
232
+ if self.tts_model is None:
233
+ self.load_model()
234
+ output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
235
+ os.makedirs('./results', exist_ok=True)
236
+ tokens = self.tokenizer(text)
237
+ input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
238
+ with torch.no_grad():
239
+ audio_output = self(input_tokens)
240
+ torchaudio.save(output_path, audio_output.cpu(), 16000)
241
+ return output_path
242
+
243
+
244
+ class SadTalker:
245
+
246
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
247
+ old_version=False):
248
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
249
+ self.cfg = self.get_cfg_defaults()
250
+ self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
251
+ self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
252
+ self.cfg['MODEL']['CONFIG_DIR'] = config_path
253
+ self.cfg['MODEL']['DEVICE'] = self.device
254
+ self.cfg['INPUT_IMAGE'] = {}
255
+ self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
256
+ self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
257
+ self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
258
+ self.cfg['INPUT_IMAGE']['SIZE'] = size
259
+ self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
260
+
261
+ download_model(kp_url, kp_file, checkpoint_path)
262
+ download_model(aud_url, aud_file, checkpoint_path)
263
+ download_model(wav_url, wav_file, checkpoint_path)
264
+ download_model(gen_url, gen_file, checkpoint_path)
265
+ download_model(mapx_url, mapx_file, checkpoint_path)
266
+ download_model(den_url, den_file, checkpoint_path)
267
+ download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
268
+ download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
269
+
270
+ self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
271
+
272
+ def get_cfg_defaults(self):
273
+ return {
274
+ 'MODEL': {
275
+ 'CHECKPOINTS_DIR': '',
276
+ 'CONFIG_DIR': '',
277
+ 'DEVICE': self.device,
278
+ 'SCALE': 64,
279
+ 'NUM_VOXEL_FRAMES': 8,
280
+ 'NUM_MOTION_FRAMES': 10,
281
+ 'MAX_FEATURES': 256,
282
+ 'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
283
+ 'VIDEO_FPS': 25,
284
+ 'OUTPUT_VIDEO_FPS': None,
285
+ 'OUTPUT_AUDIO_SAMPLE_RATE': None,
286
+ 'USE_ENHANCER': False,
287
+ 'ENHANCER_NAME': '',
288
+ 'BG_UPSAMPLER': None,
289
+ 'IS_HALF': False
290
+ },
291
+ 'INPUT_IMAGE': {}
292
+ }
293
+
294
+ def merge_from_file(self, filepath):
295
+ if os.path.exists(filepath):
296
+ with open(filepath, 'r') as f:
297
+ cfg_from_file = yaml.safe_load(f)
298
+ self.cfg.update(cfg_from_file)
299
+
300
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
301
+ batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
302
+ ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
303
+ tts_text=None, tts_lang='en'):
304
+ self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
305
+ pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
306
+ length_of_audio, use_blink, result_dir, tts_text, tts_lang)
307
+ return self.sadtalker_model.save_result()
308
+
309
+
310
+ class SadTalkerModel:
311
+
312
+ def __init__(self, sadtalker_cfg, device_id=[0]):
313
+ self.cfg = sadtalker_cfg
314
+ self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
315
+ self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
316
+ self.preprocesser = self.sadtalker.preprocesser
317
+ self.kp_extractor = self.sadtalker.kp_extractor
318
+ self.generator = self.sadtalker.generator
319
+ self.mapping = self.sadtalker.mapping
320
+ self.he_estimator = self.sadtalker.he_estimator
321
+ self.audio_to_coeff = self.sadtalker.audio_to_coeff
322
+ self.animate_from_coeff = self.sadtalker.animate_from_coeff
323
+ self.face_enhancer = self.sadtalker.face_enhancer
324
+
325
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
326
+ batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
327
+ ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
328
+ tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
329
+ self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
330
+ batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
331
+ use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
332
+ jitter_amount, jitter_source_image)
333
+ return self.inner_test.test()
334
+
335
+ def save_result(self):
336
+ return self.inner_test.save_result()
337
+
338
+
339
+ class SadTalkerInner:
340
+
341
+ def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
342
+ batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
343
+ length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
344
+ self.sadtalker_model = sadtalker_model
345
+ self.source_image = source_image
346
+ self.driven_audio = driven_audio
347
+ self.preprocess = preprocess
348
+ self.still_mode = still_mode
349
+ self.use_enhancer = use_enhancer
350
+ self.batch_size = batch_size
351
+ self.size = size
352
+ self.pose_style = pose_style
353
+ self.exp_scale = exp_scale
354
+ self.use_ref_video = use_ref_video
355
+ self.ref_video = ref_video
356
+ self.ref_info = ref_info
357
+ self.use_idle_mode = use_idle_mode
358
+ self.length_of_audio = length_of_audio
359
+ self.use_blink = use_blink
360
+ self.result_dir = result_dir
361
+ self.tts_text = tts_text
362
+ self.tts_lang = tts_lang
363
+ self.jitter_amount = jitter_amount
364
+ self.jitter_source_image = jitter_source_image
365
+ self.device = self.sadtalker_model.device
366
+ self.output_path = None
367
+
368
+ def get_test_data(self):
369
+ proc = self.sadtalker_model.preprocesser
370
+ if self.tts_text is not None:
371
+ temp_dir = tempfile.mkdtemp()
372
+ audio_path = os.path.join(temp_dir, 'audio.wav')
373
+ tts = TTSTalker()
374
+ tts.test(self.tts_text, self.tts_lang)
375
+ self.driven_audio = audio_path
376
+ source_image_pil = Image.open(self.source_image).convert('RGB')
377
+ if self.jitter_source_image:
378
+ jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
379
+ jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
380
+ source_image_pil = Image.fromarray(
381
+ np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
382
+ source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
383
+ if self.still_mode or self.use_idle_mode:
384
+ ref_pose_coeff = proc.generate_still_pose(self.pose_style)
385
+ ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
386
+ elif self.use_idle_mode:
387
+ ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
388
+ ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
389
+ else:
390
+ ref_pose_coeff = None
391
+ ref_expression_coeff = None
392
+ audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
393
+ self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
394
+ batch = {
395
+ 'source_image': source_image_tensor.unsqueeze(0).to(self.device),
396
+ 'audio': audio_tensor.unsqueeze(0).to(self.device),
397
+ 'ref_pose_coeff': ref_pose_coeff,
398
+ 'ref_expression_coeff': ref_expression_coeff,
399
+ 'source_image_crop': cropped_image,
400
+ 'crop_info': crop_info,
401
+ 'use_blink': self.use_blink,
402
+ 'pose_style': self.pose_style,
403
+ 'exp_scale': self.exp_scale,
404
+ 'ref_video': self.ref_video,
405
+ 'use_ref_video': self.use_ref_video,
406
+ 'ref_info': self.ref_info,
407
+ }
408
+ return batch, audio_sample_rate
409
+
410
+ def run_inference(self, batch):
411
+ kp_extractor = self.sadtalker_model.kp_extractor
412
+ generator = self.sadtalker_model.generator
413
+ mapping = self.sadtalker_model.mapping
414
+ he_estimator = self.sadtalker_model.he_estimator
415
+ audio_to_coeff = self.sadtalker_model.audio_to_coeff
416
+ animate_from_coeff = self.sadtalker_model.animate_from_coeff
417
+ proc = self.sadtalker_model.preprocesser
418
+ with torch.no_grad():
419
+ kp_source = kp_extractor(batch['source_image'])
420
+ if self.still_mode or self.use_idle_mode:
421
+ ref_pose_coeff = batch['ref_pose_coeff']
422
+ ref_expression_coeff = batch['ref_expression_coeff']
423
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
424
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
425
+ elif self.use_idle_mode:
426
+ ref_pose_coeff = batch['ref_pose_coeff']
427
+ ref_expression_coeff = batch['ref_expression_coeff']
428
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
429
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
430
+ else:
431
+ if self.use_ref_video:
432
+ kp_ref = kp_extractor(batch['source_image'])
433
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
434
+ use_ref_info=batch['ref_info'])
435
+ else:
436
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
437
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
438
+ coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
439
+ if self.use_blink:
440
+ coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
441
+ else:
442
+ coeff['blink_coeff'] = None
443
+ kp_driving = audio_to_coeff(batch['audio'])[0]
444
+ kp_norm = animate_from_coeff.normalize_kp(kp_driving)
445
+ coeff['kp_driving'] = kp_norm
446
+ coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
447
+ face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
448
+ output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
449
+ he_estimator, batch['audio'], batch['source_image_crop'],
450
+ face_enhancer=face_enhancer)
451
+ return output_video
452
+
453
+ def post_processing(self, output_video, audio_sample_rate, batch):
454
+ proc = self.sadtalker_model.preprocesser
455
+ base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
456
+ audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
457
+ output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
458
+ self.output_path = output_video_path
459
+ video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
460
+ 'OUTPUT_VIDEO_FPS'] is None else \
461
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
462
+ audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
463
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
464
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
465
+ if self.use_enhancer:
466
+ enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
467
+ save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
468
+ paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
469
+ output_video_path)
470
+ os.remove(enhanced_path)
471
+ else:
472
+ save_video_with_watermark(output_video, self.driven_audio, output_video_path)
473
+ if self.tts_text is not None:
474
+ shutil.rmtree(os.path.dirname(self.driven_audio))
475
+
476
+ def save_result(self):
477
+ return self.output_path
478
+
479
+ def __call__(self):
480
+ return self.output_path
481
+
482
+ def test(self):
483
+ batch, audio_sample_rate = self.get_test_data()
484
+ output_video = self.run_inference(batch)
485
+ self.post_processing(output_video, audio_sample_rate, batch)
486
+ return self.save_result()
487
+
488
+
489
+ class SadTalkerInnerModel:
490
+
491
+ def __init__(self, sadtalker_cfg, device_id=[0]):
492
+ self.cfg = sadtalker_cfg
493
+ self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
494
+ self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
495
+ self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
496
+ self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
497
+ self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
498
+ self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
499
+ 'USE_ENHANCER'] else None
500
+ self.generator = Generator(sadtalker_cfg, self.device)
501
+ self.mapping = Mapping(sadtalker_cfg, self.device)
502
+ self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
503
+
504
+
505
+ class Preprocesser:
506
+
507
+ def __init__(self, sadtalker_cfg, device):
508
+ self.cfg = sadtalker_cfg
509
+ self.device = device
510
+ if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
511
+ self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
512
+ else:
513
+ self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
514
+ self.mouth_detector = MouthDetector()
515
+
516
+ def crop(self, source_image_pil, preprocess_type, size=256):
517
+ source_image = np.array(source_image_pil)
518
+ face_info = self.face3d_helper.run(source_image)
519
+ if face_info is None:
520
+ raise Exception("No face detected")
521
+ x_min, y_min, x_max, y_max = face_info[:4]
522
+ old_size = (x_max - x_min, y_max - y_min)
523
+ x_center = (x_max + x_min) / 2
524
+ y_center = (y_max + y_min) / 2
525
+ if preprocess_type == 'crop':
526
+ face_size = max(x_max - x_min, y_max - y_min)
527
+ x_min = int(x_center - face_size / 2)
528
+ y_min = int(y_center - face_size / 2)
529
+ x_max = int(x_center + face_size / 2)
530
+ y_max = int(y_center + face_size / 2)
531
+ else:
532
+ x_min -= int((x_max - x_min) * 0.1)
533
+ y_min -= int((y_max - y_min) * 0.1)
534
+ x_max += int((x_max - x_min) * 0.1)
535
+ y_max += int((y_max - y_min) * 0.1)
536
+ h, w = source_image.shape[:2]
537
+ x_min = max(0, x_min)
538
+ y_min = max(0, y_min)
539
+ x_max = min(w, x_max)
540
+ y_max = min(h, y_max)
541
+ cropped_image = source_image[y_min:y_max, x_min:x_max]
542
+ cropped_image_pil = Image.fromarray(cropped_image)
543
+ if size is not None and size != 0:
544
+ cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
545
+ source_image_tensor = self.img2tensor(cropped_image_pil)
546
+ return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
547
+ self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
548
+
549
+ def img2tensor(self, img):
550
+ img = np.array(img).astype(np.float32) / 255.0
551
+ img = np.transpose(img, (2, 0, 1))
552
+ return torch.FloatTensor(img)
553
+
554
+ def video_to_tensor(self, video, device):
555
+ video_tensor_list = []
556
+ import torchvision.transforms as transforms
557
+ transform_func = transforms.ToTensor()
558
+ for frame in video:
559
+ frame_pil = Image.fromarray(frame)
560
+ frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
561
+ video_tensor_list.append(frame_tensor)
562
+ video_tensor = torch.cat(video_tensor_list, dim=0)
563
+ return video_tensor
564
+
565
+ def process_audio(self, audio_path, sample_rate):
566
+ wav = load_wav_util(audio_path, sample_rate)
567
+ wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
568
+ return wav_tensor, sample_rate
569
+
570
+ def generate_still_pose(self, pose_style):
571
+ ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
572
+ ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
573
+ return ref_pose_coeff
574
+
575
+ def generate_still_expression(self, exp_scale):
576
+ ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
577
+ ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
578
+ return ref_expression_coeff
579
+
580
+ def generate_idles_pose(self, length_of_audio, pose_style):
581
+ num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
582
+ ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
583
+ start_pose = self.generate_still_pose(pose_style)
584
+ end_pose = self.generate_still_pose(pose_style)
585
+ for frame_idx in range(num_frames):
586
+ alpha = frame_idx / num_frames
587
+ ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
588
+ return ref_pose_coeff
589
+
590
+ def generate_idles_expression(self, length_of_audio):
591
+ num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
592
+ ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
593
+ start_exp = self.generate_still_expression(1.0)
594
+ end_exp = self.generate_still_expression(1.0)
595
+ for frame_idx in range(num_frames):
596
+ alpha = frame_idx / num_frames
597
+ ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
598
+ return ref_expression_coeff
599
+
600
+
601
+ class KeyPointExtractor(nn.Module):
602
+
603
+ def __init__(self, sadtalker_cfg, device):
604
+ super(KeyPointExtractor, self).__init__()
605
+ self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
606
+ num_kp=10,
607
+ num_dilation_blocks=2,
608
+ dropout_rate=0.1).to(device)
609
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
610
+ self.load_kp_detector(checkpoint_path, device)
611
+
612
+ def load_kp_detector(self, checkpoint_path, device):
613
+ if os.path.exists(checkpoint_path):
614
+ if checkpoint_path.endswith('safetensors'):
615
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
616
+ else:
617
+ checkpoint = torch.load(checkpoint_path, map_location=device)
618
+ try:
619
+ self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
620
+ except RuntimeError as e:
621
+ print(f"Error loading kp_detector state_dict: {e}")
622
+ print("Trying to load state_dict without prefix 'kp_detector.'")
623
+ self.kp_extractor.load_state_dict(checkpoint, strict=False)
624
+
625
+ else:
626
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
627
+
628
+ def forward(self, x):
629
+ kp = self.kp_extractor(x)
630
+ return kp
631
+
632
+
633
+ class Audio2Coeff(nn.Module):
634
+
635
+ def __init__(self, sadtalker_cfg, device):
636
+ super(Audio2Coeff, self).__init__()
637
+ self.audio_model = Wav2Vec2Model().to(device)
638
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
639
+ self.load_audio_model(checkpoint_path, device)
640
+ self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
641
+ self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
642
+ self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
643
+ mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
644
+ self.load_mapping_model(mapping_checkpoint, device)
645
+
646
+ def load_audio_model(self, checkpoint_path, device):
647
+ if os.path.exists(checkpoint_path):
648
+ if checkpoint_path.endswith('safetensors'):
649
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
650
+ else:
651
+ checkpoint = torch.load(checkpoint_path, map_location=device)
652
+ self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
653
+ else:
654
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
655
+
656
+ def load_mapping_model(self, checkpoint_path, device):
657
+ if os.path.exists(checkpoint_path):
658
+ if checkpoint_path.endswith('safetensors'):
659
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
660
+ else:
661
+ checkpoint = torch.load(checkpoint_path, map_location=device)
662
+ self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
663
+ self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
664
+ self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
665
+ else:
666
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
667
+
668
+ def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
669
+ audio_embedding = self.audio_model(audio_tensor)
670
+ pose_coeff = self.pose_mapper(audio_embedding)
671
+ if ref_pose_coeff is not None:
672
+ pose_coeff = ref_pose_coeff
673
+ if kp_ref is not None and use_ref_info == 'pose':
674
+ ref_pose_6d = kp_ref['value'][:, :6]
675
+ pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
676
+ return pose_coeff
677
+
678
+ def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
679
+ audio_embedding = self.audio_model(audio_tensor)
680
+ expression_coeff = self.exp_mapper(audio_embedding)
681
+ if ref_expression_coeff is not None:
682
+ expression_coeff = ref_expression_coeff
683
+ return expression_coeff
684
+
685
+ def get_blink_coeff(self, audio_tensor):
686
+ audio_embedding = self.audio_model(audio_tensor)
687
+ blink_coeff = self.blink_mapper(audio_embedding)
688
+ return blink_coeff
689
+
690
+ def forward(self, audio):
691
+ audio_embedding = self.audio_model(audio)
692
+ pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
693
+ audio_embedding), self.blink_mapper(audio_embedding)
694
+ return pose_coeff, expression_coeff, blink_coeff
695
+
696
+ def mean_std_normalize(self, coeff):
697
+ mean = coeff.mean(dim=1, keepdim=True)
698
+ std = coeff.std(dim=1, keepdim=True)
699
+ return (coeff - mean) / std
700
+
701
+
702
+ class AnimateFromCoeff(nn.Module):
703
+
704
+ def __init__(self, sadtalker_cfg, device):
705
+ super(AnimateFromCoeff, self).__init__()
706
+ self.generator = Generator(sadtalker_cfg, device)
707
+ self.mapping = Mapping(sadtalker_cfg, device)
708
+ self.kp_norm = KeypointNorm(device=device)
709
+ self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
710
+
711
+ def normalize_kp(self, kp_driving):
712
+ return self.kp_norm(kp_driving)
713
+
714
+ def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
715
+ face_enhancer=None):
716
+ kp_driving = coeff['kp_driving']
717
+ jacobian = coeff['jacobian']
718
+ pose_coeff = coeff['pose_coeff']
719
+ expression_coeff = coeff['expression_coeff']
720
+ blink_coeff = coeff['blink_coeff']
721
+ with torch.no_grad():
722
+ if blink_coeff is not None:
723
+ sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
724
+ dense_motion = sparse_motion['dense_motion']
725
+ video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
726
+ face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
727
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
728
+ face_3d_param=face_3d)
729
+ video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
730
+ video_output = self.make_animation(video_output)
731
+ else:
732
+ sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
733
+ dense_motion = sparse_motion['dense_motion']
734
+ face_3d = mapping(expression_coeff, pose_coeff)
735
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
736
+ face_3d_param=face_3d)
737
+ video_output = video_3d['video_3d']
738
+ video_output = self.make_animation(video_output)
739
+ if face_enhancer is not None:
740
+ video_output_enhanced = []
741
+ for frame in tqdm(video_output, 'Face enhancer running'):
742
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
743
+ enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
744
+ video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
745
+ video_output = video_output_enhanced
746
+ return video_output
747
+
748
+ def make_animation(self, video_array):
749
+ H, W, _ = video_array[0].shape
750
+ out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
751
+ for img in video_array:
752
+ out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
753
+ out.release()
754
+ video = imageio.mimread('./tmp.mp4')
755
+ os.remove('./tmp.mp4')
756
+ return video
757
+
758
+
759
+ class Generator(nn.Module):
760
+
761
+ def __init__(self, sadtalker_cfg, device):
762
+ super(Generator, self).__init__()
763
+ self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
764
+ num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
765
+ max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
766
+ num_channels=3,
767
+ kp_size=10,
768
+ num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
769
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
770
+ self.load_generator(checkpoint_path, device)
771
+
772
+ def load_generator(self, checkpoint_path, device):
773
+ if os.path.exists(checkpoint_path):
774
+ if checkpoint_path.endswith('safetensors'):
775
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
776
+ else:
777
+ checkpoint = torch.load(checkpoint_path, map_location=device)
778
+ self.generator.load_state_dict(checkpoint.get('generator', {}))
779
+ else:
780
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
781
+
782
+ def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
783
+ if face_3d_param is not None:
784
+ video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
785
+ face_3d_param=face_3d_param)
786
+ else:
787
+ video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
788
+ return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
789
+
790
+
791
+ class Mapping(nn.Module):
792
+
793
+ def __init__(self, sadtalker_cfg, device):
794
+ super(Mapping, self).__init__()
795
+ self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
796
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
797
+ self.load_mapping_net(checkpoint_path, device)
798
+ self.f_3d_mean = torch.zeros(1, 64, device=device)
799
+
800
+ def load_mapping_net(self, checkpoint_path, device):
801
+ if os.path.exists(checkpoint_path):
802
+ if checkpoint_path.endswith('safetensors'):
803
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
804
+ else:
805
+ checkpoint = torch.load(checkpoint_path, map_location=device)
806
+ self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
807
+ else:
808
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
809
+
810
+ def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
811
+ coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
812
+ face_3d = self.mapping_net(coeff) + self.f_3d_mean
813
+ if blink_coeff is not None:
814
+ face_3d[:, -1:] = blink_coeff
815
+ return face_3d
816
+
817
+
818
+ class OcclusionAwareDenseMotion(nn.Module):
819
+
820
+ def __init__(self, sadtalker_cfg, device):
821
+ super(OcclusionAwareDenseMotion, self).__init__()
822
+ self.dense_motion_network = DenseMotionNetwork(num_kp=10,
823
+ num_channels=3,
824
+ block_expansion=sadtalker_cfg['MODEL']['SCALE'],
825
+ num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
826
+ max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
827
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
828
+ self.load_dense_motion_network(checkpoint_path, device)
829
+
830
+ def load_dense_motion_network(self, checkpoint_path, device):
831
+ if os.path.exists(checkpoint_path):
832
+ if checkpoint_path.endswith('safetensors'):
833
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
834
+ else:
835
+ checkpoint = torch.load(checkpoint_path, map_location=device)
836
+ self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
837
+ else:
838
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
839
+
840
+ def forward(self, kp_source, kp_driving, jacobian):
841
+ sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
842
+ return sparse_motion
843
+
844
+
845
+ class FaceEnhancer(nn.Module):
846
+
847
+ def __init__(self, sadtalker_cfg, device):
848
+ super(FaceEnhancer, self).__init__()
849
+ enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
850
+ bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
851
+ if enhancer_name == 'gfpgan':
852
+ from gfpgan import GFPGANer
853
+ self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
854
+ upscale=1,
855
+ arch='clean',
856
+ channel_multiplier=2,
857
+ bg_upsampler=bg_upsampler)
858
+ elif enhancer_name == 'realesrgan':
859
+ from realesrgan import RealESRGANer
860
+ half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
861
+ self.face_enhancer = RealESRGANer(scale=2,
862
+ model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
863
+ 'RealESRGAN_x2plus.pth'),
864
+ tile=0,
865
+ tile_pad=10,
866
+ pre_pad=0,
867
+ half=half,
868
+ device=device)
869
+ else:
870
+ self.face_enhancer = None
871
+
872
+ def forward(self, x):
873
+ return self.face_enhancer.enhance(x, outscale=1)[0]
874
+
875
+ def load_models():
876
+ checkpoint_path = './checkpoints'
877
+ config_path = './src/config'
878
+ size = 256
879
+ preprocess = 'crop'
880
+ old_version = False
881
+
882
+ sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version)
883
+ print("SadTalker models loaded successfully!")
884
+ return sadtalker_instance
885
+
886
+ if __name__ == '__main__':
887
+ sadtalker_instance = load_models()