import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pytorch3d.structures import Meshes from pytorch3d.io import load_obj from pytorch3d.renderer.mesh import rasterize_meshes import pickle import chumpy as ch import cv2 import sys, os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from skyreels_a1.src.utils.mediapipe_utils import face_vertices, vertex_normals, batch_orth_proj from skyreels_a1.src.media_pipe.draw_util import FaceMeshVisualizer from mediapipe.framework.formats import landmark_pb2 def keep_vertices_and_update_faces(faces, vertices_to_keep): """ Keep specified vertices in the mesh and update the faces. """ if isinstance(vertices_to_keep, list) or isinstance(vertices_to_keep, np.ndarray): vertices_to_keep = torch.tensor(vertices_to_keep, dtype=torch.long) vertices_to_keep = torch.unique(vertices_to_keep) max_vertex_index = faces.max().long().item() + 1 mask = torch.zeros(max_vertex_index, dtype=torch.bool) mask[vertices_to_keep] = True new_vertex_indices = torch.full((max_vertex_index,), -1, dtype=torch.long) new_vertex_indices[mask] = torch.arange(len(vertices_to_keep)) valid_faces_mask = (new_vertex_indices[faces] != -1).all(dim=1) filtered_faces = faces[valid_faces_mask] updated_faces = new_vertex_indices[filtered_faces] return updated_faces def predict_landmark_position(ref_points, relative_coords): """ Predict the new position of the eyeball based on reference points and relative coordinates. """ left_corner = ref_points[0] right_corner = ref_points[8] eye_center = (left_corner + right_corner) / 2 eye_width_vector = right_corner - left_corner eye_width = np.linalg.norm(eye_width_vector) eye_direction = eye_width_vector / eye_width eye_vertical = np.array([-eye_direction[1], eye_direction[0]]) predicted_pos = eye_center + \ (eye_width/2) * relative_coords[0] * eye_direction + \ (eye_width/2) * relative_coords[1] * eye_vertical return predicted_pos def mesh_points_by_barycentric_coordinates(mesh_verts, mesh_faces, lmk_face_idx, lmk_b_coords): """ Evaluation 3d points given mesh and landmark embedding """ dif1 = ch.vstack([ (mesh_verts[mesh_faces[lmk_face_idx], 0] * lmk_b_coords).sum(axis=1), (mesh_verts[mesh_faces[lmk_face_idx], 1] * lmk_b_coords).sum(axis=1), (mesh_verts[mesh_faces[lmk_face_idx], 2] * lmk_b_coords).sum(axis=1) ]).T return dif1 class Renderer(nn.Module): def __init__(self, render_full_head=False, obj_filename='pretrained_models/FLAME/head_template.obj'): super(Renderer, self).__init__() self.image_size = 224 self.mediapipe_landmark_embedding = np.load("pretrained_models/smirk/mediapipe_landmark_embedding.npz") self.vis = FaceMeshVisualizer(forehead_edge=False) verts, faces, aux = load_obj(obj_filename) uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) uvfaces = faces.textures_idx[None, ...] # (N, F, 3) faces = faces.verts_idx[None,...] self.render_full_head = render_full_head red_color = torch.tensor([255, 0, 0])[None, None, :].float() / 255. transparent_color = torch.tensor([0, 0, 0])[None, None, :].float() colors = transparent_color.repeat(1, 5023, 1) flame_masks = pickle.load( open('pretrained_models/FLAME/FLAME_masks.pkl', 'rb'), encoding='latin1') self.flame_masks = flame_masks self.register_buffer('faces', faces) face_colors = face_vertices(colors, faces) self.register_buffer('face_colors', face_colors) self.register_buffer('raw_uvcoords', uvcoords) uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) #[bz, ntv, 3] uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1] face_uvcoords = face_vertices(uvcoords, uvfaces) self.register_buffer('uvcoords', uvcoords) self.register_buffer('uvfaces', uvfaces) self.register_buffer('face_uvcoords', face_uvcoords) pi = np.pi constant_factor = torch.tensor([1/np.sqrt(4*pi), ((2*pi)/3)*(np.sqrt(3/(4*pi))), ((2*pi)/3)*(np.sqrt(3/(4*pi))),\ ((2*pi)/3)*(np.sqrt(3/(4*pi))), (pi/4)*(3)*(np.sqrt(5/(12*pi))), (pi/4)*(3)*(np.sqrt(5/(12*pi))),\ (pi/4)*(3)*(np.sqrt(5/(12*pi))), (pi/4)*(3/2)*(np.sqrt(5/(12*pi))), (pi/4)*(1/2)*(np.sqrt(5/(4*pi)))]).float() self.register_buffer('constant_factor', constant_factor) def forward(self, vertices, cam_params, source_tform=None, tform_512=None, weights_468=None, weights_473=None,shape = None, **landmarks): transformed_vertices = batch_orth_proj(vertices, cam_params) transformed_vertices[:, :, 1:] = -transformed_vertices[:, :, 1:] transformed_landmarks = {} for key in landmarks.keys(): transformed_landmarks[key] = batch_orth_proj(landmarks[key], cam_params) transformed_landmarks[key][:, :, 1:] = - transformed_landmarks[key][:, :, 1:] transformed_landmarks[key] = transformed_landmarks[key][...,:2] # rendered_img = self.render(vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473,shape) if weights_468 is None: rendered_img = self.render_with_pulid_in_vertices(vertices, transformed_vertices, source_tform, tform_512, shape) else: rendered_img = self.render(vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473,shape) outputs = { 'rendered_img': rendered_img, 'transformed_vertices': transformed_vertices } outputs.update(transformed_landmarks) return outputs def _calculate_eye_landmarks(self, landmark_list_pixlevel, weights_468, weights_473, source_tform): # [np.array([x_relative, y_relative]),target_point,ref_points] 根据当前的new_landmarks,根据target_point, 利用映射变换,计算出眼部landmarks import pdb; pdb.set_trace() pass def render_with_pulid_in_vertices(self, vertices, transformed_vertices, source_tform, tform_512, shape): batch_size = vertices.shape[0] ## rasterizer near 0 far 100. move mesh so minz larger than 0 transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10 # import pdb;pdb.set_trace() # 只使用颜色作为attributes colors = self.face_colors.expand(batch_size, -1, -1, -1) # # 加载 mediapipe_landmark_embedding 数据 lmk_b_coords = self.mediapipe_landmark_embedding['lmk_b_coords'] lmk_face_idx = self.mediapipe_landmark_embedding['lmk_face_idx'] # import pdb;pdb.set_trace() # 计算 v_selected v_selected = mesh_points_by_barycentric_coordinates(transformed_vertices.detach().cpu().numpy()[0], self.faces.detach().cpu().numpy()[0], lmk_face_idx, lmk_b_coords) # v_selected 增加对应左眼和右眼的8个位置,序号分别是:[4051, 3997, 3965, 3933, 4020],[4597, 4543, 4511, 4479, 4575],得根据transformed_vertices.detach().cpu().numpy()[0]来获取 v_selected = np.concatenate([v_selected, transformed_vertices.detach().cpu().numpy()[0][[4543, 4511, 4479, 4575]], transformed_vertices.detach().cpu().numpy()[0][[3997, 3965, 3933, 4020]]], axis=0) v_selected_tensor = torch.tensor( np.array(v_selected), dtype=torch.float32).to(transformed_vertices.device) new_landmarks = landmark_pb2.NormalizedLandmarkList() for v in v_selected_tensor: # 将 v 映射到图像坐标 img_x = (v[0] + 1) * 0.5 * self.image_size img_y = ((v[1] + 1) * 0.5) * self.image_size # import pdb;pdb.set_trace() point = np.array([img_x.cpu().numpy(), img_y.cpu().numpy(), 1.0]) croped_point = np.dot(source_tform.inverse.params, point) # original_point = np.dot(tform_512.inverse.params, point) landmark = new_landmarks.landmark.add() landmark.x = croped_point[0]/shape[1] landmark.y = croped_point[1]/shape[0] landmark.z = 1.0 # 将 v 映射到图像坐标 right_eye_x = (transformed_vertices[0,4597,0] + 1) * 0.5 * self.image_size right_eye_y = (transformed_vertices[0,4597,1] + 1) * 0.5 * self.image_size right_eye_point = np.array([right_eye_x.cpu().numpy(), right_eye_y.cpu().numpy(), 1.0]) right_eye_original = np.dot(source_tform.inverse.params, right_eye_point) right_eye_landmarks = right_eye_original[:2] left_eye_x = (transformed_vertices[0,4051,0] + 1) * 0.5 * self.image_size left_eye_y = (transformed_vertices[0,4051,1] + 1) * 0.5 * self.image_size left_eye_point = np.array([left_eye_x.cpu().numpy(), left_eye_y.cpu().numpy(), 1.0]) left_eye_original = np.dot(source_tform.inverse.params, left_eye_point) left_eye_landmarks = left_eye_original[:2] image_new = np.zeros([shape[0],shape[1],3], dtype=np.uint8) self.vis.mp_drawing.draw_landmarks(image=image_new,landmark_list=new_landmarks,connections=self.vis.face_connection_spec.keys(),landmark_drawing_spec=None,connection_drawing_spec=self.vis.face_connection_spec) # 直接设置单个像素点的颜色 left_point = (int(left_eye_landmarks[0]), int(left_eye_landmarks[1])) right_point = (int(right_eye_landmarks[0]), int(right_eye_landmarks[1])) # import pdb;pdb.set_trace() # 左眼点 - 3x3 区域 image_new[left_point[1]-1:left_point[1]+2, left_point[0]-1:left_point[0]+2] = [180, 200, 10] # RGB格式 # 右眼点 - 3x3 区域 image_new[right_point[1]-1:right_point[1]+2, right_point[0]-1:right_point[0]+2] = [10, 200, 180] landmark_58 = new_landmarks.landmark[57] # 因为索引从0开始,所以57表示第58个点 x = int(landmark_58.x * shape[1]) y = int(landmark_58.y * shape[0]) image_new[y-2:y+3, x-2:x+3] = [255, 255, 255] # 设置3x3的白色区域 return np.copy(image_new) def render(self, vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473, shape): # batch_size = vertices.shape[0] transformed_vertices[:,:,2] += 10 # Z-axis offset # colors = self.face_colors.expand(batch_size, -1, -1, -1) # rendering = self.rasterize(transformed_vertices, self.faces.expand(batch_size, -1, -1), colors) v_selected = self._calculate_landmark_points(transformed_vertices) v_selected_tensor = torch.tensor(v_selected, dtype=torch.float32, device=transformed_vertices.device) #torch.Size([113, 3]) # import pdb; pdb.set_trace() new_landmarks, landmark_list_pixlevel = self._create_landmark_list(v_selected_tensor, source_tform, shape) # 基于weights_468和weights_473,计算眼部landmarks left_eye_point_indices = weights_468[3] right_eye_point_indices = weights_473[3] # 遍历每个索引以找到其在 index_mapping 中的位置 left_eye_point_indices = [self.vis.index_mapping.index(idx) for idx in left_eye_point_indices] right_eye_point_indices = [self.vis.index_mapping.index(idx) for idx in right_eye_point_indices] left_eye_point = [landmark_list_pixlevel[idx] for idx in left_eye_point_indices] right_eye_point = [landmark_list_pixlevel[idx] for idx in right_eye_point_indices] # import pdb; pdb.set_trace() # weights_468[2].shape = (16, 2) M_affine_left, _ = cv2.estimateAffine2D(np.array(weights_468[2], dtype=np.float32), np.array(left_eye_point, dtype=np.float32)) M_affine_right, _ = cv2.estimateAffine2D(np.array(weights_473[2], dtype=np.float32), np.array(right_eye_point, dtype=np.float32)) # 计算瞳孔点 pupil_left_eye = cv2.transform(weights_468[1].reshape(1, 1, 2), M_affine_left).reshape(-1) pupil_right_eye = cv2.transform(weights_473[1].reshape(1, 1, 2), M_affine_right).reshape(-1) # left_eye_point, right_eye_point = self._calculate_eye_landmarks(landmark_list_pixlevel, weights_468, weights_473, source_tform) # left_eye_point, right_eye_point = self._process_eye_landmarks(transformed_vertices, source_tform) # import pdb; pdb.set_trace() return self._generate_final_image(new_landmarks, pupil_left_eye, pupil_right_eye, shape) # return self._generate_final_image(new_landmarks, left_eye_point, right_eye_point, shape) def _calculate_landmark_points(self, transformed_vertices): lmk_b_coords = self.mediapipe_landmark_embedding['lmk_b_coords'] lmk_face_idx = self.mediapipe_landmark_embedding['lmk_face_idx'] base_points = mesh_points_by_barycentric_coordinates( transformed_vertices.detach().cpu().numpy()[0], self.faces.detach().cpu().numpy()[0], lmk_face_idx, lmk_b_coords ) RIGHT_EYE_INDICES = [4543, 4511, 4479, 4575] LEFT_EYE_INDICES = [3997, 3965, 3933, 4020] return np.concatenate([ base_points, transformed_vertices.detach().cpu().numpy()[0][RIGHT_EYE_INDICES], transformed_vertices.detach().cpu().numpy()[0][LEFT_EYE_INDICES] ], axis=0) def _create_landmark_list(self, vertices, transform, shape): landmark_list = landmark_pb2.NormalizedLandmarkList() landmark_list_pixlevel = [] for v in vertices: img_x = (v[0] + 1) * 0.5 * self.image_size img_y = (v[1] + 1) * 0.5 * self.image_size projected = np.dot(transform.inverse.params, [img_x.cpu().numpy(), img_y.cpu().numpy(), 1.0]) landmark_list_pixlevel.append((projected[0], projected[1])) landmark = landmark_list.landmark.add() landmark.x = projected[0] / shape[1] landmark.y = projected[1] / shape[0] landmark.z = 1.0 return landmark_list, landmark_list_pixlevel def _process_eye_landmarks(self, vertices, transform): def project_eye_point(vertex_idx): x = (vertices[0, vertex_idx, 0] + 1) * 0.5 * self.image_size y = (vertices[0, vertex_idx, 1] + 1) * 0.5 * self.image_size # import pdb; pdb.set_trace() projected = np.dot(transform.inverse.params, [x.cpu().numpy(), y.cpu().numpy(), 1.0]) return (int(projected[0]), int(projected[1])) return ( project_eye_point(4051), # Left eye index project_eye_point(4597) # Right eye index ) def _generate_final_image(self, landmarks, left_eye, right_eye, shape): image = np.zeros([shape[0], shape[1], 3], dtype=np.uint8) self.vis.mp_drawing.draw_landmarks( image=image, landmark_list=landmarks, connections=self.vis.face_connection_spec.keys(), landmark_drawing_spec=None, connection_drawing_spec=self.vis.face_connection_spec ) self._draw_eye_markers(image, np.array(left_eye, dtype=np.int32), np.array(right_eye, dtype=np.int32)) self._draw_landmark_58(image, landmarks, shape) return np.copy(image) def _draw_eye_markers(self, image, left_eye, right_eye): y, x = left_eye[1]-1, left_eye[0]-1 image[y:y+3, x:x+3] = [10, 200, 250] y, x = right_eye[1]-1, right_eye[0]-1 image[y:y+3, x:x+3] = [250, 200, 10] def _draw_landmark_58(self, image, landmarks, shape): if len(landmarks.landmark) > 57: point = landmarks.landmark[57] x = int(point.x * shape[1]) y = int(point.y * shape[0]) image[y-2:y+3, x-2:x+3] = [255, 255, 255] def rasterize(self, vertices, faces, attributes=None, h=None, w=None): fixed_vertices = vertices.clone() fixed_vertices[...,:2] = -fixed_vertices[...,:2] if h is None and w is None: image_size = self.image_size else: image_size = [h, w] if h>w: fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w else: fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( meshes_screen, image_size=image_size, blur_radius=0.0, faces_per_pixel=1, bin_size=None, max_faces_per_bin=None, perspective_correct=False, ) vismask = (pix_to_face > -1).float() D = attributes.shape[-1] attributes = attributes.clone(); attributes = attributes.view(attributes.shape[0]*attributes.shape[1], 3, attributes.shape[-1]) N, H, W, K, _ = bary_coords.shape mask = pix_to_face == -1 pix_to_face = pix_to_face.clone() pix_to_face[mask] = 0 idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) pixel_vals[mask] = 0 # Replace masked values in output. pixel_vals = pixel_vals[:,:,:,0].permute(0,3,1,2) pixel_vals = torch.cat([pixel_vals, vismask[:,:,:,0][:,None,:,:]], dim=1) return pixel_vals def add_SHlight(self, normal_images, sh_coeff): ''' sh_coeff: [bz, 9, 3] ''' N = normal_images sh = torch.stack([ N[:,0]*0.+1., N[:,0], N[:,1], \ N[:,2], N[:,0]*N[:,1], N[:,0]*N[:,2], N[:,1]*N[:,2], N[:,0]**2 - N[:,1]**2, 3*(N[:,2]**2) - 1 ], 1) # [bz, 9, h, w] sh = sh*self.constant_factor[None,:,None,None] shading = torch.sum(sh_coeff[:,:,:,None,None]*sh[:,:,None,:,:], 1) # [bz, 9, 3, h, w] return shading def add_pointlight(self, vertices, normals, lights): ''' vertices: [bz, nv, 3] lights: [bz, nlight, 6] returns: shading: [bz, nv, 3] ''' light_positions = lights[:,:,:3]; light_intensities = lights[:,:,3:] directions_to_lights = F.normalize(light_positions[:,:,None,:] - vertices[:,None,:,:], dim=3) normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:] return shading.mean(1) def add_directionlight(self, normals, lights): ''' normals: [bz, nv, 3] lights: [bz, nlight, 6] returns: shading: [bz, nv, 3] ''' light_direction = lights[:,:,:3]; light_intensities = lights[:,:,3:] directions_to_lights = F.normalize(light_direction[:,:,None,:].expand(-1,-1,normals.shape[1],-1), dim=3) normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:] return shading.mean(1) def render_multiface(self, vertices, transformed_vertices, faces): batch_size = vertices.shape[0] light_positions = torch.tensor( [ [-1,-1,-1], [1,-1,-1], [-1,+1,-1], [1,+1,-1], [0,0,-1] ] )[None,:,:].expand(batch_size, -1, -1).float() light_intensities = torch.ones_like(light_positions).float()*1.7 lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10 normals = vertex_normals(vertices, faces) face_normals = face_vertices(normals, faces) colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, transformed_vertices.shape[1]+1, 1).float()/255. colors = colors.cuda() face_colors = face_vertices(colors, faces[0].unsqueeze(0)) colors = face_colors.expand(batch_size, -1, -1, -1) attributes = torch.cat([colors, face_normals], -1) rendering = self.rasterize(transformed_vertices, faces, attributes) albedo_images = rendering[:, :3, :, :] normal_images = rendering[:, 3:6, :, :] shading = self.add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights) shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2).contiguous() shaded_images = albedo_images*shading_images return shaded_images