from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import product
import numpy as np
import kornia
import torch

import vis_render

device="cuda" if torch.cuda.is_available() else "cpu"

def pack_scene_graph(sg): return torch.stack([torch.cat(list(obj.values())) for obj in sg])
def unpack_scene_graph(sg): return [{"type":"plane" if obj[[0]].item()>.5 else "sphere",
                                    "position":obj[[1,2,3]],"color":obj[[4,5,6]].clip(0,1),"dimensions":obj[[7,8,9]],"is_present":obj[[10]]} for obj in sg if obj[[-1]]>.5]

def plot_3d_bounding_boxes(boxes,cameras_):

    if len(boxes.shape)==3: return [plot_3d_bounding_boxes(x,y) for x,y in zip(boxes,cameras_)]

    boxes=unpack_scene_graph(boxes)

    if cameras_.size(-1)==6:
        cameras_=cameras_*2
        cameras=torch.eye(4)[None].repeat(len(cameras_),1,1)
        cameras[:,:3,:3]=kornia.geometry.conversions.axis_angle_to_rotation_matrix(cameras_[...,3:])
        cameras[:,:3,-1]=cameras_[...,:3]
    else:cameras=cameras_
    for obj in boxes:
        if obj["type"]=="plane":
            obj["normal"]=torch.tensor([0.,1.,0.]).to(device)
            obj["color_fn"]=lambda M: torch.where(((M[...,[0]] * 2).int() % 2) == ((M[...,[2]] * 2).int() % 2),obj["color"].to(device),torch.ones_like(obj["color"]).to(device))
        else:obj["radius"]=obj["dimensions"].mean()*5
        obj["center"]=obj["position"]
    try: 
        if len(boxes):
            render_img=vis_render.render(1,[{k:v.to(device) if type(v)==type(torch.tensor([])) else v for k,v in obj.items()} for obj in boxes],cameras.to(device))[-2].clip(-1,1)
        else: render_img=torch.zeros(64,64,3)#.to(device)
    except: render_img=torch.zeros(64,64,3)

    #scene_graph = [render.add_plane(obj["center"],obj["color"]) if obj["obj_code"]>.5 else render.add_sphere(obj["center"],obj["dimensions"].mean()*10/2,obj["color"]) 
    #                    for obj in boxes if obj["is_present"]>.5]
    #imgs=[render.render(1,scene_graph,cam.inverse()[None])[-2] for cam in cameras]
    #for img in imgs: plt.imshow(img);plt.show()
    #zz

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for box in boxes[:]:
        center = box['center']
        dimensions = box['dimensions']*10

        # Define the corners of the bounding box
        x_half = dimensions[0] / 2
        y_half = dimensions[1] / 2
        z_half = dimensions[2] / 2
        corners = list(product([-x_half, x_half], [-y_half, y_half], [-z_half, z_half]))

        # Translate corners to the center position
        translated_corners = np.array(corners) + np.array(center)

        # Create vertices from translated corners
        verts = [
            [translated_corners[0], translated_corners[1], translated_corners[5], translated_corners[4]],
            [translated_corners[4], translated_corners[6], translated_corners[7], translated_corners[5]],
            [translated_corners[7], translated_corners[6], translated_corners[2], translated_corners[3]],
            [translated_corners[2], translated_corners[0], translated_corners[1], translated_corners[3]],
            [translated_corners[5], translated_corners[7], translated_corners[3], translated_corners[1]],
            [translated_corners[0], translated_corners[2], translated_corners[6], translated_corners[4]]
        ]

        # Plot the bounding box
        ax.add_collection3d(Poly3DCollection(verts, facecolors=box["color"].tolist(), linewidths=1, edgecolors='black', alpha=0.1))

        for camera in cameras:
            camera_pos = camera[:3, 3]
            direction = camera[:3, 2]  # Assuming the camera looks along the z-axis in its local space
            direction=kornia.geometry.conversions.rotation_matrix_to_axis_angle(camera[:3,:3].T)
            direction=kornia.geometry.conversions.axis_angle_to_quaternion(direction)
            #direction=kornia.geometry.conversions.euler_from_quaternion(*direction)
            #direction = torch.nn.functional.normalize(torch.stack(direction)[None])[0]
            
            ax.scatter(camera_pos[0], camera_pos[1], camera_pos[2], c='red', marker='o')
            #arrow_length = 0.5
            #ax.quiver(camera_pos[0], camera_pos[1], camera_pos[2], direction[0], direction[1], direction[2], length=arrow_length, color='red')

    # Set plot limits and labels
    ax.set_xlim3d(-5, 5)
    ax.set_ylim3d(-5, 5)
    ax.set_zlim3d(-5, 5)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.view_init(elev=-30., azim=90)

    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    ncols, nrows = fig.canvas.get_width_height()
    image = np.frombuffer(buf, dtype=np.uint8).reshape(nrows, ncols, 3)
    plt.close()
    return image,render_img
 
    #plt.show()

# Example usage

# List of bounding boxes (example)
