skyreels-a1-talking-head / skyreels_a1 /pre_process_lmk3d.py
multimodalart's picture
Upload 83 files
38e20ed verified
import torch
import cv2
import os
import sys
import numpy as np
import argparse
import math
from PIL import Image
from decord import VideoReader
from skimage.transform import estimate_transform, warp
from insightface.app import FaceAnalysis
from diffusers.utils import load_image
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from skyreels_a1.src.utils.mediapipe_utils import MediaPipeUtils
from skyreels_a1.src.smirk_encoder import SmirkEncoder
from skyreels_a1.src.FLAME.FLAME import FLAME
from skyreels_a1.src.renderer import Renderer
from moviepy.editor import ImageSequenceClip
class FaceAnimationProcessor:
def __init__(self, device='cuda', checkpoint="pretrained_models/smirk/smirk_encoder.pt"):
self.device = device
self.app = FaceAnalysis(allowed_modules=['detection'])
self.app.prepare(ctx_id=0, det_size=(640, 640))
self.smirk_encoder = SmirkEncoder().to(device)
self.flame = FLAME(n_shape=300, n_exp=50).to(device)
self.renderer = Renderer().to(device)
self.load_checkpoint(checkpoint)
def load_checkpoint(self, checkpoint):
checkpoint_data = torch.load(checkpoint)
checkpoint_encoder = {k.replace('smirk_encoder.', ''): v for k, v in checkpoint_data.items() if 'smirk_encoder' in k}
self.smirk_encoder.load_state_dict(checkpoint_encoder)
self.smirk_encoder.eval()
def face_crop(self, image):
height, width, _ = image.shape
faces = self.app.get(image)
bbox = faces[0]['bbox']
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x1 = max(0, int(bbox[0] - w/2))
x2 = min(width - 1, int(bbox[2] + w/2))
w_new = x2 - x1
y_offset = (w_new - h) / 2.
y1 = max(0, int(bbox[1] - y_offset))
y2 = min(height, int(bbox[3] + y_offset))
x_comp = int(((x2 - x1) - (y2 - y1)) / 2) if (x2 - x1) > (y2 - y1) else 0
x1 += x_comp
x2 -= x_comp
image_crop = image[y1:y2, x1:x2]
return image_crop, x1, y1
def crop_and_resize(self, image, height, width):
image = np.array(image)
image_height, image_width, _ = image.shape
if image_height / image_width < height / width:
croped_width = int(image_height / height * width)
left = (image_width - croped_width) // 2
image = image[:, left: left+croped_width]
else:
pad = int((((width / height) * image_height) - image_width) / 2.)
padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
padded_image[:, pad:pad+image_width] = image
image = padded_image
return Image.fromarray(image).resize((width, height))
def rodrigues_to_matrix(self, pose_params):
theta = torch.norm(pose_params, dim=-1, keepdim=True)
r = pose_params / (theta + 1e-8)
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
r_x = torch.zeros((pose_params.shape[0], 3, 3), device=pose_params.device)
r_x[:, 0, 1] = -r[:, 2]
r_x[:, 0, 2] = r[:, 1]
r_x[:, 1, 0] = r[:, 2]
r_x[:, 1, 2] = -r[:, 0]
r_x[:, 2, 0] = -r[:, 1]
r_x[:, 2, 1] = r[:, 0]
R = cos_theta * torch.eye(3, device=pose_params.device).unsqueeze(0) + \
sin_theta * r_x + \
(1 - cos_theta) * r.unsqueeze(-1) @ r.unsqueeze(-2)
return R
def matrix_to_rodrigues(self, R):
cos_theta = (torch.trace(R[0]) - 1) / 2
cos_theta = torch.clamp(cos_theta, -1, 1)
theta = torch.acos(cos_theta)
if abs(theta) < 1e-4:
return torch.zeros(1, 3, device=R.device)
elif abs(theta - math.pi) < 1e-4:
R_plus_I = R[0] + torch.eye(3, device=R.device)
col_norms = torch.norm(R_plus_I, dim=0)
max_col_idx = torch.argmax(col_norms)
v = R_plus_I[:, max_col_idx]
v = v / torch.norm(v)
return (v * math.pi).unsqueeze(0)
sin_theta = torch.sin(theta)
r = torch.zeros(1, 3, device=R.device)
r[0, 0] = R[0, 2, 1] - R[0, 1, 2]
r[0, 1] = R[0, 0, 2] - R[0, 2, 0]
r[0, 2] = R[0, 1, 0] - R[0, 0, 1]
r = r / (2 * sin_theta)
return r * theta
def crop_face(self, frame, landmarks, scale=1.0, image_size=224):
left = np.min(landmarks[:, 0])
right = np.max(landmarks[:, 0])
top = np.min(landmarks[:, 1])
bottom = np.max(landmarks[:, 1])
h, w, _ = frame.shape
old_size = (right - left + bottom - top) / 2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
size = int(old_size * scale)
src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
[center[0] + size / 2, center[1] - size / 2]])
DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)
tform_original = estimate_transform('similarity', src_pts, src_pts)
return tform, tform_original
def compute_landmark_relation(self, kpt_mediapipe, target_idx=473, ref_indices=[362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]):
target_point = kpt_mediapipe[target_idx]
ref_points = kpt_mediapipe[ref_indices]
left_corner = ref_points[0]
right_corner = ref_points[8]
eye_center = (left_corner + right_corner) / 2
eye_width_vector = right_corner - left_corner
eye_width = np.linalg.norm(eye_width_vector)
eye_direction = eye_width_vector / eye_width
eye_vertical = np.array([-eye_direction[1], eye_direction[0]])
target_vector = target_point - eye_center
x_relative = np.dot(target_vector, eye_direction) / (eye_width/2)
y_relative = np.dot(target_vector, eye_vertical) / (eye_width/2)
return [np.array([x_relative, y_relative]),target_point,ref_points,ref_indices]
def process_source_image(self, image_rgb, input_size=224):
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
mediapipe_utils = MediaPipeUtils()
kpt_mediapipe, _, _, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(image_bgr)
if kpt_mediapipe is None:
raise ValueError('Cannot find facial landmarks in the source image')
kpt_mediapipe = kpt_mediapipe[..., :2]
tform, _ = self.crop_face(image_rgb, kpt_mediapipe, scale=1.4, image_size=input_size)
cropped_image = warp(image_rgb, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
cropped_image = torch.tensor(cropped_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
cropped_image = cropped_image.to(self.device)
with torch.no_grad():
source_outputs = self.smirk_encoder(cropped_image)
source_outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
return source_outputs, tform, image_rgb
def smooth_params(self, data, alpha=0.7):
smoothed_data = [data[0]]
for i in range(1, len(data)):
smoothed_value = alpha * data[i] + (1 - alpha) * smoothed_data[i-1]
smoothed_data.append(smoothed_value)
return smoothed_data
def process_driving_img_list(self, img_list, input_size=224):
driving_frames = []
driving_outputs = []
driving_tforms = []
weights_473 = []
weights_468 = []
mediapipe_utils = MediaPipeUtils()
for i, frame in enumerate(img_list):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
try:
kpt_mediapipe, mediapipe_exp, mediapipe_pose, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(frame)
except:
print('Warning: No face detected in a frame, skipping this frame')
continue
if kpt_mediapipe is None:
print('Warning: No face detected in a frame, skipping this frame')
continue
kpt_mediapipe = kpt_mediapipe[..., :2]
weights_473.append(self.compute_landmark_relation(kpt_mediapipe))
weights_468.append(self.compute_landmark_relation(kpt_mediapipe, target_idx=468, ref_indices=[33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]))
tform, _ = self.crop_face(frame, kpt_mediapipe, scale=1.4, image_size=input_size)
cropped_frame = warp(frame, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
cropped_frame = torch.tensor(cropped_frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
cropped_frame = cropped_frame.to(self.device)
with torch.no_grad():
outputs = self.smirk_encoder(cropped_frame)
outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
outputs['mediapipe_exp'] = torch.tensor(mediapipe_exp).to(self.device)
outputs['mediapipe_pose'] = torch.tensor(mediapipe_pose).to(self.device)
driving_frames.append(frame)
driving_outputs.append(outputs)
driving_tforms.append(tform)
return driving_frames, driving_outputs, driving_tforms, weights_473, weights_468
def preprocess_lmk3d(self, source_image=None, driving_image_list=None):
source_outputs, source_tform, image_original = self.process_source_image(source_image)
_, driving_outputs, driving_video_tform, weights_473, weights_468 = self.process_driving_img_list(driving_image_list)
driving_outputs_list = []
source_pose_init = source_outputs['pose_params'].clone()
driving_outputs_pose = [outputs['pose_params'] for outputs in driving_outputs]
driving_outputs_pose = self.smooth_params(driving_outputs_pose)
for i, outputs in enumerate(driving_outputs):
outputs['pose_params'] = driving_outputs_pose[i]
source_outputs['expression_params'] = outputs['expression_params']
source_outputs['jaw_params'] = outputs['jaw_params']
source_outputs['eye_pose_params'] = outputs['eye_pose_params']
source_matrix = self.rodrigues_to_matrix(source_pose_init)
driving_matrix_0 = self.rodrigues_to_matrix(driving_outputs[0]['pose_params'])
driving_matrix_i = self.rodrigues_to_matrix(driving_outputs[i]['pose_params'])
relative_rotation = torch.inverse(driving_matrix_0) @ driving_matrix_i
new_rotation = source_matrix @ relative_rotation
source_outputs['pose_params'] = self.matrix_to_rodrigues(new_rotation)
source_outputs['eyelid_params'] = outputs['eyelid_params']
flame_output = self.flame.forward(source_outputs)
renderer_output = self.renderer.forward(
flame_output['vertices'],
source_outputs['cam'],
landmarks_fan=flame_output['landmarks_fan'], source_tform=source_tform,
tform_512=None, weights_468=weights_468[i], weights_473=weights_473[i],
landmarks_mp=flame_output['landmarks_mp'], shape=image_original.shape)
rendered_img = renderer_output['rendered_img']
driving_outputs_list.extend(np.copy(rendered_img)[np.newaxis, :])
return driving_outputs_list
def preprocess_lmk3d_from_coef(self, source_outputs, source_tform, render_shape, driving_outputs):
driving_outputs_list = []
source_pose_init = source_outputs['pose_params'].clone()
driving_outputs_pose = [outputs['pose_params'] for outputs in driving_outputs]
driving_outputs_pose = self.smooth_params(driving_outputs_pose)
for i, outputs in enumerate(driving_outputs):
outputs['pose_params'] = driving_outputs_pose[i]
source_outputs['expression_params'] = outputs['expression_params']
source_outputs['jaw_params'] = outputs['jaw_params']
source_outputs['eye_pose_params'] = outputs['eye_pose_params']
source_matrix = self.rodrigues_to_matrix(source_pose_init)
driving_matrix_0 = self.rodrigues_to_matrix(driving_outputs[0]['pose_params'])
driving_matrix_i = self.rodrigues_to_matrix(driving_outputs[i]['pose_params'])
relative_rotation = torch.inverse(driving_matrix_0) @ driving_matrix_i
new_rotation = source_matrix @ relative_rotation
source_outputs['pose_params'] = self.matrix_to_rodrigues(new_rotation)
source_outputs['eyelid_params'] = outputs['eyelid_params']
flame_output = self.flame.forward(source_outputs)
renderer_output = self.renderer.forward(
flame_output['vertices'],
source_outputs['cam'],
landmarks_fan=flame_output['landmarks_fan'], source_tform=source_tform,
tform_512=None, weights_468=None, weights_473=None,
landmarks_mp=flame_output['landmarks_mp'], shape=render_shape)
rendered_img = renderer_output['rendered_img']
driving_outputs_list.extend(np.copy(rendered_img)[np.newaxis, :])
return driving_outputs_list
def ensure_even_dimensions(self, frame):
height, width = frame.shape[:2]
new_width = width - (width % 2)
new_height = height - (height % 2)
if new_width != width or new_height != height:
frame = cv2.resize(frame, (new_width, new_height))
return frame
def get_global_bbox(self, frames):
max_x1, max_y1, max_x2, max_y2 = float('inf'), float('inf'), 0, 0
for frame in frames:
faces = self.app.get(frame)
if not faces:
continue
bbox = faces[0]['bbox']
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
max_x1 = min(max_x1, x1)
max_y1 = min(max_y1, y1)
max_x2 = max(max_x2, x2)
max_y2 = max(max_y2, y2)
w = max_x2 - max_x1
h = max_y2 - max_y1
x1 = int(max_x1 - w / 2)
x2 = int(max_x2 + w / 2)
y_offset = (x2 - x1 - h) / 2
y1 = int(max_y1 - y_offset)
y2 = int(max_y2 + y_offset)
if (x2 - x1) > (y2 - y1):
x_comp = int(((x2 - x1) - (y2 - y1)) / 2)
x1 += x_comp
x2 -= x_comp
else:
y_comp = int(((y2 - y1) - (x2 - x1)) / 2)
y1 += y_comp
y2 -= y_comp
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(frames[0].shape[1], x2)
y2 = min(frames[0].shape[0], y2)
return int(x1), int(y1), int(x2), int(y2)
def face_crop_with_global_box(self, image, global_box):
x1, y1, x2, y2 = global_box
return image[y1:y2, x1:x2]
def process_video(self, source_image_path, driving_video_path, output_path, sample_size=[480, 720]):
image = load_image(source_image_path)
image = self.crop_and_resize(image, sample_size[0], sample_size[1])
ref_image = np.array(image)
ref_image, x1, y1 = self.face_crop(ref_image)
face_h, face_w, _ = ref_image.shape
vr = VideoReader(driving_video_path)
fps = vr.get_avg_fps()
video_length = len(vr)
duration = video_length / fps
target_times = np.arange(0, duration, 1/12)
frame_indices = (target_times * fps).astype(np.int32)
frame_indices = frame_indices[frame_indices < video_length]
control_frames = vr.get_batch(frame_indices).asnumpy()[:48]
if len(control_frames) < 49:
video_lenght_add = 49 - len(control_frames)
control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0)
control_frames_crop = []
global_box = self.get_global_bbox(control_frames)
for control_frame in control_frames:
frame = self.face_crop_with_global_box(control_frame, global_box)
control_frames_crop.append(frame)
out_frames = self.preprocess_lmk3d(source_image=ref_image, driving_image_list=control_frames_crop)
def write_mp4(video_path, samples, fps=14):
clip = ImageSequenceClip(samples, fps=fps)
clip.write_videofile(video_path, codec='libx264', ffmpeg_params=["-pix_fmt", "yuv420p", "-crf", "23", "-preset", "medium"])
concat_frames = []
for i in range(len(out_frames)):
ref_image_concat = ref_image.copy()
driving_frame = cv2.resize(control_frames_crop[i], (face_w, face_h))
out_frame = cv2.resize(out_frames[i], (face_w, face_h))
concat_frame = np.concatenate([ref_image_concat, driving_frame, out_frame], axis=1)
concat_frame = self.ensure_even_dimensions(concat_frame)
concat_frames.append(concat_frame)
write_mp4(output_path, concat_frames, fps=12)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process video and image for face animation.")
parser.add_argument('--source_image', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
parser.add_argument('--driving_video', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.')
parser.add_argument('--output_path', type=str, default="./output.mp4", help='Path to save the output video.')
args = parser.parse_args()
processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
processor.process_video(args.source_image, args.driving_video, args.output_path)