import cv2
import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import plotly.graph_objects as go




# print(pcd)
# skip = 100   # Skip every n points

# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# point_range = range(0, pcd.shape[0], skip) # skip points to prevent crash
# ax.scatter(pcd[point_range, 0],   # x
#            pcd[point_range, 1],   # y
#            pcd[point_range, 2],   # z
#            c=pcd[point_range, 2], # height data for color
#            cmap='spectral',
#            marker="x")
# ax.axis('scaled')  # {equal, scaled}
# plt.show()

# pcd_o3d = o3d.geometry.PointCloud()  # create point cloud object
# pcd_o3d.points = o3d.utility.Vector3dVector(pcd)  # set pcd_np as the point cloud points
# # Visualize:
# o3d.visualization.draw_geometries([pcd_o3d])


class PointCloudGenerator:
    def __init__(self):
        # Depth camera parameters:
        self.fx_depth = 5.8262448167737955e+02
        self.fy_depth = 5.8269103270988637e+02
        self.cx_depth = 3.1304475870804731e+02
        self.cy_depth = 2.3844389626620386e+02

    def conver_to_point_cloud_v1(self, depth_img):

        pcd = []
        height, width = depth_img.shape
        for i in range(height):
            for j in range(width):
                z = depth_img[i][j]
                x = (j - self.cx_depth) * z / self.fx_depth
                y = (i - self.cy_depth) * z / self.fy_depth
                pcd.append([x, y, z])
        
        return pcd

    def conver_to_point_cloud(self, depth_img):

        # get depth resolution:
        height, width = depth_img.shape
        length = height * width

        # compute indices:
        jj = np.tile(range(width), height)
        ii = np.repeat(range(height), width)

        # rechape depth image
        z = depth_img.reshape(length)
        # compute pcd:
        pcd = np.dstack([(ii - self.cx_depth) * z / self.fx_depth,
                        (jj - self.cy_depth) * z / self.fy_depth,
                        z]).reshape((length, 3))
        
        return pcd

    def generate_point_cloud(self, depth_img, normalize=False):
              
        depth_img = np.array(depth_img)

        if normalize:
            # normalizing depth image
            depth_min = depth_img.min()
            depth_max = depth_img.max()
            normalized_depth = 255 * ((depth_img - depth_min) / (depth_max - depth_min))
            depth_img = normalized_depth

        # convert depth to point cloud
        # point_cloud = self.conver_to_point_cloud(depth_img)

        # depth_image = o3d.geometry.Image(depth_img)
        depth_image = o3d.geometry.Image(np.ascontiguousarray(depth_img))

        # # Create open3d camera intrinsic object
        # intrinsic_matrix = np.array([[self.fx_depth, 0, self.cx_depth], [0, self.fy_depth, self.cy_depth], [0, 0, 1]])
        # camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
        # # camera_intrinsic.intrinsic_matrix = intrinsic_matrix
        # camera_intrinsic.set_intrinsics(640, 480, self.fx_depth, self.fy_depth, self.cx_depth, self.cy_depth)

        # camera settings
        # camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(
        #     depth_img.shape[0], depth_img.shape[1], 500, 500, depth_img.shape[0] / 2, depth_img.shape[1] / 2
        # )


        # Create open3d point cloud from depth image
        point_cloud = o3d.geometry.PointCloud.create_from_depth_image(depth_image, 
        o3d.camera.PinholeCameraIntrinsic( o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))

        return point_cloud
    
# def display_pcd(pcd_data, use_matplotlib=True):

#     if use_matplotlib:
#         fig = plt.figure()
#         ax = fig.add_subplot(111, projection='3d')

#     for data, clr in pcd_data:    
#         # points = np.array(data)
#         points = np.asarray(data.points)
#         skip = 5 
#         point_range = range(0, points.shape[0], skip) # skip points to prevent crash

#         if use_matplotlib:            
#             ax.scatter(points[point_range, 0], points[point_range, 1], points[point_range, 2]*100, c=list(clr).append(1), marker='o')  
        
#         # if not use_matplotlib:
#         #     pcd_o3d = o3d.geometry.PointCloud()  # create point cloud object
#         #     pcd_o3d.points = o3d.utility.Vector3dVector(points)  # set pcd_np as the point cloud points
#         #     # Visualize:
#         #     o3d.visualization.draw_geometries([pcd_o3d])

#     if use_matplotlib:
#         ax.set_xlabel('X Label')
#         ax.set_ylabel('Y Label')
#         ax.set_zlabel('Z Label')
#         ax.view_init(elev=-90, azim=0, roll=-90) 
#         # plt.show()
#         return fig

#     if not use_matplotlib:
#         o3d.visualization.draw_geometries([pcd_o3d])

def display_pcd(pcd_data):
    fig = go.Figure()

    for data, clr in pcd_data:
        points = np.asarray(data.points)
        skip = 1
        point_range = range(0, points.shape[0], skip)

        fig.add_trace(go.Scatter3d(
            x=points[point_range, 0],
            y=points[point_range, 1],
            z=points[point_range, 2]*100,
            mode='markers',
            marker=dict(
                size=1,
                color='rgb'+str(clr),
                opacity=1
            )
        ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X Label',
            yaxis_title='Y Label',
            zaxis_title='Z Label',
            camera=dict(
                eye=dict(x=0, y=0, z=-1),
                # up=dict(x=0, y=0, z=1),
            )
        )
    )

    return fig
    
if __name__ == "__main__":
    depth_img_path = "assets/images/depth_map_p1.png"
    depth_img = cv2.imread(depth_img_path, 0) 
    depth_img = depth_img/255
    point_cloud_gen = PointCloudGenerator()
    pcd = point_cloud_gen.generate_point_cloud(depth_img)
    display_pcd([pcd], use_matplotlib=True)