# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import statistics 
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

import torchvision.transforms as transforms
from PIL import Image

# Custom function to add Gaussian noise to tensor
def add_noise_tensor(image, noise_factor=0.1):
    noise = torch.randn_like(image) * noise_factor  # Gaussian noise
    noisy_image = image + noise
    noisy_image = torch.clamp(noisy_image, 0.0, 1.0)  # Ensure pixel values are in [0, 1]
    return noisy_image

# Define the augmentations
#augmentation = transforms.Compose([ transforms.RandomRotation(degrees=20), transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1)), ])
augmentation = transforms.Compose([  ])
from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)
import albumentations as A
bg_paths=glob("/data/cameron/pytorch_vision_datasets/data_large/*/*/*.jpg")

def format_render_sample(data_,use_rand_bg=False,use_color_aug=False):
    data_["fg"] = fg = data_["segs"].max(dim=1)[0][:,None]
    data_["segs"]=torch.cat((~fg,data_["segs"]),1)

    # Add color augmentations to robot
    transform = A.Compose([A.RandomBrightnessContrast(p=0.2),A.MotionBlur(p=0.2),A.MedianBlur(blur_limit=3, p=0.1),A.Blur(blur_limit=3, p=0.1), A.HueSaturationValue(p=0.3),A.RGBShift(r_shift_limit=70, g_shift_limit=70, b_shift_limit=70, p=0.9),A.Blur(p=.3),A.ChannelShuffle(p=.8)])
    if use_color_aug: data_["img"]=torch.from_numpy(transform(image=(data_["img"]*255).to(torch.uint8)[0].permute(1,2,0).numpy())["image"]).float().permute(2,0,1)[None]/255

    # load a random background and set the background
    rand_bg = plt.imread(random.choice(bg_paths))
    if len(rand_bg.shape)==3 and use_rand_bg:
        data_["img"]=data_["img"]*fg.float()+F.interpolate(torch.from_numpy(rand_bg).permute(2,0,1)[None].float()/255,data_["img"].shape[-2:])*(1-data_["fg"].float())
    return {k:F.interpolate(v.float(),(256,256))[0].to(v.dtype) if v.size(-1)>63 else v[0] for k,v in data_.items()}

class PyBulletFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*.pt")
        self.val=val
        if val: self.scene_paths=self.scene_paths[-100:]
        else:self.scene_paths=self.scene_paths[:-100]

    def __len__(self): return 100000000
    def __getitem__(self, idx):
        data_=torch.load(random.choice(self.scene_paths[:]))
        #try: data_=torch.load(random.choice(self.scene_paths[:]))
        #except:
        #    print("weird bad load")
        #    return self[0]
        #if len(data_.keys())!=len(['pointmap_cam', 'pointmap_link', 'cam2world_cv', 'joint_states', 'curr_links', 'segs', 'img', 'rob_id']):
        #    print("weird bad load")
        #    return self[0]
        return format_render_sample(data_,use_rand_bg=True,use_color_aug=True)
class PointCloudFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*target.png")
        if val: self.scene_paths=self.scene_paths[-100:]
        else:self.scene_paths=self.scene_paths[:-100]
    
    def __len__(self): return 100000000
    def __getitem__(self, idx):
        path=random.choice(self.scene_paths[:])
        #try:
        transfs=np.load(path.replace("target","transfs").replace(".png",".npy"))
        segs=np.load(path.replace("target","segs").replace(".png",".npy"))
        obj_pointmap=np.load(path.replace("target","camera_pointmap").replace(".png",".npy"))
        cam_pointmap=np.load(path.replace("target","obj_pointmap").replace(".png",".npy"))
        img_target = torch.from_numpy(np.array(Image.open(path) )[...,:3]).permute(2,0,1)[None].float()/255
        #except: return self[0]
        datas={ 
               "cam_pointmap":  torch.from_numpy(cam_pointmap).permute(2,0,1),
               "obj_pointmap":  torch.from_numpy(obj_pointmap).permute(2,0,1),
               "obj_transfs":torch.from_numpy(transfs),
               "segs":torch.from_numpy(segs),
               "img":    img_target
              }
            
        return datas
import pybullet as p
import geometry
from tqdm import tqdm
width=height = 256#1200, 1200
hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)  # [N, 4]
fov = 60
aspect = width / height
near, far = 0.1, 5.0
fx, fy, cx, cy, camera_matrix = geometry.get_camera_intrinsics(width,height,fov)
Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
proj_matrix_ = p.computeProjectionMatrixFOV(fov, width/height, near, far)
def format_pybullet_render(joint_states=None,view_matrix_=None,rob_id=0): # todo parallelize with threading

    # 2. Generate camera view
    if view_matrix_ is None: # generate random camera view
        radius = 1.5 if 0 else np.random.uniform(1, 2)
        camera_position, look_at_position, roll = geometry.sample_camera_positions(radius)
        view_matrix_ = np.array(p.computeViewMatrix(camera_position, look_at_position, roll)).reshape(4,4, order='F')
        #cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
    #else:
        #world2cam_opengl = Tc @ cam2world_cv
        #view_matrix_ = world2cam_opengl.flatten(order='F').tolist()

    # Randomly move robot joints
    if joint_states is None:
        geometry.randomly_move_joints(rob_id)
    else:
        for joint_i in range(len(joint_states)): p.resetJointState(rob_id, joint_i, joint_states[joint_i])

    # Render image
    _, _, rgb, depth, seg = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix_.flatten(order='F').tolist(), projectionMatrix=proj_matrix_,flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    target_rgb = np.reshape(rgb, (height, width, 4))[:, :, :3].astype(np.float32) / 255.0  # Ignore alpha
    points_cam = geometry.depth_to_point_cloud((far * near) / (far - (far - near) * depth), fx, fy, cx, cy)

    curr_link_transforms, visual_links = geometry.get_all_link_transforms(rob_id)

    # - Make pointmap image: transform is camera space pointmap to link space verified above and verify recovered transform
    pointmap_link_img = np.zeros_like(points_cam)
    segmasks={}
    for link_idx in visual_links:
        # Transforms: opengl2OpenCV -> inv(World2Cam) aka Robot base -> inv(linkToWorld)
        segmasks[link_idx]= mask = ((seg>> 24) == (link_idx+1)).reshape(-1)
        pointmap_link_img[mask] = np.einsum('ij,nj->ni', np.linalg.inv(curr_link_transforms[link_idx]) @ np.linalg.inv(view_matrix_) @ Tc, hom(points_cam[mask]))[:, :3]

    datas={}
    datas["segs"]=torch.stack([torch.from_numpy(segmasks[k]).reshape(height,width) for k in sorted(segmasks.keys())])[None]
    datas["img"]=torch.from_numpy(target_rgb).permute(2,0,1)[None]
    datas["view_matrix_"]=torch.from_numpy(view_matrix_)[None]
    datas["points_cam"]=torch.from_numpy(points_cam.reshape(height,width,3)).permute(2,0,1)[None]
    datas["points_link"]=torch.from_numpy(pointmap_link_img.reshape(height,width,3)).permute(2,0,1)[None]
    datas["curr_link_transforms"]=torch.stack([torch.from_numpy(curr_link_transforms[k]) for k in sorted(segmasks.keys())])[None]
    datas["joint_states"]= torch.from_numpy(np.array([x[0] for x in p.getJointStates(rob_id, list(range(p.getNumJoints(rob_id))))]))[None]
    return datas

    #if sum([x.sum() for x in segs])<10:print("skipping because no visible robot");return format_pybullet_render(joint_states,cam2world_cv,rob_id)
    
    #datas["cam_pointmap"].append( torch.from_numpy(cam_pointmap).permute(2,0,1))
    datas["pointmap_cam" ].append( torch.from_numpy(points_cam.reshape(height,width,3)).permute(2,0,1) )
    datas["pointmap_link"].append( torch.from_numpy(pointmap_img.reshape(height,width,3)).permute(2,0,1) )
    datas["cam2world_cv" ].append( torch.from_numpy(cam2world_cv)    )
    datas["joint_states" ].append( torch.from_numpy(np.array(joint_states))    )
    datas["curr_links"   ].append( torch.from_numpy(np.stack(curr_links))        )
    datas["segs"         ].append( torch.from_numpy(np.stack(segs))             )
    datas["img"     ].append( torch.from_numpy(target_rgb).permute(2,0,1)  )
    datas["rob_id"       ].append( torch.tensor([rob_id])  )
        
    for k,v in datas.items():datas[k]=torch.stack(v)
    return datas

    #physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    #p.setAdditionalSearchPath(pybullet_data.getDataPath())
    #rob_id = p.loadURDF("franka_panda/panda.urdf")
    #for scene_i in tqdm(range(1000000)): torch.save(get_pybullet_render(None,None,rob_id),datadir+"%05d.pt"%scene_i) # todo use N workers

    #datadir="/data/cameron/robot_calibration_testing/link_centric/"
    #physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    #p.setAdditionalSearchPath(pybullet_data.getDataPath())
    #rob_id = p.loadURDF("franka_panda/panda.urdf")
    #for scene_i in tqdm(range(1000000)): torch.save(get_pybullet_render(None,None,rob_id),datadir+"%05d.pt"%scene_i) # todo use N workers

import pybullet as p
import pybullet_data
import multiprocessing as mp
from tqdm import tqdm

def worker(worker_id, num_workers, datadir, num_scenes, progress_queue):
    p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    rob_id = p.loadURDF("franka_panda/panda.urdf")

    for i in range(worker_id, num_scenes, num_workers):
        data = format_pybullet_render(None, None, rob_id)
        torch.save(data, f"{datadir}/{i:05d}.pt")
        progress_queue.put(1)


if __name__ == "__main__":

    #p.connect(p.DIRECT);p.setGravity(0, 0, -9.8);p.setAdditionalSearchPath(pybullet_data.getDataPath());rob_id = p.loadURDF("franka_panda/panda.urdf")
    #format_pybullet_render(None,None,rob_id);zz

    n_workers = 20
    n_scenes   = 10000
    datadir    = "/data/cameron/robot_calibration_testing/fulldiversecam_redo/"
    if not os.path.exists(datadir): os.makedirs(datadir)

    # a queue for workers to send “I finished one” messages
    progress_queue = mp.Queue()

    # spawn workers
    procs = []
    for wid in range(n_workers):
        p2 = mp.Process(
            target=worker,
            args=(wid, n_workers, datadir, n_scenes, progress_queue)
        )
        p2.start()
        procs.append(p2)

    # global tqdm bar, blocks until we've seen n_scenes messages
    with tqdm(total=n_scenes) as pbar:
        for _ in range(n_scenes):
            progress_queue.get()  # wait for the next “done” signal
            pbar.update(1)

    # all scenes done – now join workers
    for p2 in procs:
        p2.join()

