|
import pytorch_lightning as pl |
|
import torch.optim as optim |
|
from lib.model.v2a import V2A |
|
from lib.model.body_model_params import BodyModelParams |
|
from lib.model.deformer import SMPLDeformer |
|
import cv2 |
|
import torch |
|
from lib.model.loss import Loss |
|
import hydra |
|
import os |
|
import numpy as np |
|
from lib.utils.meshing import generate_mesh |
|
from kaolin.ops.mesh import index_vertices_by_faces |
|
import trimesh |
|
from lib.model.deformer import skinning |
|
from lib.utils import utils |
|
class V2AModel(pl.LightningModule): |
|
def __init__(self, opt) -> None: |
|
super().__init__() |
|
|
|
self.opt = opt |
|
num_training_frames = opt.dataset.metainfo.end_frame - opt.dataset.metainfo.start_frame |
|
self.betas_path = os.path.join(hydra.utils.to_absolute_path('..'), 'data', opt.dataset.metainfo.data_dir, 'mean_shape.npy') |
|
self.gender = opt.dataset.metainfo.gender |
|
self.model = V2A(opt.model, self.betas_path, self.gender, num_training_frames) |
|
self.start_frame = opt.dataset.metainfo.start_frame |
|
self.end_frame = opt.dataset.metainfo.end_frame |
|
self.training_modules = ["model"] |
|
|
|
self.training_indices = list(range(self.start_frame, self.end_frame)) |
|
self.body_model_params = BodyModelParams(num_training_frames, model_type='smpl') |
|
self.load_body_model_params() |
|
optim_params = self.body_model_params.param_names |
|
for param_name in optim_params: |
|
self.body_model_params.set_requires_grad(param_name, requires_grad=True) |
|
self.training_modules += ['body_model_params'] |
|
|
|
self.loss = Loss(opt.model.loss) |
|
|
|
def load_body_model_params(self): |
|
body_model_params = {param_name: [] for param_name in self.body_model_params.param_names} |
|
data_root = os.path.join('../data', self.opt.dataset.metainfo.data_dir) |
|
data_root = hydra.utils.to_absolute_path(data_root) |
|
|
|
body_model_params['betas'] = torch.tensor(np.load(os.path.join(data_root, 'mean_shape.npy'))[None], dtype=torch.float32) |
|
body_model_params['global_orient'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices][:, :3], dtype=torch.float32) |
|
body_model_params['body_pose'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices] [:, 3:], dtype=torch.float32) |
|
body_model_params['transl'] = torch.tensor(np.load(os.path.join(data_root, 'normalize_trans.npy'))[self.training_indices], dtype=torch.float32) |
|
|
|
for param_name in body_model_params.keys(): |
|
self.body_model_params.init_parameters(param_name, body_model_params[param_name], requires_grad=False) |
|
|
|
def configure_optimizers(self): |
|
params = [{'params': self.model.parameters(), 'lr':self.opt.model.learning_rate}] |
|
params.append({'params': self.body_model_params.parameters(), 'lr':self.opt.model.learning_rate*0.1}) |
|
self.optimizer = optim.Adam(params, lr=self.opt.model.learning_rate, eps=1e-8) |
|
self.scheduler = optim.lr_scheduler.MultiStepLR( |
|
self.optimizer, milestones=self.opt.model.sched_milestones, gamma=self.opt.model.sched_factor) |
|
return [self.optimizer], [self.scheduler] |
|
|
|
def training_step(self, batch): |
|
inputs, targets = batch |
|
|
|
batch_idx = inputs["idx"] |
|
|
|
body_model_params = self.body_model_params(batch_idx) |
|
inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
|
inputs['smpl_shape'] = body_model_params['betas'] |
|
inputs['smpl_trans'] = body_model_params['transl'] |
|
|
|
inputs['current_epoch'] = self.current_epoch |
|
model_outputs = self.model(inputs) |
|
|
|
loss_output = self.loss(model_outputs, targets) |
|
for k, v in loss_output.items(): |
|
if k in ["loss"]: |
|
self.log(k, v.item(), prog_bar=True, on_step=True) |
|
else: |
|
self.log(k, v.item(), prog_bar=True, on_step=True) |
|
return loss_output["loss"] |
|
|
|
def training_epoch_end(self, outputs) -> None: |
|
|
|
if self.current_epoch != 0 and self.current_epoch % 20 == 0: |
|
cond = {'smpl': torch.zeros(1, 69).float().cuda()} |
|
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=2) |
|
self.model.mesh_v_cano = torch.tensor(mesh_canonical.vertices[None], device = self.model.smpl_v_cano.device).float() |
|
self.model.mesh_f_cano = torch.tensor(mesh_canonical.faces.astype(np.int64), device=self.model.smpl_v_cano.device) |
|
self.model.mesh_face_vertices = index_vertices_by_faces(self.model.mesh_v_cano, self.model.mesh_f_cano) |
|
return super().training_epoch_end(outputs) |
|
|
|
def query_oc(self, x, cond): |
|
|
|
x = x.reshape(-1, 3) |
|
mnfld_pred = self.model.implicit_network(x, cond)[:,:,0].reshape(-1,1) |
|
return {'sdf':mnfld_pred} |
|
|
|
def query_wc(self, x): |
|
|
|
x = x.reshape(-1, 3) |
|
w = self.model.deformer.query_weights(x) |
|
|
|
return w |
|
|
|
def query_od(self, x, cond, smpl_tfs, smpl_verts): |
|
|
|
x = x.reshape(-1, 3) |
|
x_c, _ = self.model.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts) |
|
output = self.model.implicit_network(x_c, cond)[0] |
|
sdf = output[:, 0:1] |
|
|
|
return {'sdf': sdf} |
|
|
|
def get_deformed_mesh_fast_mode(self, verts, smpl_tfs): |
|
verts = torch.tensor(verts).cuda().float() |
|
weights = self.model.deformer.query_weights(verts) |
|
verts_deformed = skinning(verts.unsqueeze(0), weights, smpl_tfs).data.cpu().numpy()[0] |
|
return verts_deformed |
|
|
|
def validation_step(self, batch, *args, **kwargs): |
|
|
|
output = {} |
|
inputs, targets = batch |
|
inputs['current_epoch'] = self.current_epoch |
|
self.model.eval() |
|
|
|
body_model_params = self.body_model_params(inputs['image_id']) |
|
inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
|
inputs['smpl_shape'] = body_model_params['betas'] |
|
inputs['smpl_trans'] = body_model_params['transl'] |
|
|
|
cond = {'smpl': inputs["smpl_pose"][:, 3:]/np.pi} |
|
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=3) |
|
|
|
mesh_canonical = trimesh.Trimesh(mesh_canonical.vertices, mesh_canonical.faces) |
|
|
|
output.update({ |
|
'canonical_mesh':mesh_canonical |
|
}) |
|
|
|
split = utils.split_input(inputs, targets["total_pixels"][0], n_pixels=min(targets['pixel_per_batch'], targets["img_size"][0] * targets["img_size"][1])) |
|
|
|
res = [] |
|
for s in split: |
|
|
|
out = self.model(s) |
|
|
|
for k, v in out.items(): |
|
try: |
|
out[k] = v.detach() |
|
except: |
|
out[k] = v |
|
|
|
res.append({ |
|
'rgb_values': out['rgb_values'].detach(), |
|
'normal_values': out['normal_values'].detach(), |
|
'fg_rgb_values': out['fg_rgb_values'].detach(), |
|
}) |
|
batch_size = targets['rgb'].shape[0] |
|
|
|
model_outputs = utils.merge_output(res, targets["total_pixels"][0], batch_size) |
|
|
|
output.update({ |
|
"rgb_values": model_outputs["rgb_values"].detach().clone(), |
|
"normal_values": model_outputs["normal_values"].detach().clone(), |
|
"fg_rgb_values": model_outputs["fg_rgb_values"].detach().clone(), |
|
**targets, |
|
}) |
|
|
|
return output |
|
|
|
def validation_step_end(self, batch_parts): |
|
return batch_parts |
|
|
|
def validation_epoch_end(self, outputs) -> None: |
|
img_size = outputs[0]["img_size"] |
|
|
|
rgb_pred = torch.cat([output["rgb_values"] for output in outputs], dim=0) |
|
rgb_pred = rgb_pred.reshape(*img_size, -1) |
|
|
|
fg_rgb_pred = torch.cat([output["fg_rgb_values"] for output in outputs], dim=0) |
|
fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) |
|
|
|
normal_pred = torch.cat([output["normal_values"] for output in outputs], dim=0) |
|
normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 |
|
|
|
rgb_gt = torch.cat([output["rgb"] for output in outputs], dim=1).squeeze(0) |
|
rgb_gt = rgb_gt.reshape(*img_size, -1) |
|
if 'normal' in outputs[0].keys(): |
|
normal_gt = torch.cat([output["normal"] for output in outputs], dim=1).squeeze(0) |
|
normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 |
|
normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() |
|
else: |
|
normal = torch.cat([normal_pred], dim=0).cpu().numpy() |
|
|
|
rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() |
|
rgb = (rgb * 255).astype(np.uint8) |
|
|
|
fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() |
|
fg_rgb = (fg_rgb * 255).astype(np.uint8) |
|
|
|
normal = (normal * 255).astype(np.uint8) |
|
|
|
os.makedirs("rendering", exist_ok=True) |
|
os.makedirs("normal", exist_ok=True) |
|
os.makedirs('fg_rendering', exist_ok=True) |
|
|
|
canonical_mesh = outputs[0]['canonical_mesh'] |
|
canonical_mesh.export(f"rendering/{self.current_epoch}.ply") |
|
|
|
cv2.imwrite(f"rendering/{self.current_epoch}.png", rgb[:, :, ::-1]) |
|
cv2.imwrite(f"normal/{self.current_epoch}.png", normal[:, :, ::-1]) |
|
cv2.imwrite(f"fg_rendering/{self.current_epoch}.png", fg_rgb[:, :, ::-1]) |
|
|
|
def test_step(self, batch, *args, **kwargs): |
|
inputs, targets, pixel_per_batch, total_pixels, idx = batch |
|
num_splits = (total_pixels + pixel_per_batch - |
|
1) // pixel_per_batch |
|
results = [] |
|
|
|
scale, smpl_trans, smpl_pose, smpl_shape = torch.split(inputs["smpl_params"], [1, 3, 72, 10], dim=1) |
|
|
|
body_model_params = self.body_model_params(inputs['idx']) |
|
smpl_shape = body_model_params['betas'] if body_model_params['betas'].dim() == 2 else body_model_params['betas'].unsqueeze(0) |
|
smpl_trans = body_model_params['transl'] |
|
smpl_pose = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
|
|
|
smpl_outputs = self.model.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) |
|
smpl_tfs = smpl_outputs['smpl_tfs'] |
|
cond = {'smpl': smpl_pose[:, 3:]/np.pi} |
|
|
|
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=4) |
|
self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender, K=7) |
|
verts_deformed = self.get_deformed_mesh_fast_mode(mesh_canonical.vertices, smpl_tfs) |
|
mesh_deformed = trimesh.Trimesh(vertices=verts_deformed, faces=mesh_canonical.faces, process=False) |
|
|
|
os.makedirs("test_mask", exist_ok=True) |
|
os.makedirs("test_rendering", exist_ok=True) |
|
os.makedirs("test_fg_rendering", exist_ok=True) |
|
os.makedirs("test_normal", exist_ok=True) |
|
os.makedirs("test_mesh", exist_ok=True) |
|
|
|
mesh_canonical.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_canonical.ply") |
|
mesh_deformed.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_deformed.ply") |
|
self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender) |
|
for i in range(num_splits): |
|
indices = list(range(i * pixel_per_batch, |
|
min((i + 1) * pixel_per_batch, total_pixels))) |
|
batch_inputs = {"uv": inputs["uv"][:, indices], |
|
"intrinsics": inputs['intrinsics'], |
|
"pose": inputs['pose'], |
|
"smpl_params": inputs["smpl_params"], |
|
"smpl_pose": inputs["smpl_params"][:, 4:76], |
|
"smpl_shape": inputs["smpl_params"][:, 76:], |
|
"smpl_trans": inputs["smpl_params"][:, 1:4], |
|
"idx": inputs["idx"] if 'idx' in inputs.keys() else None} |
|
|
|
body_model_params = self.body_model_params(inputs['idx']) |
|
|
|
batch_inputs.update({'smpl_pose': torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)}) |
|
batch_inputs.update({'smpl_shape': body_model_params['betas']}) |
|
batch_inputs.update({'smpl_trans': body_model_params['transl']}) |
|
|
|
batch_targets = {"rgb": targets["rgb"][:, indices].detach().clone() if 'rgb' in targets.keys() else None, |
|
"img_size": targets["img_size"]} |
|
|
|
with torch.no_grad(): |
|
model_outputs = self.model(batch_inputs) |
|
results.append({"rgb_values":model_outputs["rgb_values"].detach().clone(), |
|
"fg_rgb_values":model_outputs["fg_rgb_values"].detach().clone(), |
|
"normal_values": model_outputs["normal_values"].detach().clone(), |
|
"acc_map": model_outputs["acc_map"].detach().clone(), |
|
**batch_targets}) |
|
|
|
img_size = results[0]["img_size"] |
|
rgb_pred = torch.cat([result["rgb_values"] for result in results], dim=0) |
|
rgb_pred = rgb_pred.reshape(*img_size, -1) |
|
|
|
fg_rgb_pred = torch.cat([result["fg_rgb_values"] for result in results], dim=0) |
|
fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) |
|
|
|
normal_pred = torch.cat([result["normal_values"] for result in results], dim=0) |
|
normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 |
|
|
|
pred_mask = torch.cat([result["acc_map"] for result in results], dim=0) |
|
pred_mask = pred_mask.reshape(*img_size, -1) |
|
|
|
if results[0]['rgb'] is not None: |
|
rgb_gt = torch.cat([result["rgb"] for result in results], dim=1).squeeze(0) |
|
rgb_gt = rgb_gt.reshape(*img_size, -1) |
|
rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() |
|
else: |
|
rgb = torch.cat([rgb_pred], dim=0).cpu().numpy() |
|
if 'normal' in results[0].keys(): |
|
normal_gt = torch.cat([result["normal"] for result in results], dim=1).squeeze(0) |
|
normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 |
|
normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() |
|
else: |
|
normal = torch.cat([normal_pred], dim=0).cpu().numpy() |
|
|
|
rgb = (rgb * 255).astype(np.uint8) |
|
|
|
fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() |
|
fg_rgb = (fg_rgb * 255).astype(np.uint8) |
|
|
|
normal = (normal * 255).astype(np.uint8) |
|
|
|
cv2.imwrite(f"test_mask/{int(idx.cpu().numpy()):04d}.png", pred_mask.cpu().numpy() * 255) |
|
cv2.imwrite(f"test_rendering/{int(idx.cpu().numpy()):04d}.png", rgb[:, :, ::-1]) |
|
cv2.imwrite(f"test_normal/{int(idx.cpu().numpy()):04d}.png", normal[:, :, ::-1]) |
|
cv2.imwrite(f"test_fg_rendering/{int(idx.cpu().numpy()):04d}.png", fg_rgb[:, :, ::-1]) |