#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This example demonstrates scene optimization with the PyTorch3D
pulsar interface. For this, a reference image has been pre-generated
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`).
The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful
scene representation.
"""
import logging
import math

import cv2
import imageio
import numpy as np
import torch
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.renderer.points import (
    PointsRasterizationSettings,
    PointsRasterizer,
    PulsarPointsRenderer,
)
from pytorch3d.structures.pointclouds import Pointclouds
from torch import nn, optim


LOGGER = logging.getLogger(__name__)
N_POINTS = 10_000
WIDTH = 1_000
HEIGHT = 1_000
DEVICE = torch.device("cuda")


class SceneModel(nn.Module):
    """
    A simple scene model to demonstrate use of pulsar in PyTorch modules.

    The scene model is parameterized with sphere locations (vert_pos),
    channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
    camera rotation (cam_rot) and sensor focal length and width (cam_sensor).

    The forward method of the model renders this scene description. Any
    of these parameters could instead be passed as inputs to the forward
    method and come from a different model.
    """

    def __init__(self):
        super(SceneModel, self).__init__()
        self.gamma = 1.0
        # Points.
        torch.manual_seed(1)
        vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 10.0
        vert_pos[:, 2] += 25.0
        vert_pos[:, :2] -= 5.0
        self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
        self.register_parameter(
            "vert_col",
            nn.Parameter(
                torch.ones(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 0.5,
                requires_grad=True,
            ),
        )
        self.register_parameter(
            "vert_rad",
            nn.Parameter(
                torch.ones(N_POINTS, dtype=torch.float32) * 0.3, requires_grad=True
            ),
        )
        self.register_buffer(
            "cam_params",
            torch.tensor(
                [0.0, 0.0, 0.0, 0.0, math.pi, 0.0, 5.0, 2.0], dtype=torch.float32
            ),
        )
        self.cameras = PerspectiveCameras(
            # The focal length must be double the size for PyTorch3D because of the NDC
            # coordinates spanning a range of two - and they must be normalized by the
            # sensor width (see the pulsar example). This means we need here
            # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar.
            focal_length=5.0,
            R=torch.eye(3, dtype=torch.float32, device=DEVICE)[None, ...],
            T=torch.zeros((1, 3), dtype=torch.float32, device=DEVICE),
            image_size=((HEIGHT, WIDTH),),
            device=DEVICE,
        )
        raster_settings = PointsRasterizationSettings(
            image_size=(HEIGHT, WIDTH),
            radius=self.vert_rad,
        )
        rasterizer = PointsRasterizer(
            cameras=self.cameras, raster_settings=raster_settings
        )
        self.renderer = PulsarPointsRenderer(rasterizer=rasterizer, n_track=32)

    def forward(self):
        # The Pointclouds object creates copies of it's arguments - that's why
        # we have to create a new object in every forward step.
        pcl = Pointclouds(
            points=self.vert_pos[None, ...], features=self.vert_col[None, ...]
        )
        return self.renderer(
            pcl,
            gamma=(self.gamma,),
            zfar=(45.0,),
            znear=(1.0,),
            radius_world=True,
            bg_col=torch.ones((3,), dtype=torch.float32, device=DEVICE),
        )[0]


def cli():
    """
    Scene optimization example using pulsar and the unified PyTorch3D interface.
    """
    LOGGER.info("Loading reference...")
    # Load reference.
    ref = (
        torch.from_numpy(
            imageio.imread(
                "../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png"
            )[:, ::-1, :].copy()
        ).to(torch.float32)
        / 255.0
    ).to(DEVICE)
    # Set up model.
    model = SceneModel().to(DEVICE)
    # Optimizer.
    optimizer = optim.SGD(
        [
            {"params": [model.vert_col], "lr": 1e0},
            {"params": [model.vert_rad], "lr": 5e-3},
            {"params": [model.vert_pos], "lr": 1e-2},
        ]
    )
    LOGGER.info("Optimizing...")
    # Optimize.
    for i in range(500):
        optimizer.zero_grad()
        result = model()
        # Visualize.
        result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
        cv2.imshow("opt", result_im[:, :, ::-1])
        overlay_img = np.ascontiguousarray(
            ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
                :, :, ::-1
            ]
        )
        overlay_img = cv2.putText(
            overlay_img,
            "Step %d" % (i),
            (10, 40),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 0, 0),
            2,
            cv2.LINE_AA,
            False,
        )
        cv2.imshow("overlay", overlay_img)
        cv2.waitKey(1)
        # Update.
        loss = ((result - ref) ** 2).sum()
        LOGGER.info("loss %d: %f", i, loss.item())
        loss.backward()
        optimizer.step()
        # Cleanup.
        with torch.no_grad():
            model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
            # Remove points.
            model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
            model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
            vd = (
                (model.vert_col - torch.ones(3, dtype=torch.float32).to(DEVICE))
                .abs()
                .sum(dim=1)
            )
            model.vert_pos.data[vd <= 0.2] = -1000.0
    LOGGER.info("Done.")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    cli()