# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import statistics 
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

import torchvision.transforms as transforms
from PIL import Image

# Custom function to add Gaussian noise to tensor
def add_noise_tensor(image, noise_factor=0.1):
    noise = torch.randn_like(image) * noise_factor  # Gaussian noise
    noisy_image = image + noise
    noisy_image = torch.clamp(noisy_image, 0.0, 1.0)  # Ensure pixel values are in [0, 1]
    return noisy_image

# Define the augmentations
#augmentation = transforms.Compose([ transforms.RandomRotation(degrees=20), transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1)), ])
augmentation = transforms.Compose([  ])
from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)
import albumentations as A
bg_paths=glob("/data/cameron/pytorch_vision_datasets/data_large/*/*/*.jpg")

urdfs=["franka_panda/panda.urdf","kuka_iiwa/model.urdf","xarm/xarm6_robot.urdf"]

def format_render_sample(data_,use_rand_bg=False,use_color_aug=False):
    # Add color augmentations to robot
    transform = A.Compose([A.RandomBrightnessContrast(p=0.2),A.MotionBlur(p=0.2),A.MedianBlur(blur_limit=3, p=0.1),A.Blur(blur_limit=3, p=0.1), A.HueSaturationValue(p=0.3),A.RGBShift(r_shift_limit=70, g_shift_limit=70, b_shift_limit=70, p=0.9),A.Blur(p=.3),A.ChannelShuffle(p=.8)])
    if use_color_aug: data_["img"]=torch.from_numpy(transform(image=(data_["img"]*255).to(torch.uint8)[0].permute(1,2,0).numpy())["image"]).float().permute(2,0,1)[None]/255

    # load a random background and set the background
    rand_bg = plt.imread(random.choice(bg_paths))
    if len(rand_bg.shape)==3 and use_rand_bg:
        data_["img"]=data_["img"]*fg.float()+F.interpolate(torch.from_numpy(rand_bg).permute(2,0,1)[None].float()/255,data_["img"].shape[-2:])*(1-data_["fg"].float())
    return {k:v[0] for k,v in data_.items()}
    #return {k:F.interpolate(v.float(),(256,256))[0].to(v.dtype) if v.size(-1)>63 else v[0] for k,v in data_.items()}

class PyBulletFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*.pt")
        self.val=val
        if val: self.scene_paths=self.scene_paths[-100:]
        else:self.scene_paths=self.scene_paths[:-100]

    def __len__(self): return 100000000
    def __getitem__(self, idx):
        scene_path=random.choice(self.scene_paths[:])
        data_=torch.load(scene_path)

        embodiment=scene_path.split("/")[-2]
        data_["link_imgs"]=torch.load("/data/cameron/robot_calibration_testing/%s.pt"%embodiment).permute(0,3,1,2)[None]
        data_["rob_id"]=torch.tensor([[x.split("/")[0] for x in urdfs].index(embodiment)])

        #try: data_=torch.load(random.choice(self.scene_paths[:]))
        #except:
        #    print("weird bad load")
        #    return self[0]
        #if len(data_.keys())!=len(['pointmap_cam', 'pointmap_link', 'cam2world_cv', 'joint_states', 'curr_links', 'segs', 'img', 'rob_id']):
        #    print("weird bad load")
        #    return self[0]
        return format_render_sample(data_,use_rand_bg=False,use_color_aug=False)
import pybullet as p
import geometry
from tqdm import tqdm
width=height = 256#1200, 1200
hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)  # [N, 4]
fov = 60
aspect = width / height
near, far = 0.1, 5.0
fx, fy, cx, cy, camera_matrix = geometry.get_camera_intrinsics(width,height,fov)
Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
proj_matrix_ = p.computeProjectionMatrixFOV(fov, width/height, near, far)
def format_pybullet_render(joint_states=None,view_matrix_=None,rob_id=0,cam2world_cv=None,urdf=None): # todo parallelize with threading
    if urdf is not None:
        print("loading new urdf")
        p.disconnect();p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        print(urdfs[rob_id])
        rob_id = p.loadURDF(urdfs[rob_id])

    # 2. Generate camera view
    if view_matrix_ is None: # generate random camera view
        
        camera_position, look_at_position, roll = geometry.sample_camera_positions(fixed=False)
        view_matrix_ = np.array(p.computeViewMatrix(camera_position, look_at_position, roll)).reshape(4,4, order='F')
        cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
    #else:
        #world2cam_opengl = Tc @ cam2world_cv
        #view_matrix_ = world2cam_opengl.flatten(order='F').tolist()

    # Randomly move robot joints
    if joint_states is None:
        geometry.randomly_move_joints(rob_id)
    else:
        for joint_i in range(len(joint_states)): p.resetJointState(rob_id, joint_i, joint_states[joint_i])

    # Render image
    _, _, rgb, depth, seg = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix_.flatten(order='F').tolist(), projectionMatrix=proj_matrix_,flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    target_rgb = np.reshape(rgb, (height, width, 4))[:, :, :3].astype(np.float32) / 255.0  # Ignore alpha
    points_cam = geometry.depth_to_point_cloud((far * near) / (far - (far - near) * depth), fx, fy, cx, cy)

    curr_links,links_with_visual = geometry.get_all_link_transforms(rob_id)
    curr_links={k:curr_links[k] for k in links_with_visual}

    link_positions=np.stack([np.array(p.getLinkState(rob_id, i, computeForwardKinematics=True)[4] if i>=0 else p.getBasePositionAndOrientation(rob_id)[0]) for i in sorted(links_with_visual) ])
    link_positions_cam = np.einsum('ij,nj->ni', cam2world_cv, hom(link_positions))[:, :3]

    fg=(seg!=-1).reshape(-1)
    joint_sdfs = link_positions_cam[:,None]-points_cam[None,:] #*fg[None,:,None]
    datas={}
    datas["joint_sdfs"]=torch.from_numpy(joint_sdfs).permute(0,2,1,).unflatten(-1,(height,width))[None]
    datas["points_cam"]=torch.from_numpy(points_cam).T.unflatten(-1,(height,width))[None]
    datas["img"]=torch.from_numpy(rgb[...,:3]).permute(2,0,1)[None]/255*2-1
    datas["cam2world_cv"]=torch.from_numpy(cam2world_cv)[None]
    datas["fg"]=torch.from_numpy(fg).unflatten(0,(height,width))[None,None]
    datas["link_positions_cam"]=torch.from_numpy(link_positions_cam)[None]
    datas["link_positions"]=torch.from_numpy(link_positions)[None]
    datas["joint_states"]= torch.from_numpy(np.array([x[0] for x in p.getJointStates(rob_id, list(range(p.getNumJoints(rob_id))))]))[None]
    return datas

import pybullet as p
import pybullet_data
import multiprocessing as mp
from tqdm import tqdm

def worker(worker_id, num_workers, datadir, num_scenes, progress_queue,urdf):
    p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    rob_id = p.loadURDF(urdf)

    for i in range(worker_id, num_scenes, num_workers):
        data = format_pybullet_render(None, None, rob_id)
        torch.save(data, f"{datadir}/{i:05d}.pt")
        progress_queue.put(1)

if __name__ == "__main__":

    #p.connect(p.DIRECT);p.setGravity(0, 0, -9.8);p.setAdditionalSearchPath(pybullet_data.getDataPath());rob_id = p.loadURDF("franka_panda/panda.urdf")
    #format_pybullet_render(None,None,rob_id);#zz

    n_workers = 20
    n_scenes   = 2000
    urdfs=["franka_panda/panda.urdf","kuka_iiwa/model.urdf","xarm/xarm6_robot.urdf"]
    urdf=urdfs[2]
    datadir    = "/data/cameron/robot_calibration_testing/kpts/"+urdf.split("/")[-2]
    if not os.path.exists(datadir): os.makedirs(datadir)

    # a queue for workers to send “I finished one” messages
    progress_queue = mp.Queue()

    # spawn workers
    procs = []
    for wid in range(n_workers):
        p2 = mp.Process( target=worker, args=(wid, n_workers, datadir, n_scenes, progress_queue,urdf));p2.start();procs.append(p2)

    # global tqdm bar, blocks until we've seen n_scenes messages
    with tqdm(total=n_scenes) as pbar:
        for _ in range(n_scenes):
            progress_queue.get()  # wait for the next “done” signal
            pbar.update(1)

    # all scenes done – now join workers
    for p2 in procs:
        p2.join()

