Spaces:
Running
Running
Update sadtalker_utils.py
Browse files- sadtalker_utils.py +887 -867
sadtalker_utils.py
CHANGED
@@ -1,867 +1,887 @@
|
|
1 |
-
import os
|
2 |
-
import shutil
|
3 |
-
import uuid
|
4 |
-
import cv2
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
import yaml
|
10 |
-
from PIL import Image
|
11 |
-
from skimage import img_as_ubyte, transform
|
12 |
-
import safetensors
|
13 |
-
import librosa
|
14 |
-
from pydub import AudioSegment
|
15 |
-
import imageio
|
16 |
-
from scipy import signal
|
17 |
-
from scipy.io import loadmat, savemat, wavfile
|
18 |
-
import glob
|
19 |
-
import tempfile
|
20 |
-
from tqdm import tqdm
|
21 |
-
import math
|
22 |
-
import torchaudio
|
23 |
-
import urllib.request
|
24 |
-
from safetensors.torch import load_file, save_file
|
25 |
-
|
26 |
-
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
27 |
-
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
28 |
-
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
|
29 |
-
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
30 |
-
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
|
31 |
-
kp_file = "kp_detector.safetensors"
|
32 |
-
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
|
33 |
-
aud_file = "auido2pose_00140-model.pth"
|
34 |
-
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
|
35 |
-
wav_file = "wav2vec2.pth"
|
36 |
-
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
|
37 |
-
gen_file = "generator.pth"
|
38 |
-
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
|
39 |
-
mapx_file = "mapping.pth"
|
40 |
-
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
|
41 |
-
den_file = "dense_motion.pth"
|
42 |
-
|
43 |
-
|
44 |
-
def download_model(url, filename, checkpoint_dir):
|
45 |
-
if not os.path.exists(os.path.join(checkpoint_dir, filename)):
|
46 |
-
print(f"Downloading {filename}...")
|
47 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
48 |
-
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
|
49 |
-
print(f"{filename} downloaded.")
|
50 |
-
else:
|
51 |
-
print(f"{filename} already exists.")
|
52 |
-
|
53 |
-
|
54 |
-
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
|
55 |
-
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
|
56 |
-
|
57 |
-
|
58 |
-
def load_wav_util(path, sr):
|
59 |
-
return librosa.core.load(path, sr=sr)[0]
|
60 |
-
|
61 |
-
|
62 |
-
def save_wav_util(wav, path, sr):
|
63 |
-
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
64 |
-
wavfile.write(path, sr, wav.astype(np.int16))
|
65 |
-
|
66 |
-
|
67 |
-
class OcclusionAwareKPDetector(nn.Module):
|
68 |
-
|
69 |
-
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
|
70 |
-
super(OcclusionAwareKPDetector, self).__init__()
|
71 |
-
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
72 |
-
self.bn1 = nn.BatchNorm2d(64)
|
73 |
-
self.relu = nn.ReLU()
|
74 |
-
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
|
75 |
-
|
76 |
-
def forward(self, x):
|
77 |
-
x = self.relu(self.bn1(self.conv1(x)))
|
78 |
-
x = self.conv2(x)
|
79 |
-
kp = {'value': x.view(x.size(0), -1)}
|
80 |
-
return kp
|
81 |
-
|
82 |
-
|
83 |
-
class Wav2Vec2Model(nn.Module):
|
84 |
-
|
85 |
-
def __init__(self):
|
86 |
-
super(Wav2Vec2Model, self).__init__()
|
87 |
-
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
|
88 |
-
self.bn = nn.BatchNorm1d(64)
|
89 |
-
self.relu = nn.ReLU()
|
90 |
-
self.fc = nn.Linear(64, 2048)
|
91 |
-
|
92 |
-
def forward(self, audio):
|
93 |
-
x = audio.unsqueeze(1)
|
94 |
-
x = self.relu(self.bn(self.conv(x)))
|
95 |
-
x = torch.mean(x, dim=-1)
|
96 |
-
x = self.fc(x)
|
97 |
-
return x
|
98 |
-
|
99 |
-
|
100 |
-
class AudioCoeffsPredictor(nn.Module):
|
101 |
-
|
102 |
-
def __init__(self, input_dim, output_dim):
|
103 |
-
super(AudioCoeffsPredictor, self).__init__()
|
104 |
-
self.linear = nn.Linear(input_dim, output_dim)
|
105 |
-
|
106 |
-
def forward(self, audio_embedding):
|
107 |
-
return self.linear(audio_embedding)
|
108 |
-
|
109 |
-
|
110 |
-
class MappingNet(nn.Module):
|
111 |
-
|
112 |
-
def __init__(self, num_coeffs, num_layers, hidden_dim):
|
113 |
-
super(MappingNet, self).__init__()
|
114 |
-
layers = []
|
115 |
-
input_dim = num_coeffs * 2
|
116 |
-
for _ in range(num_layers):
|
117 |
-
layers.append(nn.Linear(input_dim, hidden_dim))
|
118 |
-
layers.append(nn.ReLU())
|
119 |
-
input_dim = hidden_dim
|
120 |
-
layers.append(nn.Linear(hidden_dim, num_coeffs))
|
121 |
-
self.net = nn.Sequential(*layers)
|
122 |
-
|
123 |
-
def forward(self, x):
|
124 |
-
return self.net(x)
|
125 |
-
|
126 |
-
|
127 |
-
class DenseMotionNetwork(nn.Module):
|
128 |
-
|
129 |
-
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
|
130 |
-
super(DenseMotionNetwork, self).__init__()
|
131 |
-
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
|
132 |
-
self.relu = nn.ReLU()
|
133 |
-
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
|
134 |
-
|
135 |
-
def forward(self, kp_source, kp_driving, jacobian):
|
136 |
-
x = self.relu(self.conv1(kp_source))
|
137 |
-
x = self.conv2(x)
|
138 |
-
sparse_motion = {'dense_motion': x}
|
139 |
-
return sparse_motion
|
140 |
-
|
141 |
-
|
142 |
-
class Hourglass(nn.Module):
|
143 |
-
|
144 |
-
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
|
145 |
-
super(Hourglass, self).__init__()
|
146 |
-
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
|
147 |
-
nn.BatchNorm2d(max_features), nn.ReLU())
|
148 |
-
self.decoder = nn.Sequential(
|
149 |
-
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
|
150 |
-
|
151 |
-
def forward(self, source_image, kp_driving, **kwargs):
|
152 |
-
x = self.encoder(source_image)
|
153 |
-
x = self.decoder(x)
|
154 |
-
B, C, H, W = x.size()
|
155 |
-
video = []
|
156 |
-
for _ in range(10):
|
157 |
-
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
|
158 |
-
np.uint8)
|
159 |
-
video.append(frame)
|
160 |
-
return video
|
161 |
-
|
162 |
-
|
163 |
-
class Face3DHelper:
|
164 |
-
|
165 |
-
def __init__(self, local_pca_path, device):
|
166 |
-
self.local_pca_path = local_pca_path
|
167 |
-
self.device = device
|
168 |
-
|
169 |
-
def run(self, source_image):
|
170 |
-
h, w, _ = source_image.shape
|
171 |
-
x_min = w // 4
|
172 |
-
y_min = h // 4
|
173 |
-
x_max = x_min + w // 2
|
174 |
-
y_max = y_min + h // 2
|
175 |
-
return [x_min, y_min, x_max, y_max]
|
176 |
-
|
177 |
-
|
178 |
-
class Face3DHelperOld(Face3DHelper):
|
179 |
-
|
180 |
-
def __init__(self, local_pca_path, device):
|
181 |
-
super(Face3DHelperOld, self).__init__(local_pca_path, device)
|
182 |
-
|
183 |
-
|
184 |
-
class MouthDetector:
|
185 |
-
|
186 |
-
def __init__(self):
|
187 |
-
pass
|
188 |
-
|
189 |
-
def detect(self, image):
|
190 |
-
h, w = image.shape[:2]
|
191 |
-
return (w // 2, h // 2)
|
192 |
-
|
193 |
-
|
194 |
-
class KeypointNorm(nn.Module):
|
195 |
-
|
196 |
-
def __init__(self, device):
|
197 |
-
super(KeypointNorm, self).__init__()
|
198 |
-
self.device = device
|
199 |
-
|
200 |
-
def forward(self, kp_driving):
|
201 |
-
return kp_driving
|
202 |
-
|
203 |
-
|
204 |
-
def save_video_with_watermark(video_frames, audio_path, output_path):
|
205 |
-
H, W, _ = video_frames[0].shape
|
206 |
-
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
207 |
-
for frame in video_frames:
|
208 |
-
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
209 |
-
out.release()
|
210 |
-
|
211 |
-
|
212 |
-
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
|
213 |
-
shutil.copy(video_path, output_path)
|
214 |
-
|
215 |
-
|
216 |
-
class TTSTalker:
|
217 |
-
|
218 |
-
def __init__(self):
|
219 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
220 |
-
self.tts_model = None
|
221 |
-
|
222 |
-
def load_model(self):
|
223 |
-
self.tts_model = self
|
224 |
-
|
225 |
-
def tokenizer(self, text):
|
226 |
-
return [ord(c) for c in text]
|
227 |
-
|
228 |
-
def __call__(self, input_tokens):
|
229 |
-
return torch.zeros(1, 16000, device=self.device)
|
230 |
-
|
231 |
-
def test(self, text, lang='en'):
|
232 |
-
if self.tts_model is None:
|
233 |
-
self.load_model()
|
234 |
-
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
|
235 |
-
os.makedirs('./results', exist_ok=True)
|
236 |
-
tokens = self.tokenizer(text)
|
237 |
-
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
|
238 |
-
with torch.no_grad():
|
239 |
-
audio_output = self(input_tokens)
|
240 |
-
torchaudio.save(output_path, audio_output.cpu(), 16000)
|
241 |
-
return output_path
|
242 |
-
|
243 |
-
|
244 |
-
class SadTalker:
|
245 |
-
|
246 |
-
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
|
247 |
-
old_version=False):
|
248 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
249 |
-
self.cfg = self.get_cfg_defaults()
|
250 |
-
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
|
251 |
-
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
|
252 |
-
self.cfg['MODEL']['CONFIG_DIR'] = config_path
|
253 |
-
self.cfg['MODEL']['DEVICE'] = self.device
|
254 |
-
self.cfg['INPUT_IMAGE'] = {}
|
255 |
-
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
|
256 |
-
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
|
257 |
-
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
|
258 |
-
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
259 |
-
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
260 |
-
|
261 |
-
download_model(kp_url, kp_file, checkpoint_path)
|
262 |
-
download_model(aud_url, aud_file, checkpoint_path)
|
263 |
-
download_model(wav_url, wav_file, checkpoint_path)
|
264 |
-
download_model(gen_url, gen_file, checkpoint_path)
|
265 |
-
download_model(mapx_url, mapx_file, checkpoint_path)
|
266 |
-
download_model(den_url, den_file, checkpoint_path)
|
267 |
-
download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
|
268 |
-
download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
|
269 |
-
|
270 |
-
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
271 |
-
|
272 |
-
def get_cfg_defaults(self):
|
273 |
-
return {
|
274 |
-
'MODEL': {
|
275 |
-
'CHECKPOINTS_DIR': '',
|
276 |
-
'CONFIG_DIR': '',
|
277 |
-
'DEVICE': self.device,
|
278 |
-
'SCALE': 64,
|
279 |
-
'NUM_VOXEL_FRAMES': 8,
|
280 |
-
'NUM_MOTION_FRAMES': 10,
|
281 |
-
'MAX_FEATURES': 256,
|
282 |
-
'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
|
283 |
-
'VIDEO_FPS': 25,
|
284 |
-
'OUTPUT_VIDEO_FPS': None,
|
285 |
-
'OUTPUT_AUDIO_SAMPLE_RATE': None,
|
286 |
-
'USE_ENHANCER': False,
|
287 |
-
'ENHANCER_NAME': '',
|
288 |
-
'BG_UPSAMPLER': None,
|
289 |
-
'IS_HALF': False
|
290 |
-
},
|
291 |
-
'INPUT_IMAGE': {}
|
292 |
-
}
|
293 |
-
|
294 |
-
def merge_from_file(self, filepath):
|
295 |
-
if os.path.exists(filepath):
|
296 |
-
with open(filepath, 'r') as f:
|
297 |
-
cfg_from_file = yaml.safe_load(f)
|
298 |
-
self.cfg.update(cfg_from_file)
|
299 |
-
|
300 |
-
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
301 |
-
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
302 |
-
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
303 |
-
tts_text=None, tts_lang='en'):
|
304 |
-
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
|
305 |
-
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
306 |
-
length_of_audio, use_blink, result_dir, tts_text, tts_lang)
|
307 |
-
return self.sadtalker_model.save_result()
|
308 |
-
|
309 |
-
|
310 |
-
class SadTalkerModel:
|
311 |
-
|
312 |
-
def __init__(self, sadtalker_cfg, device_id=[0]):
|
313 |
-
self.cfg = sadtalker_cfg
|
314 |
-
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
315 |
-
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
316 |
-
self.preprocesser = self.sadtalker.preprocesser
|
317 |
-
self.kp_extractor = self.sadtalker.kp_extractor
|
318 |
-
self.generator = self.sadtalker.generator
|
319 |
-
self.mapping = self.sadtalker.mapping
|
320 |
-
self.he_estimator = self.sadtalker.he_estimator
|
321 |
-
self.audio_to_coeff = self.sadtalker.audio_to_coeff
|
322 |
-
self.animate_from_coeff = self.sadtalker.animate_from_coeff
|
323 |
-
self.face_enhancer = self.sadtalker.face_enhancer
|
324 |
-
|
325 |
-
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
326 |
-
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
327 |
-
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
328 |
-
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
|
329 |
-
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
330 |
-
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
|
331 |
-
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
|
332 |
-
jitter_amount, jitter_source_image)
|
333 |
-
return self.inner_test.test()
|
334 |
-
|
335 |
-
def save_result(self):
|
336 |
-
return self.inner_test.save_result()
|
337 |
-
|
338 |
-
|
339 |
-
class SadTalkerInner:
|
340 |
-
|
341 |
-
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
342 |
-
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
343 |
-
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
|
344 |
-
self.sadtalker_model = sadtalker_model
|
345 |
-
self.source_image = source_image
|
346 |
-
self.driven_audio = driven_audio
|
347 |
-
self.preprocess = preprocess
|
348 |
-
self.still_mode = still_mode
|
349 |
-
self.use_enhancer = use_enhancer
|
350 |
-
self.batch_size = batch_size
|
351 |
-
self.size = size
|
352 |
-
self.pose_style = pose_style
|
353 |
-
self.exp_scale = exp_scale
|
354 |
-
self.use_ref_video = use_ref_video
|
355 |
-
self.ref_video = ref_video
|
356 |
-
self.ref_info = ref_info
|
357 |
-
self.use_idle_mode = use_idle_mode
|
358 |
-
self.length_of_audio = length_of_audio
|
359 |
-
self.use_blink = use_blink
|
360 |
-
self.result_dir = result_dir
|
361 |
-
self.tts_text = tts_text
|
362 |
-
self.tts_lang = tts_lang
|
363 |
-
self.jitter_amount = jitter_amount
|
364 |
-
self.jitter_source_image = jitter_source_image
|
365 |
-
self.device = self.sadtalker_model.device
|
366 |
-
self.output_path = None
|
367 |
-
|
368 |
-
def get_test_data(self):
|
369 |
-
proc = self.sadtalker_model.preprocesser
|
370 |
-
if self.tts_text is not None:
|
371 |
-
temp_dir = tempfile.mkdtemp()
|
372 |
-
audio_path = os.path.join(temp_dir, 'audio.wav')
|
373 |
-
tts = TTSTalker()
|
374 |
-
tts.test(self.tts_text, self.tts_lang)
|
375 |
-
self.driven_audio = audio_path
|
376 |
-
source_image_pil = Image.open(self.source_image).convert('RGB')
|
377 |
-
if self.jitter_source_image:
|
378 |
-
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
379 |
-
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
380 |
-
source_image_pil = Image.fromarray(
|
381 |
-
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
|
382 |
-
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
|
383 |
-
if self.still_mode or self.use_idle_mode:
|
384 |
-
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
385 |
-
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
386 |
-
elif self.use_idle_mode:
|
387 |
-
ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
|
388 |
-
ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
|
389 |
-
else:
|
390 |
-
ref_pose_coeff = None
|
391 |
-
ref_expression_coeff = None
|
392 |
-
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
393 |
-
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
|
394 |
-
batch = {
|
395 |
-
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
396 |
-
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
397 |
-
'ref_pose_coeff': ref_pose_coeff,
|
398 |
-
'ref_expression_coeff': ref_expression_coeff,
|
399 |
-
'source_image_crop': cropped_image,
|
400 |
-
'crop_info': crop_info,
|
401 |
-
'use_blink': self.use_blink,
|
402 |
-
'pose_style': self.pose_style,
|
403 |
-
'exp_scale': self.exp_scale,
|
404 |
-
'ref_video': self.ref_video,
|
405 |
-
'use_ref_video': self.use_ref_video,
|
406 |
-
'ref_info': self.ref_info,
|
407 |
-
}
|
408 |
-
return batch, audio_sample_rate
|
409 |
-
|
410 |
-
def run_inference(self, batch):
|
411 |
-
kp_extractor = self.sadtalker_model.kp_extractor
|
412 |
-
generator = self.sadtalker_model.generator
|
413 |
-
mapping = self.sadtalker_model.mapping
|
414 |
-
he_estimator = self.sadtalker_model.he_estimator
|
415 |
-
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
416 |
-
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
417 |
-
proc = self.sadtalker_model.preprocesser
|
418 |
-
with torch.no_grad():
|
419 |
-
kp_source = kp_extractor(batch['source_image'])
|
420 |
-
if self.still_mode or self.use_idle_mode:
|
421 |
-
ref_pose_coeff = batch['ref_pose_coeff']
|
422 |
-
ref_expression_coeff = batch['ref_expression_coeff']
|
423 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
424 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
425 |
-
elif self.use_idle_mode:
|
426 |
-
ref_pose_coeff = batch['ref_pose_coeff']
|
427 |
-
ref_expression_coeff = batch['ref_expression_coeff']
|
428 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
429 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
430 |
-
else:
|
431 |
-
if self.use_ref_video:
|
432 |
-
kp_ref = kp_extractor(batch['source_image'])
|
433 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
|
434 |
-
use_ref_info=batch['ref_info'])
|
435 |
-
else:
|
436 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
|
437 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
|
438 |
-
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
|
439 |
-
if self.use_blink:
|
440 |
-
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
|
441 |
-
else:
|
442 |
-
coeff['blink_coeff'] = None
|
443 |
-
kp_driving = audio_to_coeff(batch['audio'])[0]
|
444 |
-
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
445 |
-
coeff['kp_driving'] = kp_norm
|
446 |
-
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
447 |
-
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
448 |
-
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
449 |
-
he_estimator, batch['audio'], batch['source_image_crop'],
|
450 |
-
face_enhancer=face_enhancer)
|
451 |
-
return output_video
|
452 |
-
|
453 |
-
def post_processing(self, output_video, audio_sample_rate, batch):
|
454 |
-
proc = self.sadtalker_model.preprocesser
|
455 |
-
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
|
456 |
-
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
457 |
-
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
458 |
-
self.output_path = output_video_path
|
459 |
-
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
|
460 |
-
'OUTPUT_VIDEO_FPS'] is None else \
|
461 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
|
462 |
-
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
|
463 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
|
464 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
465 |
-
if self.use_enhancer:
|
466 |
-
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
467 |
-
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
468 |
-
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
|
469 |
-
output_video_path)
|
470 |
-
os.remove(enhanced_path)
|
471 |
-
else:
|
472 |
-
save_video_with_watermark(output_video, self.driven_audio, output_video_path)
|
473 |
-
if self.tts_text is not None:
|
474 |
-
shutil.rmtree(os.path.dirname(self.driven_audio))
|
475 |
-
|
476 |
-
def save_result(self):
|
477 |
-
return self.output_path
|
478 |
-
|
479 |
-
def __call__(self):
|
480 |
-
return self.output_path
|
481 |
-
|
482 |
-
def test(self):
|
483 |
-
batch, audio_sample_rate = self.get_test_data()
|
484 |
-
output_video = self.run_inference(batch)
|
485 |
-
self.post_processing(output_video, audio_sample_rate, batch)
|
486 |
-
return self.save_result()
|
487 |
-
|
488 |
-
|
489 |
-
class SadTalkerInnerModel:
|
490 |
-
|
491 |
-
def __init__(self, sadtalker_cfg, device_id=[0]):
|
492 |
-
self.cfg = sadtalker_cfg
|
493 |
-
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
494 |
-
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
495 |
-
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
496 |
-
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
497 |
-
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
498 |
-
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
|
499 |
-
'USE_ENHANCER'] else None
|
500 |
-
self.generator = Generator(sadtalker_cfg, self.device)
|
501 |
-
self.mapping = Mapping(sadtalker_cfg, self.device)
|
502 |
-
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
503 |
-
|
504 |
-
|
505 |
-
class Preprocesser:
|
506 |
-
|
507 |
-
def __init__(self, sadtalker_cfg, device):
|
508 |
-
self.cfg = sadtalker_cfg
|
509 |
-
self.device = device
|
510 |
-
if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
|
511 |
-
self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
512 |
-
else:
|
513 |
-
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
514 |
-
self.mouth_detector = MouthDetector()
|
515 |
-
|
516 |
-
def crop(self, source_image_pil, preprocess_type, size=256):
|
517 |
-
source_image = np.array(source_image_pil)
|
518 |
-
face_info = self.face3d_helper.run(source_image)
|
519 |
-
if face_info is None:
|
520 |
-
raise Exception("No face detected")
|
521 |
-
x_min, y_min, x_max, y_max = face_info[:4]
|
522 |
-
old_size = (x_max - x_min, y_max - y_min)
|
523 |
-
x_center = (x_max + x_min) / 2
|
524 |
-
y_center = (y_max + y_min) / 2
|
525 |
-
if preprocess_type == 'crop':
|
526 |
-
face_size = max(x_max - x_min, y_max - y_min)
|
527 |
-
x_min = int(x_center - face_size / 2)
|
528 |
-
y_min = int(y_center - face_size / 2)
|
529 |
-
x_max = int(x_center + face_size / 2)
|
530 |
-
y_max = int(y_center + face_size / 2)
|
531 |
-
else:
|
532 |
-
x_min -= int((x_max - x_min) * 0.1)
|
533 |
-
y_min -= int((y_max - y_min) * 0.1)
|
534 |
-
x_max += int((x_max - x_min) * 0.1)
|
535 |
-
y_max += int((y_max - y_min) * 0.1)
|
536 |
-
h, w = source_image.shape[:2]
|
537 |
-
x_min = max(0, x_min)
|
538 |
-
y_min = max(0, y_min)
|
539 |
-
x_max = min(w, x_max)
|
540 |
-
y_max = min(h, y_max)
|
541 |
-
cropped_image = source_image[y_min:y_max, x_min:x_max]
|
542 |
-
cropped_image_pil = Image.fromarray(cropped_image)
|
543 |
-
if size is not None and size != 0:
|
544 |
-
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
545 |
-
source_image_tensor = self.img2tensor(cropped_image_pil)
|
546 |
-
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
547 |
-
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
|
548 |
-
|
549 |
-
def img2tensor(self, img):
|
550 |
-
img = np.array(img).astype(np.float32) / 255.0
|
551 |
-
img = np.transpose(img, (2, 0, 1))
|
552 |
-
return torch.FloatTensor(img)
|
553 |
-
|
554 |
-
def video_to_tensor(self, video, device):
|
555 |
-
video_tensor_list = []
|
556 |
-
import torchvision.transforms as transforms
|
557 |
-
transform_func = transforms.ToTensor()
|
558 |
-
for frame in video:
|
559 |
-
frame_pil = Image.fromarray(frame)
|
560 |
-
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
|
561 |
-
video_tensor_list.append(frame_tensor)
|
562 |
-
video_tensor = torch.cat(video_tensor_list, dim=0)
|
563 |
-
return video_tensor
|
564 |
-
|
565 |
-
def process_audio(self, audio_path, sample_rate):
|
566 |
-
wav = load_wav_util(audio_path, sample_rate)
|
567 |
-
wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
|
568 |
-
return wav_tensor, sample_rate
|
569 |
-
|
570 |
-
def generate_still_pose(self, pose_style):
|
571 |
-
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
572 |
-
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
|
573 |
-
return ref_pose_coeff
|
574 |
-
|
575 |
-
def generate_still_expression(self, exp_scale):
|
576 |
-
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
577 |
-
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
|
578 |
-
return ref_expression_coeff
|
579 |
-
|
580 |
-
def generate_idles_pose(self, length_of_audio, pose_style):
|
581 |
-
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
582 |
-
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
583 |
-
start_pose = self.generate_still_pose(pose_style)
|
584 |
-
end_pose = self.generate_still_pose(pose_style)
|
585 |
-
for frame_idx in range(num_frames):
|
586 |
-
alpha = frame_idx / num_frames
|
587 |
-
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
|
588 |
-
return ref_pose_coeff
|
589 |
-
|
590 |
-
def generate_idles_expression(self, length_of_audio):
|
591 |
-
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
592 |
-
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
593 |
-
start_exp = self.generate_still_expression(1.0)
|
594 |
-
end_exp = self.generate_still_expression(1.0)
|
595 |
-
for frame_idx in range(num_frames):
|
596 |
-
alpha = frame_idx / num_frames
|
597 |
-
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
|
598 |
-
return ref_expression_coeff
|
599 |
-
|
600 |
-
|
601 |
-
class KeyPointExtractor(nn.Module):
|
602 |
-
|
603 |
-
def __init__(self, sadtalker_cfg, device):
|
604 |
-
super(KeyPointExtractor, self).__init__()
|
605 |
-
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
|
606 |
-
num_kp=10,
|
607 |
-
num_dilation_blocks=2,
|
608 |
-
dropout_rate=0.1).to(device)
|
609 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
610 |
-
self.load_kp_detector(checkpoint_path, device)
|
611 |
-
|
612 |
-
def load_kp_detector(self, checkpoint_path, device):
|
613 |
-
if os.path.exists(checkpoint_path):
|
614 |
-
if checkpoint_path.endswith('safetensors'):
|
615 |
-
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
616 |
-
else:
|
617 |
-
checkpoint = torch.load(checkpoint_path, map_location=device)
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
return
|
689 |
-
|
690 |
-
def
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
dense_motion =
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
self.
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import uuid
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import yaml
|
10 |
+
from PIL import Image
|
11 |
+
from skimage import img_as_ubyte, transform
|
12 |
+
import safetensors
|
13 |
+
import librosa
|
14 |
+
from pydub import AudioSegment
|
15 |
+
import imageio
|
16 |
+
from scipy import signal
|
17 |
+
from scipy.io import loadmat, savemat, wavfile
|
18 |
+
import glob
|
19 |
+
import tempfile
|
20 |
+
from tqdm import tqdm
|
21 |
+
import math
|
22 |
+
import torchaudio
|
23 |
+
import urllib.request
|
24 |
+
from safetensors.torch import load_file, save_file
|
25 |
+
|
26 |
+
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
27 |
+
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
28 |
+
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
|
29 |
+
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
30 |
+
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
|
31 |
+
kp_file = "kp_detector.safetensors"
|
32 |
+
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
|
33 |
+
aud_file = "auido2pose_00140-model.pth"
|
34 |
+
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
|
35 |
+
wav_file = "wav2vec2.pth"
|
36 |
+
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
|
37 |
+
gen_file = "generator.pth"
|
38 |
+
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
|
39 |
+
mapx_file = "mapping.pth"
|
40 |
+
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
|
41 |
+
den_file = "dense_motion.pth"
|
42 |
+
|
43 |
+
|
44 |
+
def download_model(url, filename, checkpoint_dir):
|
45 |
+
if not os.path.exists(os.path.join(checkpoint_dir, filename)):
|
46 |
+
print(f"Downloading {filename}...")
|
47 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
48 |
+
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
|
49 |
+
print(f"{filename} downloaded.")
|
50 |
+
else:
|
51 |
+
print(f"{filename} already exists.")
|
52 |
+
|
53 |
+
|
54 |
+
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
|
55 |
+
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
|
56 |
+
|
57 |
+
|
58 |
+
def load_wav_util(path, sr):
|
59 |
+
return librosa.core.load(path, sr=sr)[0]
|
60 |
+
|
61 |
+
|
62 |
+
def save_wav_util(wav, path, sr):
|
63 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
64 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
65 |
+
|
66 |
+
|
67 |
+
class OcclusionAwareKPDetector(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
|
70 |
+
super(OcclusionAwareKPDetector, self).__init__()
|
71 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
72 |
+
self.bn1 = nn.BatchNorm2d(64)
|
73 |
+
self.relu = nn.ReLU()
|
74 |
+
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
78 |
+
x = self.conv2(x)
|
79 |
+
kp = {'value': x.view(x.size(0), -1)}
|
80 |
+
return kp
|
81 |
+
|
82 |
+
|
83 |
+
class Wav2Vec2Model(nn.Module):
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
super(Wav2Vec2Model, self).__init__()
|
87 |
+
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
|
88 |
+
self.bn = nn.BatchNorm1d(64)
|
89 |
+
self.relu = nn.ReLU()
|
90 |
+
self.fc = nn.Linear(64, 2048)
|
91 |
+
|
92 |
+
def forward(self, audio):
|
93 |
+
x = audio.unsqueeze(1)
|
94 |
+
x = self.relu(self.bn(self.conv(x)))
|
95 |
+
x = torch.mean(x, dim=-1)
|
96 |
+
x = self.fc(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class AudioCoeffsPredictor(nn.Module):
|
101 |
+
|
102 |
+
def __init__(self, input_dim, output_dim):
|
103 |
+
super(AudioCoeffsPredictor, self).__init__()
|
104 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
105 |
+
|
106 |
+
def forward(self, audio_embedding):
|
107 |
+
return self.linear(audio_embedding)
|
108 |
+
|
109 |
+
|
110 |
+
class MappingNet(nn.Module):
|
111 |
+
|
112 |
+
def __init__(self, num_coeffs, num_layers, hidden_dim):
|
113 |
+
super(MappingNet, self).__init__()
|
114 |
+
layers = []
|
115 |
+
input_dim = num_coeffs * 2
|
116 |
+
for _ in range(num_layers):
|
117 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
118 |
+
layers.append(nn.ReLU())
|
119 |
+
input_dim = hidden_dim
|
120 |
+
layers.append(nn.Linear(hidden_dim, num_coeffs))
|
121 |
+
self.net = nn.Sequential(*layers)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
return self.net(x)
|
125 |
+
|
126 |
+
|
127 |
+
class DenseMotionNetwork(nn.Module):
|
128 |
+
|
129 |
+
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
|
130 |
+
super(DenseMotionNetwork, self).__init__()
|
131 |
+
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
|
132 |
+
self.relu = nn.ReLU()
|
133 |
+
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
|
134 |
+
|
135 |
+
def forward(self, kp_source, kp_driving, jacobian):
|
136 |
+
x = self.relu(self.conv1(kp_source))
|
137 |
+
x = self.conv2(x)
|
138 |
+
sparse_motion = {'dense_motion': x}
|
139 |
+
return sparse_motion
|
140 |
+
|
141 |
+
|
142 |
+
class Hourglass(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
|
145 |
+
super(Hourglass, self).__init__()
|
146 |
+
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
|
147 |
+
nn.BatchNorm2d(max_features), nn.ReLU())
|
148 |
+
self.decoder = nn.Sequential(
|
149 |
+
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
|
150 |
+
|
151 |
+
def forward(self, source_image, kp_driving, **kwargs):
|
152 |
+
x = self.encoder(source_image)
|
153 |
+
x = self.decoder(x)
|
154 |
+
B, C, H, W = x.size()
|
155 |
+
video = []
|
156 |
+
for _ in range(10):
|
157 |
+
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
|
158 |
+
np.uint8)
|
159 |
+
video.append(frame)
|
160 |
+
return video
|
161 |
+
|
162 |
+
|
163 |
+
class Face3DHelper:
|
164 |
+
|
165 |
+
def __init__(self, local_pca_path, device):
|
166 |
+
self.local_pca_path = local_pca_path
|
167 |
+
self.device = device
|
168 |
+
|
169 |
+
def run(self, source_image):
|
170 |
+
h, w, _ = source_image.shape
|
171 |
+
x_min = w // 4
|
172 |
+
y_min = h // 4
|
173 |
+
x_max = x_min + w // 2
|
174 |
+
y_max = y_min + h // 2
|
175 |
+
return [x_min, y_min, x_max, y_max]
|
176 |
+
|
177 |
+
|
178 |
+
class Face3DHelperOld(Face3DHelper):
|
179 |
+
|
180 |
+
def __init__(self, local_pca_path, device):
|
181 |
+
super(Face3DHelperOld, self).__init__(local_pca_path, device)
|
182 |
+
|
183 |
+
|
184 |
+
class MouthDetector:
|
185 |
+
|
186 |
+
def __init__(self):
|
187 |
+
pass
|
188 |
+
|
189 |
+
def detect(self, image):
|
190 |
+
h, w = image.shape[:2]
|
191 |
+
return (w // 2, h // 2)
|
192 |
+
|
193 |
+
|
194 |
+
class KeypointNorm(nn.Module):
|
195 |
+
|
196 |
+
def __init__(self, device):
|
197 |
+
super(KeypointNorm, self).__init__()
|
198 |
+
self.device = device
|
199 |
+
|
200 |
+
def forward(self, kp_driving):
|
201 |
+
return kp_driving
|
202 |
+
|
203 |
+
|
204 |
+
def save_video_with_watermark(video_frames, audio_path, output_path):
|
205 |
+
H, W, _ = video_frames[0].shape
|
206 |
+
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
207 |
+
for frame in video_frames:
|
208 |
+
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
209 |
+
out.release()
|
210 |
+
|
211 |
+
|
212 |
+
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
|
213 |
+
shutil.copy(video_path, output_path)
|
214 |
+
|
215 |
+
|
216 |
+
class TTSTalker:
|
217 |
+
|
218 |
+
def __init__(self):
|
219 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
220 |
+
self.tts_model = None
|
221 |
+
|
222 |
+
def load_model(self):
|
223 |
+
self.tts_model = self
|
224 |
+
|
225 |
+
def tokenizer(self, text):
|
226 |
+
return [ord(c) for c in text]
|
227 |
+
|
228 |
+
def __call__(self, input_tokens):
|
229 |
+
return torch.zeros(1, 16000, device=self.device)
|
230 |
+
|
231 |
+
def test(self, text, lang='en'):
|
232 |
+
if self.tts_model is None:
|
233 |
+
self.load_model()
|
234 |
+
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
|
235 |
+
os.makedirs('./results', exist_ok=True)
|
236 |
+
tokens = self.tokenizer(text)
|
237 |
+
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
|
238 |
+
with torch.no_grad():
|
239 |
+
audio_output = self(input_tokens)
|
240 |
+
torchaudio.save(output_path, audio_output.cpu(), 16000)
|
241 |
+
return output_path
|
242 |
+
|
243 |
+
|
244 |
+
class SadTalker:
|
245 |
+
|
246 |
+
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
|
247 |
+
old_version=False):
|
248 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
249 |
+
self.cfg = self.get_cfg_defaults()
|
250 |
+
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
|
251 |
+
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
|
252 |
+
self.cfg['MODEL']['CONFIG_DIR'] = config_path
|
253 |
+
self.cfg['MODEL']['DEVICE'] = self.device
|
254 |
+
self.cfg['INPUT_IMAGE'] = {}
|
255 |
+
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
|
256 |
+
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
|
257 |
+
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
|
258 |
+
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
259 |
+
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
260 |
+
|
261 |
+
download_model(kp_url, kp_file, checkpoint_path)
|
262 |
+
download_model(aud_url, aud_file, checkpoint_path)
|
263 |
+
download_model(wav_url, wav_file, checkpoint_path)
|
264 |
+
download_model(gen_url, gen_file, checkpoint_path)
|
265 |
+
download_model(mapx_url, mapx_file, checkpoint_path)
|
266 |
+
download_model(den_url, den_file, checkpoint_path)
|
267 |
+
download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
|
268 |
+
download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
|
269 |
+
|
270 |
+
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
271 |
+
|
272 |
+
def get_cfg_defaults(self):
|
273 |
+
return {
|
274 |
+
'MODEL': {
|
275 |
+
'CHECKPOINTS_DIR': '',
|
276 |
+
'CONFIG_DIR': '',
|
277 |
+
'DEVICE': self.device,
|
278 |
+
'SCALE': 64,
|
279 |
+
'NUM_VOXEL_FRAMES': 8,
|
280 |
+
'NUM_MOTION_FRAMES': 10,
|
281 |
+
'MAX_FEATURES': 256,
|
282 |
+
'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
|
283 |
+
'VIDEO_FPS': 25,
|
284 |
+
'OUTPUT_VIDEO_FPS': None,
|
285 |
+
'OUTPUT_AUDIO_SAMPLE_RATE': None,
|
286 |
+
'USE_ENHANCER': False,
|
287 |
+
'ENHANCER_NAME': '',
|
288 |
+
'BG_UPSAMPLER': None,
|
289 |
+
'IS_HALF': False
|
290 |
+
},
|
291 |
+
'INPUT_IMAGE': {}
|
292 |
+
}
|
293 |
+
|
294 |
+
def merge_from_file(self, filepath):
|
295 |
+
if os.path.exists(filepath):
|
296 |
+
with open(filepath, 'r') as f:
|
297 |
+
cfg_from_file = yaml.safe_load(f)
|
298 |
+
self.cfg.update(cfg_from_file)
|
299 |
+
|
300 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
301 |
+
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
302 |
+
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
303 |
+
tts_text=None, tts_lang='en'):
|
304 |
+
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
|
305 |
+
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
306 |
+
length_of_audio, use_blink, result_dir, tts_text, tts_lang)
|
307 |
+
return self.sadtalker_model.save_result()
|
308 |
+
|
309 |
+
|
310 |
+
class SadTalkerModel:
|
311 |
+
|
312 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
313 |
+
self.cfg = sadtalker_cfg
|
314 |
+
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
315 |
+
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
316 |
+
self.preprocesser = self.sadtalker.preprocesser
|
317 |
+
self.kp_extractor = self.sadtalker.kp_extractor
|
318 |
+
self.generator = self.sadtalker.generator
|
319 |
+
self.mapping = self.sadtalker.mapping
|
320 |
+
self.he_estimator = self.sadtalker.he_estimator
|
321 |
+
self.audio_to_coeff = self.sadtalker.audio_to_coeff
|
322 |
+
self.animate_from_coeff = self.sadtalker.animate_from_coeff
|
323 |
+
self.face_enhancer = self.sadtalker.face_enhancer
|
324 |
+
|
325 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
326 |
+
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
327 |
+
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
328 |
+
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
|
329 |
+
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
330 |
+
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
|
331 |
+
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
|
332 |
+
jitter_amount, jitter_source_image)
|
333 |
+
return self.inner_test.test()
|
334 |
+
|
335 |
+
def save_result(self):
|
336 |
+
return self.inner_test.save_result()
|
337 |
+
|
338 |
+
|
339 |
+
class SadTalkerInner:
|
340 |
+
|
341 |
+
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
342 |
+
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
343 |
+
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
|
344 |
+
self.sadtalker_model = sadtalker_model
|
345 |
+
self.source_image = source_image
|
346 |
+
self.driven_audio = driven_audio
|
347 |
+
self.preprocess = preprocess
|
348 |
+
self.still_mode = still_mode
|
349 |
+
self.use_enhancer = use_enhancer
|
350 |
+
self.batch_size = batch_size
|
351 |
+
self.size = size
|
352 |
+
self.pose_style = pose_style
|
353 |
+
self.exp_scale = exp_scale
|
354 |
+
self.use_ref_video = use_ref_video
|
355 |
+
self.ref_video = ref_video
|
356 |
+
self.ref_info = ref_info
|
357 |
+
self.use_idle_mode = use_idle_mode
|
358 |
+
self.length_of_audio = length_of_audio
|
359 |
+
self.use_blink = use_blink
|
360 |
+
self.result_dir = result_dir
|
361 |
+
self.tts_text = tts_text
|
362 |
+
self.tts_lang = tts_lang
|
363 |
+
self.jitter_amount = jitter_amount
|
364 |
+
self.jitter_source_image = jitter_source_image
|
365 |
+
self.device = self.sadtalker_model.device
|
366 |
+
self.output_path = None
|
367 |
+
|
368 |
+
def get_test_data(self):
|
369 |
+
proc = self.sadtalker_model.preprocesser
|
370 |
+
if self.tts_text is not None:
|
371 |
+
temp_dir = tempfile.mkdtemp()
|
372 |
+
audio_path = os.path.join(temp_dir, 'audio.wav')
|
373 |
+
tts = TTSTalker()
|
374 |
+
tts.test(self.tts_text, self.tts_lang)
|
375 |
+
self.driven_audio = audio_path
|
376 |
+
source_image_pil = Image.open(self.source_image).convert('RGB')
|
377 |
+
if self.jitter_source_image:
|
378 |
+
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
379 |
+
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
380 |
+
source_image_pil = Image.fromarray(
|
381 |
+
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
|
382 |
+
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
|
383 |
+
if self.still_mode or self.use_idle_mode:
|
384 |
+
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
385 |
+
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
386 |
+
elif self.use_idle_mode:
|
387 |
+
ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
|
388 |
+
ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
|
389 |
+
else:
|
390 |
+
ref_pose_coeff = None
|
391 |
+
ref_expression_coeff = None
|
392 |
+
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
393 |
+
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
|
394 |
+
batch = {
|
395 |
+
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
396 |
+
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
397 |
+
'ref_pose_coeff': ref_pose_coeff,
|
398 |
+
'ref_expression_coeff': ref_expression_coeff,
|
399 |
+
'source_image_crop': cropped_image,
|
400 |
+
'crop_info': crop_info,
|
401 |
+
'use_blink': self.use_blink,
|
402 |
+
'pose_style': self.pose_style,
|
403 |
+
'exp_scale': self.exp_scale,
|
404 |
+
'ref_video': self.ref_video,
|
405 |
+
'use_ref_video': self.use_ref_video,
|
406 |
+
'ref_info': self.ref_info,
|
407 |
+
}
|
408 |
+
return batch, audio_sample_rate
|
409 |
+
|
410 |
+
def run_inference(self, batch):
|
411 |
+
kp_extractor = self.sadtalker_model.kp_extractor
|
412 |
+
generator = self.sadtalker_model.generator
|
413 |
+
mapping = self.sadtalker_model.mapping
|
414 |
+
he_estimator = self.sadtalker_model.he_estimator
|
415 |
+
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
416 |
+
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
417 |
+
proc = self.sadtalker_model.preprocesser
|
418 |
+
with torch.no_grad():
|
419 |
+
kp_source = kp_extractor(batch['source_image'])
|
420 |
+
if self.still_mode or self.use_idle_mode:
|
421 |
+
ref_pose_coeff = batch['ref_pose_coeff']
|
422 |
+
ref_expression_coeff = batch['ref_expression_coeff']
|
423 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
424 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
425 |
+
elif self.use_idle_mode:
|
426 |
+
ref_pose_coeff = batch['ref_pose_coeff']
|
427 |
+
ref_expression_coeff = batch['ref_expression_coeff']
|
428 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
429 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
430 |
+
else:
|
431 |
+
if self.use_ref_video:
|
432 |
+
kp_ref = kp_extractor(batch['source_image'])
|
433 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
|
434 |
+
use_ref_info=batch['ref_info'])
|
435 |
+
else:
|
436 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
|
437 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
|
438 |
+
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
|
439 |
+
if self.use_blink:
|
440 |
+
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
|
441 |
+
else:
|
442 |
+
coeff['blink_coeff'] = None
|
443 |
+
kp_driving = audio_to_coeff(batch['audio'])[0]
|
444 |
+
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
445 |
+
coeff['kp_driving'] = kp_norm
|
446 |
+
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
447 |
+
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
448 |
+
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
449 |
+
he_estimator, batch['audio'], batch['source_image_crop'],
|
450 |
+
face_enhancer=face_enhancer)
|
451 |
+
return output_video
|
452 |
+
|
453 |
+
def post_processing(self, output_video, audio_sample_rate, batch):
|
454 |
+
proc = self.sadtalker_model.preprocesser
|
455 |
+
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
|
456 |
+
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
457 |
+
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
458 |
+
self.output_path = output_video_path
|
459 |
+
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
|
460 |
+
'OUTPUT_VIDEO_FPS'] is None else \
|
461 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
|
462 |
+
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
|
463 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
|
464 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
465 |
+
if self.use_enhancer:
|
466 |
+
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
467 |
+
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
468 |
+
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
|
469 |
+
output_video_path)
|
470 |
+
os.remove(enhanced_path)
|
471 |
+
else:
|
472 |
+
save_video_with_watermark(output_video, self.driven_audio, output_video_path)
|
473 |
+
if self.tts_text is not None:
|
474 |
+
shutil.rmtree(os.path.dirname(self.driven_audio))
|
475 |
+
|
476 |
+
def save_result(self):
|
477 |
+
return self.output_path
|
478 |
+
|
479 |
+
def __call__(self):
|
480 |
+
return self.output_path
|
481 |
+
|
482 |
+
def test(self):
|
483 |
+
batch, audio_sample_rate = self.get_test_data()
|
484 |
+
output_video = self.run_inference(batch)
|
485 |
+
self.post_processing(output_video, audio_sample_rate, batch)
|
486 |
+
return self.save_result()
|
487 |
+
|
488 |
+
|
489 |
+
class SadTalkerInnerModel:
|
490 |
+
|
491 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
492 |
+
self.cfg = sadtalker_cfg
|
493 |
+
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
494 |
+
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
495 |
+
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
496 |
+
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
497 |
+
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
498 |
+
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
|
499 |
+
'USE_ENHANCER'] else None
|
500 |
+
self.generator = Generator(sadtalker_cfg, self.device)
|
501 |
+
self.mapping = Mapping(sadtalker_cfg, self.device)
|
502 |
+
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
503 |
+
|
504 |
+
|
505 |
+
class Preprocesser:
|
506 |
+
|
507 |
+
def __init__(self, sadtalker_cfg, device):
|
508 |
+
self.cfg = sadtalker_cfg
|
509 |
+
self.device = device
|
510 |
+
if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
|
511 |
+
self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
512 |
+
else:
|
513 |
+
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
514 |
+
self.mouth_detector = MouthDetector()
|
515 |
+
|
516 |
+
def crop(self, source_image_pil, preprocess_type, size=256):
|
517 |
+
source_image = np.array(source_image_pil)
|
518 |
+
face_info = self.face3d_helper.run(source_image)
|
519 |
+
if face_info is None:
|
520 |
+
raise Exception("No face detected")
|
521 |
+
x_min, y_min, x_max, y_max = face_info[:4]
|
522 |
+
old_size = (x_max - x_min, y_max - y_min)
|
523 |
+
x_center = (x_max + x_min) / 2
|
524 |
+
y_center = (y_max + y_min) / 2
|
525 |
+
if preprocess_type == 'crop':
|
526 |
+
face_size = max(x_max - x_min, y_max - y_min)
|
527 |
+
x_min = int(x_center - face_size / 2)
|
528 |
+
y_min = int(y_center - face_size / 2)
|
529 |
+
x_max = int(x_center + face_size / 2)
|
530 |
+
y_max = int(y_center + face_size / 2)
|
531 |
+
else:
|
532 |
+
x_min -= int((x_max - x_min) * 0.1)
|
533 |
+
y_min -= int((y_max - y_min) * 0.1)
|
534 |
+
x_max += int((x_max - x_min) * 0.1)
|
535 |
+
y_max += int((y_max - y_min) * 0.1)
|
536 |
+
h, w = source_image.shape[:2]
|
537 |
+
x_min = max(0, x_min)
|
538 |
+
y_min = max(0, y_min)
|
539 |
+
x_max = min(w, x_max)
|
540 |
+
y_max = min(h, y_max)
|
541 |
+
cropped_image = source_image[y_min:y_max, x_min:x_max]
|
542 |
+
cropped_image_pil = Image.fromarray(cropped_image)
|
543 |
+
if size is not None and size != 0:
|
544 |
+
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
545 |
+
source_image_tensor = self.img2tensor(cropped_image_pil)
|
546 |
+
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
547 |
+
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
|
548 |
+
|
549 |
+
def img2tensor(self, img):
|
550 |
+
img = np.array(img).astype(np.float32) / 255.0
|
551 |
+
img = np.transpose(img, (2, 0, 1))
|
552 |
+
return torch.FloatTensor(img)
|
553 |
+
|
554 |
+
def video_to_tensor(self, video, device):
|
555 |
+
video_tensor_list = []
|
556 |
+
import torchvision.transforms as transforms
|
557 |
+
transform_func = transforms.ToTensor()
|
558 |
+
for frame in video:
|
559 |
+
frame_pil = Image.fromarray(frame)
|
560 |
+
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
|
561 |
+
video_tensor_list.append(frame_tensor)
|
562 |
+
video_tensor = torch.cat(video_tensor_list, dim=0)
|
563 |
+
return video_tensor
|
564 |
+
|
565 |
+
def process_audio(self, audio_path, sample_rate):
|
566 |
+
wav = load_wav_util(audio_path, sample_rate)
|
567 |
+
wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
|
568 |
+
return wav_tensor, sample_rate
|
569 |
+
|
570 |
+
def generate_still_pose(self, pose_style):
|
571 |
+
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
572 |
+
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
|
573 |
+
return ref_pose_coeff
|
574 |
+
|
575 |
+
def generate_still_expression(self, exp_scale):
|
576 |
+
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
577 |
+
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
|
578 |
+
return ref_expression_coeff
|
579 |
+
|
580 |
+
def generate_idles_pose(self, length_of_audio, pose_style):
|
581 |
+
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
582 |
+
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
583 |
+
start_pose = self.generate_still_pose(pose_style)
|
584 |
+
end_pose = self.generate_still_pose(pose_style)
|
585 |
+
for frame_idx in range(num_frames):
|
586 |
+
alpha = frame_idx / num_frames
|
587 |
+
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
|
588 |
+
return ref_pose_coeff
|
589 |
+
|
590 |
+
def generate_idles_expression(self, length_of_audio):
|
591 |
+
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
592 |
+
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
593 |
+
start_exp = self.generate_still_expression(1.0)
|
594 |
+
end_exp = self.generate_still_expression(1.0)
|
595 |
+
for frame_idx in range(num_frames):
|
596 |
+
alpha = frame_idx / num_frames
|
597 |
+
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
|
598 |
+
return ref_expression_coeff
|
599 |
+
|
600 |
+
|
601 |
+
class KeyPointExtractor(nn.Module):
|
602 |
+
|
603 |
+
def __init__(self, sadtalker_cfg, device):
|
604 |
+
super(KeyPointExtractor, self).__init__()
|
605 |
+
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
|
606 |
+
num_kp=10,
|
607 |
+
num_dilation_blocks=2,
|
608 |
+
dropout_rate=0.1).to(device)
|
609 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
610 |
+
self.load_kp_detector(checkpoint_path, device)
|
611 |
+
|
612 |
+
def load_kp_detector(self, checkpoint_path, device):
|
613 |
+
if os.path.exists(checkpoint_path):
|
614 |
+
if checkpoint_path.endswith('safetensors'):
|
615 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
616 |
+
else:
|
617 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
618 |
+
try:
|
619 |
+
self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
|
620 |
+
except RuntimeError as e:
|
621 |
+
print(f"Error loading kp_detector state_dict: {e}")
|
622 |
+
print("Trying to load state_dict without prefix 'kp_detector.'")
|
623 |
+
self.kp_extractor.load_state_dict(checkpoint, strict=False)
|
624 |
+
|
625 |
+
else:
|
626 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
627 |
+
|
628 |
+
def forward(self, x):
|
629 |
+
kp = self.kp_extractor(x)
|
630 |
+
return kp
|
631 |
+
|
632 |
+
|
633 |
+
class Audio2Coeff(nn.Module):
|
634 |
+
|
635 |
+
def __init__(self, sadtalker_cfg, device):
|
636 |
+
super(Audio2Coeff, self).__init__()
|
637 |
+
self.audio_model = Wav2Vec2Model().to(device)
|
638 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
|
639 |
+
self.load_audio_model(checkpoint_path, device)
|
640 |
+
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
641 |
+
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
642 |
+
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
643 |
+
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
|
644 |
+
self.load_mapping_model(mapping_checkpoint, device)
|
645 |
+
|
646 |
+
def load_audio_model(self, checkpoint_path, device):
|
647 |
+
if os.path.exists(checkpoint_path):
|
648 |
+
if checkpoint_path.endswith('safetensors'):
|
649 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
650 |
+
else:
|
651 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
652 |
+
self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
|
653 |
+
else:
|
654 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
655 |
+
|
656 |
+
def load_mapping_model(self, checkpoint_path, device):
|
657 |
+
if os.path.exists(checkpoint_path):
|
658 |
+
if checkpoint_path.endswith('safetensors'):
|
659 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
660 |
+
else:
|
661 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
662 |
+
self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
|
663 |
+
self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
|
664 |
+
self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
|
665 |
+
else:
|
666 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
667 |
+
|
668 |
+
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
|
669 |
+
audio_embedding = self.audio_model(audio_tensor)
|
670 |
+
pose_coeff = self.pose_mapper(audio_embedding)
|
671 |
+
if ref_pose_coeff is not None:
|
672 |
+
pose_coeff = ref_pose_coeff
|
673 |
+
if kp_ref is not None and use_ref_info == 'pose':
|
674 |
+
ref_pose_6d = kp_ref['value'][:, :6]
|
675 |
+
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
|
676 |
+
return pose_coeff
|
677 |
+
|
678 |
+
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
|
679 |
+
audio_embedding = self.audio_model(audio_tensor)
|
680 |
+
expression_coeff = self.exp_mapper(audio_embedding)
|
681 |
+
if ref_expression_coeff is not None:
|
682 |
+
expression_coeff = ref_expression_coeff
|
683 |
+
return expression_coeff
|
684 |
+
|
685 |
+
def get_blink_coeff(self, audio_tensor):
|
686 |
+
audio_embedding = self.audio_model(audio_tensor)
|
687 |
+
blink_coeff = self.blink_mapper(audio_embedding)
|
688 |
+
return blink_coeff
|
689 |
+
|
690 |
+
def forward(self, audio):
|
691 |
+
audio_embedding = self.audio_model(audio)
|
692 |
+
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
|
693 |
+
audio_embedding), self.blink_mapper(audio_embedding)
|
694 |
+
return pose_coeff, expression_coeff, blink_coeff
|
695 |
+
|
696 |
+
def mean_std_normalize(self, coeff):
|
697 |
+
mean = coeff.mean(dim=1, keepdim=True)
|
698 |
+
std = coeff.std(dim=1, keepdim=True)
|
699 |
+
return (coeff - mean) / std
|
700 |
+
|
701 |
+
|
702 |
+
class AnimateFromCoeff(nn.Module):
|
703 |
+
|
704 |
+
def __init__(self, sadtalker_cfg, device):
|
705 |
+
super(AnimateFromCoeff, self).__init__()
|
706 |
+
self.generator = Generator(sadtalker_cfg, device)
|
707 |
+
self.mapping = Mapping(sadtalker_cfg, device)
|
708 |
+
self.kp_norm = KeypointNorm(device=device)
|
709 |
+
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
|
710 |
+
|
711 |
+
def normalize_kp(self, kp_driving):
|
712 |
+
return self.kp_norm(kp_driving)
|
713 |
+
|
714 |
+
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
|
715 |
+
face_enhancer=None):
|
716 |
+
kp_driving = coeff['kp_driving']
|
717 |
+
jacobian = coeff['jacobian']
|
718 |
+
pose_coeff = coeff['pose_coeff']
|
719 |
+
expression_coeff = coeff['expression_coeff']
|
720 |
+
blink_coeff = coeff['blink_coeff']
|
721 |
+
with torch.no_grad():
|
722 |
+
if blink_coeff is not None:
|
723 |
+
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
724 |
+
dense_motion = sparse_motion['dense_motion']
|
725 |
+
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
|
726 |
+
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
|
727 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
|
728 |
+
face_3d_param=face_3d)
|
729 |
+
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
|
730 |
+
video_output = self.make_animation(video_output)
|
731 |
+
else:
|
732 |
+
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
733 |
+
dense_motion = sparse_motion['dense_motion']
|
734 |
+
face_3d = mapping(expression_coeff, pose_coeff)
|
735 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
|
736 |
+
face_3d_param=face_3d)
|
737 |
+
video_output = video_3d['video_3d']
|
738 |
+
video_output = self.make_animation(video_output)
|
739 |
+
if face_enhancer is not None:
|
740 |
+
video_output_enhanced = []
|
741 |
+
for frame in tqdm(video_output, 'Face enhancer running'):
|
742 |
+
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
743 |
+
enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
|
744 |
+
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
|
745 |
+
video_output = video_output_enhanced
|
746 |
+
return video_output
|
747 |
+
|
748 |
+
def make_animation(self, video_array):
|
749 |
+
H, W, _ = video_array[0].shape
|
750 |
+
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
751 |
+
for img in video_array:
|
752 |
+
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
753 |
+
out.release()
|
754 |
+
video = imageio.mimread('./tmp.mp4')
|
755 |
+
os.remove('./tmp.mp4')
|
756 |
+
return video
|
757 |
+
|
758 |
+
|
759 |
+
class Generator(nn.Module):
|
760 |
+
|
761 |
+
def __init__(self, sadtalker_cfg, device):
|
762 |
+
super(Generator, self).__init__()
|
763 |
+
self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
764 |
+
num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
|
765 |
+
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
|
766 |
+
num_channels=3,
|
767 |
+
kp_size=10,
|
768 |
+
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
769 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
|
770 |
+
self.load_generator(checkpoint_path, device)
|
771 |
+
|
772 |
+
def load_generator(self, checkpoint_path, device):
|
773 |
+
if os.path.exists(checkpoint_path):
|
774 |
+
if checkpoint_path.endswith('safetensors'):
|
775 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
776 |
+
else:
|
777 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
778 |
+
self.generator.load_state_dict(checkpoint.get('generator', {}))
|
779 |
+
else:
|
780 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
781 |
+
|
782 |
+
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
|
783 |
+
if face_3d_param is not None:
|
784 |
+
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
|
785 |
+
face_3d_param=face_3d_param)
|
786 |
+
else:
|
787 |
+
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
|
788 |
+
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
789 |
+
|
790 |
+
|
791 |
+
class Mapping(nn.Module):
|
792 |
+
|
793 |
+
def __init__(self, sadtalker_cfg, device):
|
794 |
+
super(Mapping, self).__init__()
|
795 |
+
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
796 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
|
797 |
+
self.load_mapping_net(checkpoint_path, device)
|
798 |
+
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
799 |
+
|
800 |
+
def load_mapping_net(self, checkpoint_path, device):
|
801 |
+
if os.path.exists(checkpoint_path):
|
802 |
+
if checkpoint_path.endswith('safetensors'):
|
803 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
804 |
+
else:
|
805 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
806 |
+
self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
|
807 |
+
else:
|
808 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
809 |
+
|
810 |
+
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
|
811 |
+
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
|
812 |
+
face_3d = self.mapping_net(coeff) + self.f_3d_mean
|
813 |
+
if blink_coeff is not None:
|
814 |
+
face_3d[:, -1:] = blink_coeff
|
815 |
+
return face_3d
|
816 |
+
|
817 |
+
|
818 |
+
class OcclusionAwareDenseMotion(nn.Module):
|
819 |
+
|
820 |
+
def __init__(self, sadtalker_cfg, device):
|
821 |
+
super(OcclusionAwareDenseMotion, self).__init__()
|
822 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
|
823 |
+
num_channels=3,
|
824 |
+
block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
825 |
+
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
|
826 |
+
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
827 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
|
828 |
+
self.load_dense_motion_network(checkpoint_path, device)
|
829 |
+
|
830 |
+
def load_dense_motion_network(self, checkpoint_path, device):
|
831 |
+
if os.path.exists(checkpoint_path):
|
832 |
+
if checkpoint_path.endswith('safetensors'):
|
833 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
834 |
+
else:
|
835 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
836 |
+
self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
|
837 |
+
else:
|
838 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
839 |
+
|
840 |
+
def forward(self, kp_source, kp_driving, jacobian):
|
841 |
+
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
|
842 |
+
return sparse_motion
|
843 |
+
|
844 |
+
|
845 |
+
class FaceEnhancer(nn.Module):
|
846 |
+
|
847 |
+
def __init__(self, sadtalker_cfg, device):
|
848 |
+
super(FaceEnhancer, self).__init__()
|
849 |
+
enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
|
850 |
+
bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
|
851 |
+
if enhancer_name == 'gfpgan':
|
852 |
+
from gfpgan import GFPGANer
|
853 |
+
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
|
854 |
+
upscale=1,
|
855 |
+
arch='clean',
|
856 |
+
channel_multiplier=2,
|
857 |
+
bg_upsampler=bg_upsampler)
|
858 |
+
elif enhancer_name == 'realesrgan':
|
859 |
+
from realesrgan import RealESRGANer
|
860 |
+
half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
|
861 |
+
self.face_enhancer = RealESRGANer(scale=2,
|
862 |
+
model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
|
863 |
+
'RealESRGAN_x2plus.pth'),
|
864 |
+
tile=0,
|
865 |
+
tile_pad=10,
|
866 |
+
pre_pad=0,
|
867 |
+
half=half,
|
868 |
+
device=device)
|
869 |
+
else:
|
870 |
+
self.face_enhancer = None
|
871 |
+
|
872 |
+
def forward(self, x):
|
873 |
+
return self.face_enhancer.enhance(x, outscale=1)[0]
|
874 |
+
|
875 |
+
def load_models():
|
876 |
+
checkpoint_path = './checkpoints'
|
877 |
+
config_path = './src/config'
|
878 |
+
size = 256
|
879 |
+
preprocess = 'crop'
|
880 |
+
old_version = False
|
881 |
+
|
882 |
+
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version)
|
883 |
+
print("SadTalker models loaded successfully!")
|
884 |
+
return sadtalker_instance
|
885 |
+
|
886 |
+
if __name__ == '__main__':
|
887 |
+
sadtalker_instance = load_models()
|