import numpy as np
import cv2
from cv2 import aruco 
import time,math
import random
from itertools import product, combinations
import os
from tqdm import tqdm
import pybullet as p
import pybullet_data
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
import kornia
import torch
np.set_printoptions(precision=2)
import sys
sys.path.append("/home/cameronsmith/misc/")
#seed=47;np.random.seed(seed);random.seed(seed) # 46 is no obj base, 47 is all but 1 visible
from pyb_tmpfns import *

# PyBullet scene setup
physicsClient = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
print(pybullet_data.getDataPath())
p.setGravity(0, 0, -9.8)
robot =rob_id = p.loadURDF("franka_panda/panda.urdf")
n_joints=num_joints = p.getNumJoints(rob_id)

hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)  # [N, 4]

# Render out canon robot image as source image
WIDTH, HEIGHT = 256, 256
FOV = 60  # degrees
fov_rad = np.deg2rad(FOV)
UP    = [0,0,1]
vis_data = p.getVisualShapeData(robot)
links_with_visual = {item[1] for item in vis_data}
link_imgs=[]

for link_idx in sorted(links_with_visual.union({-1})):  # include base = -1
    # 1) hide all, then show only this link
    for i in range(-1, n_joints): p.changeVisualShape(robot, i, rgbaColor=[1,1,1,0])
    p.changeVisualShape(robot, link_idx, rgbaColor=[1,1,1,1])
    # 2) AABB → radius
    aabb_min, aabb_max = p.getAABB(robot, link_idx)
    aabb_min, aabb_max = np.array(aabb_min), np.array(aabb_max)
    half_extents = (aabb_max - aabb_min) * 0.5
    radius = np.linalg.norm(half_extents)
    distance = .8 * radius / np.sin(fov_rad * 0.5)
    center = (aabb_min + aabb_max) * 0.5
    eye   = center + np.array([0, -distance, 0])
    viewM = p.computeViewMatrix(eye.tolist(), center.tolist(), UP)
    near,far=0.01, distance*2
    projM = p.computeProjectionMatrixFOV(FOV, WIDTH/HEIGHT, near,far)

    # 5) render & save
    _, _, rgb, depth, seg = p.getCameraImage( WIDTH, HEIGHT, viewM, projM, flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    fg=(seg!=-1).reshape(-1)
    # 4. Lift image into robot-centric camera frame coordinates
    Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
    cam2world_cv = np.linalg.inv(Tc) @ np.array(viewM).reshape(4, 4, order='F') 
    depth  = (far * near) / (far - (far - near) * depth)
    fx, fy, cx, cy, camera_matrix = get_camera_intrinsics(WIDTH,HEIGHT,FOV)
    points_cam = depth_to_point_cloud(depth, fx, fy, cx, cy)
    points_cam[~fg]=0
    link_imgs.append({"rgb":rgb, "fg":fg, "points_cam":points_cam})

torch.save(link_imgs,"link_imgs.pt")
