import torch
import numpy as np
f_x=f_y=1.0
c_x=c_y=0
K = torch.tensor([[f_x, 0, c_x],
                  [0, f_y, c_y],
                  [0, 0, 1]])  # Intrinsic matrix
d = 1.0  # Depth scaling factor

import matplotlib.pyplot as plt

import torch
import plotly.graph_objects as go
import torch

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def generate_grid(depth_map):
    h, w = depth_map.shape
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
    y = y.float()
    x = x.float()
    return torch.stack([x, y], dim=-1)

def depth_map_to_point_cloud(depth_map, K, d):
    grid = generate_grid(depth_map)
    K_inv = torch.inverse(K)
    grid_homogeneous = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
    points_homogeneous = torch.matmul(K_inv, grid_homogeneous.unsqueeze(-1)).squeeze(-1)
    
    # Multiply by depth map to get the 3D points and remove invalid depths
    points_homogeneous = points_homogeneous * depth_map.unsqueeze(-1) * d
    
    # Filter out invalid points (depth <= 0)
    valid_mask = depth_map > 0
    return points_homogeneous[valid_mask]

def filter_points_by_mask(points, mask):
    # Apply the mask to the points
    mask_flat = mask.flatten()  # Flatten mask to 1D
    points_flat = points.view(-1, 3)  # Flatten points to match the mask shape
    return points_flat[mask_flat.bool()]  # Only keep points where mask is True

def compute_bounding_box(points):
    # Find the min and max along each axis (x, y, z)
    try:
        min_vals = points.min(dim=0).values
        max_vals = points.max(dim=0).values
    except:
        #print("skipping empty bbox")
        min_vals,max_vals=torch.tensor([0.0000, 0.0000, 0.]),torch.tensor([0.0000, 0.0000, 0.])
    return min_vals, max_vals

def visualize_point_cloud_with_bbox(points, mask_points_all):
    # Plot the point cloud
    stride=4#50
    points=points[::stride]
    x, y, z = points[..., 0].flatten(), points[..., 1].flatten(), points[..., 2].flatten()

    colors = torch.from_numpy(img).flatten(0,1)[::stride]
    r, g, b = colors[..., 0], colors[..., 1], colors[..., 2]
    colors_hex = ['#%02x%02x%02x' % (int(ri), int(gi), int(bi)) for ri, gi, bi in zip(r, g, b)]

    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z, mode='markers',
        #marker=dict(size=2, color=rgb, colorscale='Viridis', opacity=0.8)
        marker=dict(size=2, color=colors_hex, opacity=0.8)
    )])

    # Compute bounding box for the masked points
    for mask_points in mask_points_all:
        min_vals, max_vals = compute_bounding_box(mask_points)
        
        # Add the 3D bounding box using 12 lines (edges of a cuboid)
        bbox_edges = [
            (min_vals[0], min_vals[1], min_vals[2]), (max_vals[0], min_vals[1], min_vals[2]),
            (min_vals[0], max_vals[1], min_vals[2]), (max_vals[0], max_vals[1], min_vals[2]),
            (min_vals[0], min_vals[1], max_vals[2]), (max_vals[0], min_vals[1], max_vals[2]),
            (min_vals[0], max_vals[1], max_vals[2]), (max_vals[0], max_vals[1], max_vals[2]),
        ]
        
        # Create the edges by connecting the corners
        edges_idx = [(0, 1), (1, 3), (3, 2), (2, 0),  # bottom face
                     (4, 5), (5, 7), (7, 6), (6, 4),  # top face
                     (0, 4), (1, 5), (2, 6), (3, 7)]  # vertical edges

        for start, end in edges_idx:
            fig.add_trace(go.Scatter3d(
                x=[bbox_edges[start][0], bbox_edges[end][0]],
                y=[bbox_edges[start][1], bbox_edges[end][1]],
                z=[bbox_edges[start][2], bbox_edges[end][2]],
                mode='lines', line=dict(color='red', width=3)
            ))

        # Show the plot
        fig.update_layout(scene=dict(
            xaxis_title='X', yaxis_title='Y', zaxis_title='Z'
        ))
        
    fig.show()


def add_bounding_box(ax, min_vals, max_vals):
    # Create the edges by connecting the corners of the bounding box
    bbox_edges = [
        (min_vals[0], min_vals[1], min_vals[2]), (max_vals[0], min_vals[1], min_vals[2]),
        (min_vals[0], max_vals[1], min_vals[2]), (max_vals[0], max_vals[1], min_vals[2]),
        (min_vals[0], min_vals[1], max_vals[2]), (max_vals[0], min_vals[1], max_vals[2]),
        (min_vals[0], max_vals[1], max_vals[2]), (max_vals[0], max_vals[1], max_vals[2]),
    ]
    
    # Define pairs of edges that need to be connected
    edges_idx = [(0, 1), (1, 3), (3, 2), (2, 0),  # bottom face
                 (4, 5), (5, 7), (7, 6), (6, 4),  # top face
                 (0, 4), (1, 5), (2, 6), (3, 7)]  # vertical edges

    # Plot bounding box edges
    for start, end in edges_idx:
        ax.plot(
            [bbox_edges[start][0], bbox_edges[end][0]],
            [bbox_edges[start][1], bbox_edges[end][1]],
            [bbox_edges[start][2], bbox_edges[end][2]],
            color='red', linewidth=2,alpha=.4,
        )

def visualize_colored_point_cloud_matplotlib(points, colors, mask_points_all=None,min_vals=None,max_vals=None):
    # Unpack the points into x, y, z coordinates
    #points=points[::100]
    #colors = torch.tensor(img).flatten(0,1)[::100]/255
    #colors = colors.flatten(0,1)[::100]/255

    x, y, z = points[..., 0].numpy(), points[..., 1].numpy(), points[..., 2].numpy()

    # Create a 3D scatter plot using matplotlib
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x, y, z, c=colors, s=1)

    if mask_points_all is not None:
        for mask_points in mask_points_all:
            min_vals, max_vals = compute_bounding_box(mask_points)
            add_bounding_box(ax, min_vals, max_vals)
    else:
        for min_val,max_val in zip(min_vals,max_vals):
            add_bounding_box(ax, min_val*100, max_val*100)

    ax.view_init(elev=-60., azim=-90)

    # Set axis labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    ncols, nrows = fig.canvas.get_width_height()
    image = np.frombuffer(buf, dtype=np.uint8).reshape(nrows, ncols, 3)
    plt.close()
    return image

if 0:
    #depth_map, masks = torch.load("hydrant_exp.pt",map_location="cpu")
    #img=plt.imread("hydrantrgb.jpg")
    # Example usage:
    
    point_cloud = depth_map_to_point_cloud(depth_map, K, d)

    # Get points corresponding to the mask
    for mask in masks:
        mask_points = [filter_points_by_mask(point_cloud, m) for m in mask]
        visualize_colored_point_cloud_matplotlib(point_cloud, img, mask_points)
    zz

    # Visualize point cloud with bounding box around masked points
    visualize_point_cloud_with_bbox(point_cloud, mask_points)
