import pytorch_lightning as pl import torch.optim as optim from lib.model.v2a import V2A from lib.model.body_model_params import BodyModelParams from lib.model.deformer import SMPLDeformer import cv2 import torch from lib.model.loss import Loss import hydra import os import numpy as np from lib.utils.meshing import generate_mesh from kaolin.ops.mesh import index_vertices_by_faces import trimesh from lib.model.deformer import skinning from lib.utils import utils class V2AModel(pl.LightningModule): def __init__(self, opt) -> None: super().__init__() self.opt = opt num_training_frames = opt.dataset.metainfo.end_frame - opt.dataset.metainfo.start_frame self.betas_path = os.path.join(hydra.utils.to_absolute_path('..'), 'data', opt.dataset.metainfo.data_dir, 'mean_shape.npy') self.gender = opt.dataset.metainfo.gender self.model = V2A(opt.model, self.betas_path, self.gender, num_training_frames) self.start_frame = opt.dataset.metainfo.start_frame self.end_frame = opt.dataset.metainfo.end_frame self.training_modules = ["model"] self.training_indices = list(range(self.start_frame, self.end_frame)) self.body_model_params = BodyModelParams(num_training_frames, model_type='smpl') self.load_body_model_params() optim_params = self.body_model_params.param_names for param_name in optim_params: self.body_model_params.set_requires_grad(param_name, requires_grad=True) self.training_modules += ['body_model_params'] self.loss = Loss(opt.model.loss) def load_body_model_params(self): body_model_params = {param_name: [] for param_name in self.body_model_params.param_names} data_root = os.path.join('../data', self.opt.dataset.metainfo.data_dir) data_root = hydra.utils.to_absolute_path(data_root) body_model_params['betas'] = torch.tensor(np.load(os.path.join(data_root, 'mean_shape.npy'))[None], dtype=torch.float32) body_model_params['global_orient'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices][:, :3], dtype=torch.float32) body_model_params['body_pose'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices] [:, 3:], dtype=torch.float32) body_model_params['transl'] = torch.tensor(np.load(os.path.join(data_root, 'normalize_trans.npy'))[self.training_indices], dtype=torch.float32) for param_name in body_model_params.keys(): self.body_model_params.init_parameters(param_name, body_model_params[param_name], requires_grad=False) def configure_optimizers(self): params = [{'params': self.model.parameters(), 'lr':self.opt.model.learning_rate}] params.append({'params': self.body_model_params.parameters(), 'lr':self.opt.model.learning_rate*0.1}) self.optimizer = optim.Adam(params, lr=self.opt.model.learning_rate, eps=1e-8) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.opt.model.sched_milestones, gamma=self.opt.model.sched_factor) return [self.optimizer], [self.scheduler] def training_step(self, batch): inputs, targets = batch batch_idx = inputs["idx"] body_model_params = self.body_model_params(batch_idx) inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) inputs['smpl_shape'] = body_model_params['betas'] inputs['smpl_trans'] = body_model_params['transl'] inputs['current_epoch'] = self.current_epoch model_outputs = self.model(inputs) loss_output = self.loss(model_outputs, targets) for k, v in loss_output.items(): if k in ["loss"]: self.log(k, v.item(), prog_bar=True, on_step=True) else: self.log(k, v.item(), prog_bar=True, on_step=True) return loss_output["loss"] def training_epoch_end(self, outputs) -> None: # Canonical mesh update every 20 epochs if self.current_epoch != 0 and self.current_epoch % 20 == 0: cond = {'smpl': torch.zeros(1, 69).float().cuda()} mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=2) self.model.mesh_v_cano = torch.tensor(mesh_canonical.vertices[None], device = self.model.smpl_v_cano.device).float() self.model.mesh_f_cano = torch.tensor(mesh_canonical.faces.astype(np.int64), device=self.model.smpl_v_cano.device) self.model.mesh_face_vertices = index_vertices_by_faces(self.model.mesh_v_cano, self.model.mesh_f_cano) return super().training_epoch_end(outputs) def query_oc(self, x, cond): x = x.reshape(-1, 3) mnfld_pred = self.model.implicit_network(x, cond)[:,:,0].reshape(-1,1) return {'sdf':mnfld_pred} def query_wc(self, x): x = x.reshape(-1, 3) w = self.model.deformer.query_weights(x) return w def query_od(self, x, cond, smpl_tfs, smpl_verts): x = x.reshape(-1, 3) x_c, _ = self.model.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts) output = self.model.implicit_network(x_c, cond)[0] sdf = output[:, 0:1] return {'sdf': sdf} def get_deformed_mesh_fast_mode(self, verts, smpl_tfs): verts = torch.tensor(verts).cuda().float() weights = self.model.deformer.query_weights(verts) verts_deformed = skinning(verts.unsqueeze(0), weights, smpl_tfs).data.cpu().numpy()[0] return verts_deformed def validation_step(self, batch, *args, **kwargs): output = {} inputs, targets = batch inputs['current_epoch'] = self.current_epoch self.model.eval() body_model_params = self.body_model_params(inputs['image_id']) inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) inputs['smpl_shape'] = body_model_params['betas'] inputs['smpl_trans'] = body_model_params['transl'] cond = {'smpl': inputs["smpl_pose"][:, 3:]/np.pi} mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=3) mesh_canonical = trimesh.Trimesh(mesh_canonical.vertices, mesh_canonical.faces) output.update({ 'canonical_mesh':mesh_canonical }) split = utils.split_input(inputs, targets["total_pixels"][0], n_pixels=min(targets['pixel_per_batch'], targets["img_size"][0] * targets["img_size"][1])) res = [] for s in split: out = self.model(s) for k, v in out.items(): try: out[k] = v.detach() except: out[k] = v res.append({ 'rgb_values': out['rgb_values'].detach(), 'normal_values': out['normal_values'].detach(), 'fg_rgb_values': out['fg_rgb_values'].detach(), }) batch_size = targets['rgb'].shape[0] model_outputs = utils.merge_output(res, targets["total_pixels"][0], batch_size) output.update({ "rgb_values": model_outputs["rgb_values"].detach().clone(), "normal_values": model_outputs["normal_values"].detach().clone(), "fg_rgb_values": model_outputs["fg_rgb_values"].detach().clone(), **targets, }) return output def validation_step_end(self, batch_parts): return batch_parts def validation_epoch_end(self, outputs) -> None: img_size = outputs[0]["img_size"] rgb_pred = torch.cat([output["rgb_values"] for output in outputs], dim=0) rgb_pred = rgb_pred.reshape(*img_size, -1) fg_rgb_pred = torch.cat([output["fg_rgb_values"] for output in outputs], dim=0) fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) normal_pred = torch.cat([output["normal_values"] for output in outputs], dim=0) normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 rgb_gt = torch.cat([output["rgb"] for output in outputs], dim=1).squeeze(0) rgb_gt = rgb_gt.reshape(*img_size, -1) if 'normal' in outputs[0].keys(): normal_gt = torch.cat([output["normal"] for output in outputs], dim=1).squeeze(0) normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() else: normal = torch.cat([normal_pred], dim=0).cpu().numpy() rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() rgb = (rgb * 255).astype(np.uint8) fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() fg_rgb = (fg_rgb * 255).astype(np.uint8) normal = (normal * 255).astype(np.uint8) os.makedirs("rendering", exist_ok=True) os.makedirs("normal", exist_ok=True) os.makedirs('fg_rendering', exist_ok=True) canonical_mesh = outputs[0]['canonical_mesh'] canonical_mesh.export(f"rendering/{self.current_epoch}.ply") cv2.imwrite(f"rendering/{self.current_epoch}.png", rgb[:, :, ::-1]) cv2.imwrite(f"normal/{self.current_epoch}.png", normal[:, :, ::-1]) cv2.imwrite(f"fg_rendering/{self.current_epoch}.png", fg_rgb[:, :, ::-1]) def test_step(self, batch, *args, **kwargs): inputs, targets, pixel_per_batch, total_pixels, idx = batch num_splits = (total_pixels + pixel_per_batch - 1) // pixel_per_batch results = [] scale, smpl_trans, smpl_pose, smpl_shape = torch.split(inputs["smpl_params"], [1, 3, 72, 10], dim=1) body_model_params = self.body_model_params(inputs['idx']) smpl_shape = body_model_params['betas'] if body_model_params['betas'].dim() == 2 else body_model_params['betas'].unsqueeze(0) smpl_trans = body_model_params['transl'] smpl_pose = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) smpl_outputs = self.model.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) smpl_tfs = smpl_outputs['smpl_tfs'] cond = {'smpl': smpl_pose[:, 3:]/np.pi} mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=4) self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender, K=7) verts_deformed = self.get_deformed_mesh_fast_mode(mesh_canonical.vertices, smpl_tfs) mesh_deformed = trimesh.Trimesh(vertices=verts_deformed, faces=mesh_canonical.faces, process=False) os.makedirs("test_mask", exist_ok=True) os.makedirs("test_rendering", exist_ok=True) os.makedirs("test_fg_rendering", exist_ok=True) os.makedirs("test_normal", exist_ok=True) os.makedirs("test_mesh", exist_ok=True) mesh_canonical.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_canonical.ply") mesh_deformed.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_deformed.ply") self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender) for i in range(num_splits): indices = list(range(i * pixel_per_batch, min((i + 1) * pixel_per_batch, total_pixels))) batch_inputs = {"uv": inputs["uv"][:, indices], "intrinsics": inputs['intrinsics'], "pose": inputs['pose'], "smpl_params": inputs["smpl_params"], "smpl_pose": inputs["smpl_params"][:, 4:76], "smpl_shape": inputs["smpl_params"][:, 76:], "smpl_trans": inputs["smpl_params"][:, 1:4], "idx": inputs["idx"] if 'idx' in inputs.keys() else None} body_model_params = self.body_model_params(inputs['idx']) batch_inputs.update({'smpl_pose': torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)}) batch_inputs.update({'smpl_shape': body_model_params['betas']}) batch_inputs.update({'smpl_trans': body_model_params['transl']}) batch_targets = {"rgb": targets["rgb"][:, indices].detach().clone() if 'rgb' in targets.keys() else None, "img_size": targets["img_size"]} with torch.no_grad(): model_outputs = self.model(batch_inputs) results.append({"rgb_values":model_outputs["rgb_values"].detach().clone(), "fg_rgb_values":model_outputs["fg_rgb_values"].detach().clone(), "normal_values": model_outputs["normal_values"].detach().clone(), "acc_map": model_outputs["acc_map"].detach().clone(), **batch_targets}) img_size = results[0]["img_size"] rgb_pred = torch.cat([result["rgb_values"] for result in results], dim=0) rgb_pred = rgb_pred.reshape(*img_size, -1) fg_rgb_pred = torch.cat([result["fg_rgb_values"] for result in results], dim=0) fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) normal_pred = torch.cat([result["normal_values"] for result in results], dim=0) normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 pred_mask = torch.cat([result["acc_map"] for result in results], dim=0) pred_mask = pred_mask.reshape(*img_size, -1) if results[0]['rgb'] is not None: rgb_gt = torch.cat([result["rgb"] for result in results], dim=1).squeeze(0) rgb_gt = rgb_gt.reshape(*img_size, -1) rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() else: rgb = torch.cat([rgb_pred], dim=0).cpu().numpy() if 'normal' in results[0].keys(): normal_gt = torch.cat([result["normal"] for result in results], dim=1).squeeze(0) normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() else: normal = torch.cat([normal_pred], dim=0).cpu().numpy() rgb = (rgb * 255).astype(np.uint8) fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() fg_rgb = (fg_rgb * 255).astype(np.uint8) normal = (normal * 255).astype(np.uint8) cv2.imwrite(f"test_mask/{int(idx.cpu().numpy()):04d}.png", pred_mask.cpu().numpy() * 255) cv2.imwrite(f"test_rendering/{int(idx.cpu().numpy()):04d}.png", rgb[:, :, ::-1]) cv2.imwrite(f"test_normal/{int(idx.cpu().numpy()):04d}.png", normal[:, :, ::-1]) cv2.imwrite(f"test_fg_rendering/{int(idx.cpu().numpy()):04d}.png", fg_rgb[:, :, ::-1])