import numpy as np
import cv2
from cv2 import aruco 
import time,math
import random
from itertools import product, combinations
import os
from tqdm import tqdm
import pybullet as p
import pybullet_data
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
import kornia
import torch
np.set_printoptions(precision=2)
import sys
sys.path.append(".")
import geometry
#seed=47;np.random.seed(seed);random.seed(seed) # 46 is no obj base, 47 is all but 1 visible

# PyBullet scene setup
physicsClient = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
print(pybullet_data.getDataPath())
p.setGravity(0, 0, -9.8)
#robot =rob_id = p.loadURDF("franka_panda/panda.urdf")
import data
urdf=data.urdfs[1]
robot =rob_id = p.loadURDF(urdf)

n_joints=num_joints = p.getNumJoints(rob_id)
from pdb import set_trace as pdb_;pdb_() 

hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)  # [N, 4]

# Render out canon robot image as source image
WIDTH, HEIGHT = 64,64#256, 256
FOV = 60  # degrees
fov_rad = np.deg2rad(FOV)
UP    = [0,0,1]

curr_links,links_with_visual = geometry.get_all_link_transforms(robot)
curr_links={k:curr_links[k] for k in links_with_visual}

link_imgs=[]
for link_idx,link_transform in sorted(curr_links.items()):  # include base = -1
    print(link_idx)
    tag = "base" if link_idx < 0 else f"link{link_idx}"
    #print(f"Rendering {tag}…")

    # 1) hide all, then show only this link
    for i in range(-1, n_joints): p.changeVisualShape(robot, i, rgbaColor=[1,1,1,0])
    p.changeVisualShape(robot, link_idx, rgbaColor=[1,1,1,1])

    # 2) AABB → radius
    aabb_min, aabb_max = p.getAABB(robot, link_idx)
    aabb_min, aabb_max = np.array(aabb_min), np.array(aabb_max)
    half_extents = (aabb_max - aabb_min) * 0.5
    radius = np.linalg.norm(half_extents)
    # 3) distance so sphere fits in FOV
    distance = .8 * radius / np.sin(fov_rad * 0.5)
    # 4) build camera matrices
    center = (aabb_min + aabb_max) * 0.5
    eye   = center + np.array([0, -distance, 0])
    viewM = p.computeViewMatrix(eye.tolist(), center.tolist(), UP)
    near,far=0.01, distance*2
    projM = p.computeProjectionMatrixFOV(FOV, WIDTH/HEIGHT, near,far)

    # 5) render & save
    _, _, rgb, depth, seg = p.getCameraImage( WIDTH, HEIGHT, viewM, projM, flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    fg=(seg!=-1).reshape(-1)
    # 4. Lift image into robot-centric camera frame coordinates
    Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
    cam2world_cv = np.linalg.inv(Tc) @ np.array(viewM).reshape(4, 4, order='F') 
    depth  = (far * near) / (far - (far - near) * depth)
    fx, fy, cx, cy, camera_matrix = geometry.get_camera_intrinsics(WIDTH,HEIGHT,FOV)
    points_cam = geometry.depth_to_point_cloud(depth, fx, fy, cx, cy)
    points_cam[~fg]=0
    points_link_base = np.einsum('ij,nj->ni', np.linalg.inv(link_transform)@np.linalg.inv(cam2world_cv) , hom(points_cam))[:, :3] # rob_base is at 0,0,0 
    points_rob_base = np.einsum('ij,nj->ni', np.linalg.inv(cam2world_cv), hom(points_cam))[:, :3] # rob_base is at 0,0,0 

    link_imgs.append(torch.cat((torch.from_numpy(rgb)[...,:3]/255,torch.from_numpy(points_rob_base).unflatten(0,(HEIGHT,WIDTH))),-1))#{"rgb":rgb, "points":points_rob_base})
torch.save(torch.stack(link_imgs),"/data/cameron/robot_calibration_testing/%s.pt"%urdf.split("/")[-2])


