# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import statistics 
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

import torchvision.transforms as transforms
from PIL import Image

# Custom function to add Gaussian noise to tensor
def add_noise_tensor(image, noise_factor=0.1):
    noise = torch.randn_like(image) * noise_factor  # Gaussian noise
    noisy_image = image + noise
    noisy_image = torch.clamp(noisy_image, 0.0, 1.0)  # Ensure pixel values are in [0, 1]
    return noisy_image

# Define the augmentations
#augmentation = transforms.Compose([ transforms.RandomRotation(degrees=20), transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1)), ])
augmentation = transforms.Compose([  ])
from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)
import albumentations as A
bg_paths=glob("/data/cameron/pytorch_vision_datasets/data_large/*/*/*.jpg")

def format_render_sample(data_,use_rand_bg=False,use_color_aug=False):
    data_["fg"] = fg = data_["segs"].max(dim=1)[0][:,None]
    data_["segs"]=torch.cat((~fg,data_["segs"]),1)

    # Add color augmentations to robot
    transform = A.Compose([A.RandomBrightnessContrast(p=0.2),A.MotionBlur(p=0.2),A.MedianBlur(blur_limit=3, p=0.1),A.Blur(blur_limit=3, p=0.1), A.HueSaturationValue(p=0.3),A.RGBShift(r_shift_limit=70, g_shift_limit=70, b_shift_limit=70, p=0.9),A.Blur(p=.3),A.ChannelShuffle(p=.8)])
    if use_color_aug: data_["targ_img"]=torch.from_numpy(transform(image=(data_["targ_img"]*255).to(torch.uint8)[0].permute(1,2,0).numpy())["image"]).float().permute(2,0,1)[None]/255

    # load a random background and set the background
    rand_bg = plt.imread(random.choice(bg_paths))
    if len(rand_bg.shape)==3 and use_rand_bg:
        data_["targ_img"]=data_["targ_img"]*fg.float()+F.interpolate(torch.from_numpy(rand_bg).permute(2,0,1)[None].float()/255,data_["targ_img"].shape[-2:])*(1-data_["fg"].float())
    return {k:F.interpolate(v.float(),(256,256))[0].to(v.dtype) if v.size(-1)>63 else v[0] for k,v in data_.items()}

class PyBulletFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*.pt")
        self.val=val
        if val: self.scene_paths=self.scene_paths[-100:]
        else:self.scene_paths=self.scene_paths[:-100]

    def __len__(self): return 100000000
    def __getitem__(self, idx):
        try: data_=torch.load(random.choice(self.scene_paths[:]))
        except:
            print("weird bad load")
            return self[0]
        if len(data_.keys())!=len(['pointmap_cam', 'pointmap_link', 'cam2world_cv', 'joint_states', 'curr_links', 'segs', 'targ_img', 'rob_id']):
            print("weird bad load")
            return self[0]
        return format_render_sample(data_)
class PointCloudFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*target.png")
        if val: self.scene_paths=self.scene_paths[-100:]
        else:self.scene_paths=self.scene_paths[:-100]
    
    def __len__(self): return 100000000
    def __getitem__(self, idx):
        path=random.choice(self.scene_paths[:])
        #try:
        transfs=np.load(path.replace("target","transfs").replace(".png",".npy"))
        segs=np.load(path.replace("target","segs").replace(".png",".npy"))
        obj_pointmap=np.load(path.replace("target","camera_pointmap").replace(".png",".npy"))
        cam_pointmap=np.load(path.replace("target","obj_pointmap").replace(".png",".npy"))
        img_target = torch.from_numpy(np.array(Image.open(path) )[...,:3]).permute(2,0,1)[None].float()/255
        #except: return self[0]
        datas={ 
               "cam_pointmap":  torch.from_numpy(cam_pointmap).permute(2,0,1),
               "obj_pointmap":  torch.from_numpy(obj_pointmap).permute(2,0,1),
               "obj_transfs":torch.from_numpy(transfs),
               "segs":torch.from_numpy(segs),
               "targ_img":    img_target
              }
            
        return datas
import pybullet as p
import geometry
from tqdm import tqdm
width=height = 512#1200, 1200
hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)  # [N, 4]
fov = 60
aspect = width / height
near, far = 0.1, 5.0
fx, fy, cx, cy, camera_matrix = geometry.get_camera_intrinsics(width,height,fov)
Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
proj_matrix_ = p.computeProjectionMatrixFOV(fov, width/height, near, far)
def format_pybullet_render(joint_states=None,cam2world_cv=None,rob_id=0): # todo parallelize with threading

    num_joints = p.getNumJoints(rob_id)

    # Render N scenes
    datas=defaultdict(list)

    # 1. Randomly move robot joints 
    #joint_states = [x[0] for x in p.getJointStates(rob_id, list(range(p.getNumJoints(rob_id))))]
    if joint_states is None: # randomly move joints
        joint_states=[]
        # Move to generate joint state and accept as arg
        for i in range(num_joints):
            info = p.getJointInfo(rob_id, i)
            lower, upper = info[8], info[9]
            # only sample if there’s actually a range
            if lower < upper:
                rand_angle = random.uniform(lower, upper)
                p.resetJointState(rob_id, i, rand_angle)
                joint_states.append(rand_angle)
            else: joint_states.append(lower) 
    else:
        for joint_i in range(len(joint_states)): p.resetJointState(rob_id, joint_i, joint_states[joint_i])
    curr_links = geometry.get_link_transforms(p,rob_id)

    # 2. Generate camera view
    if cam2world_cv is None: # generate random camera view
        radius = 1.5#np.random.uniform(1, 2)
        camera_position, look_at_position, roll = geometry.sample_camera_positions(radius)
        view_matrix_ = p.computeViewMatrix(camera_position, look_at_position, roll) # TODO randomize roll here, not correct to assume exactly up
        cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
    else:
        world2cam_opengl = Tc @ cam2world_cv
        view_matrix_ = world2cam_opengl.flatten(order='F').tolist()

    # 3. Render image
    image_arr = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix_, 
                                projectionMatrix=proj_matrix_,flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    target_rgb = np.reshape(image_arr[2], (height, width, 4))[:, :, :3].astype(np.float32) / 255.0  # Ignore alpha
    #fg=(image_arr[4]!=-1).reshape(-1)
    #link_indices = (image_arr[4]>> 24) 

    seg_mask = image_arr[4]
    object_ids = seg_mask & ((1 << 24) - 1)
    link_index = seg_mask >> 24
    link_index = np.where(object_ids==0,link_index,-np.ones_like(link_index))

    segs=[link_index==i for i in range(num_joints)]
    if sum([x.sum() for x in segs])<10:print("skipping because no visible robot");return format_pybullet_render(joint_states,cam2world_cv,rob_id)
    
    # 4. Lift image into robot-centric camera frame coordinates
    depth  = (far * near) / (far - (far - near) * image_arr[3])
    points_cam = geometry.depth_to_point_cloud(depth, fx, fy, cx, cy)
    points_rob_base = np.einsum('ij,nj->ni', np.linalg.inv(cam2world_cv) , hom(points_cam))[:, :3] # rob_base is at 0,0,0 
    
    # 5. Transform pointmap to per-link coordinate frame
    # actually pretty sure we should be using inverse link here so it's local to each frame or each joint is centered at 000
    pointmap_img = np.copy(hom(points_rob_base)[:,:3])
    for joint_i in range(num_joints):
        pointmap_img[segs[joint_i].reshape(-1)]=(np.linalg.inv(curr_links[joint_i]) @ hom(points_rob_base)[segs[joint_i].reshape(-1)].T).T[:, :3] 
    
    #datas["cam_pointmap"].append( torch.from_numpy(cam_pointmap).permute(2,0,1))
    datas["pointmap_cam" ].append( torch.from_numpy(points_cam.reshape(height,width,3)).permute(2,0,1) )
    datas["pointmap_link"].append( torch.from_numpy(pointmap_img.reshape(height,width,3)).permute(2,0,1) )
    datas["cam2world_cv" ].append( torch.from_numpy(cam2world_cv)    )
    datas["joint_states" ].append( torch.from_numpy(np.array(joint_states))    )
    datas["curr_links"   ].append( torch.from_numpy(np.stack(curr_links))        )
    datas["segs"         ].append( torch.from_numpy(np.stack(segs))             )
    datas["targ_img"     ].append( torch.from_numpy(target_rgb).permute(2,0,1)  )
    datas["rob_id"       ].append( torch.tensor([rob_id])  )
        
    for k,v in datas.items():datas[k]=torch.stack(v)
    return datas

    #physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    #p.setAdditionalSearchPath(pybullet_data.getDataPath())
    #rob_id = p.loadURDF("franka_panda/panda.urdf")
    #for scene_i in tqdm(range(1000000)): torch.save(get_pybullet_render(None,None,rob_id),datadir+"%05d.pt"%scene_i) # todo use N workers

    #datadir="/data/cameron/robot_calibration_testing/link_centric/"
    #physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    #p.setAdditionalSearchPath(pybullet_data.getDataPath())
    #rob_id = p.loadURDF("franka_panda/panda.urdf")
    #for scene_i in tqdm(range(1000000)): torch.save(get_pybullet_render(None,None,rob_id),datadir+"%05d.pt"%scene_i) # todo use N workers

import pybullet as p
import pybullet_data
import multiprocessing as mp
from tqdm import tqdm

def worker(worker_id, num_workers, datadir, num_scenes, progress_queue):
    # fresh process → fresh default physics client
    p.connect(p.DIRECT)
    p.setGravity(0, 0, -9.8)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    rob_id = p.loadURDF("franka_panda/panda.urdf")

    for i in range(worker_id, num_scenes, num_workers):
        data = format_pybullet_render(None, None, rob_id)
        torch.save(data, f"{datadir}/{i:05d}.pt")
        progress_queue.put(1)


if __name__ == "__main__":
    n_workers = 20
    n_scenes   = 1000
    datadir    = "/data/cameron/robot_calibration_testing/justjoints/"
    if not os.path.exists(datadir): os.makedirs(datadir)


    # a queue for workers to send “I finished one” messages
    progress_queue = mp.Queue()

    # spawn workers
    procs = []
    for wid in range(n_workers):
        p2 = mp.Process(
            target=worker,
            args=(wid, n_workers, datadir, n_scenes, progress_queue)
        )
        p2.start()
        procs.append(p2)

    # global tqdm bar, blocks until we've seen n_scenes messages
    with tqdm(total=n_scenes) as pbar:
        for _ in range(n_scenes):
            progress_queue.get()  # wait for the next “done” signal
            pbar.update(1)

    # all scenes done – now join workers
    for p2 in procs:
        p2.join()

