# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import statistics 
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

import torchvision.transforms as transforms
from PIL import Image

class PyBulletFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__( self, path=".",val=False):
        self.scene_paths=glob(path+"/*.png")
    
    def __len__(self): return 100000000
    def __getitem__(self, idx):
        path=random.choice(self.scene_paths)
        img=torch.from_numpy(plt.imread(path))[...,:3].permute(2,0,1)
        transf=torch.from_numpy(np.load(path.replace("_img.png","_segs.npy"))).float()
        datas={ "transf":transf, "img":    img }
            
        return datas

import pybullet as p
import geometry
from tqdm import tqdm
Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
width, height = 256,256#1200, 1200
fov = 60
aspect = width / height
near, far = 0.1, 5.0
proj_matrix_ = p.computeProjectionMatrixFOV(fov, width/height, near, far)
Tc = np.array([[1,  0,  0,  0], [0,  -1,  0,  0], [0,  0,  -1,  0], [0,  0,  0,  1]]).reshape(4,4)
def get_pybullet_render(num_scenes): # todo parallelize with threading
    for rotation_coverage,rotation_str in zip([.1,.5,1],["01","05","1"]):
        for scene_i in tqdm(range(num_scenes)):
            radius = np.random.uniform(2,4)
            azimuth_cam   = np.random.uniform(0, 2*np.pi)*rotation_coverage
            elevation_cam = np.random.uniform(0, 2*np.pi)*rotation_coverage
            random_translation = np.array([ np.random.uniform(-1, 1), np.random.uniform(-1, 1), np.random.uniform(-1, 1) ]) * .5 
            cam_pos = [ radius * np.cos(elevation_cam) * np.cos(azimuth_cam), radius * np.cos(elevation_cam) * np.sin(azimuth_cam), radius * np.sin(elevation_cam) ]
            cam_pos = np.array(cam_pos) + random_translation
            look_at_pos = np.array([0,0,0]) + random_translation

            view_matrix_ = p.computeViewMatrix(cam_pos.tolist(), look_at_pos.tolist(), [0, 0, 1])
            cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
            # Render image
            image_arr = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix_, 
                                        projectionMatrix=proj_matrix_,flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
            target_rgb = np.reshape(image_arr[2], (height, width, 4))[:, :, :3].astype(np.float32) / 255.0  # Ignore alpha
            # dataset save
            dataset_dir="/data/cameron/duck_testing/"+rotation_str
            plt.imsave(os.path.join(dataset_dir, '%05d_img.png'%scene_i),target_rgb)
            np.save(os.path.join(dataset_dir, '%05d_segs'%scene_i),cam2world_cv)
     
def render_cam(cam2world_cv): 

    world2cam_opengl = Tc @ cam2world_cv
    view_matrix_ = world2cam_opengl.flatten(order='F').tolist()

    image_arr = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix_,
                                projectionMatrix=proj_matrix_,flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX)
    target_rgb = np.reshape(image_arr[2], (height, width, 4))[:, :, :3].astype(np.float32) / 255.0  # Ignore alpha

    return torch.from_numpy(target_rgb)

if __name__ == '__main__': # render pybullet dataset offline
    print("Doing rendering")
    import pybullet_data
    physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    p.loadURDF("duck_vhacd.urdf", basePosition=[0,0,0], baseOrientation=p.getQuaternionFromEuler([0,0,0]),globalScaling=10)
    get_pybullet_render(10000)
