# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

from PIL import Image
def _load_16big_png_depth(depth_png) -> np.ndarray:
    with Image.open(depth_png) as depth_pil:
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
        # we cast it to uint16, then reinterpret as float16, then cast to float32
        depth = (
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
            .astype(np.float32)
            .reshape((depth_pil.size[1], depth_pil.size[0]))
        )
    return depth
def _load_depth(path, scale_adjustment) -> np.ndarray:
    d = _load_16big_png_depth(path) * scale_adjustment
    d[~np.isfinite(d)] = 0.0
    return d[None]  # fake feature channel

# Geometry functions below used for calculating depth, ignore
def glob_imgs(path):
    imgs = []
    for ext in ["*.png", "*.jpg", "*.JPEG", "*.JPG"]:
        imgs.extend(glob(os.path.join(path, ext)))
    return imgs


def pick(list, item_idcs):
    if not list:
        return list
    return [list[i] for i in item_idcs]


def parse_intrinsics(intrinsics):
    fx = intrinsics[..., 0, :1]
    fy = intrinsics[..., 1, 1:2]
    cx = intrinsics[..., 0, 2:3]
    cy = intrinsics[..., 1, 2:3]
    return fx, fy, cx, cy


from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)


def expand_as(x, y):
    if len(x.shape) == len(y.shape):
        return x

    for i in range(len(y.shape) - len(x.shape)):
        x = x.unsqueeze(-1)

    return x


def lift(x, y, z, intrinsics, homogeneous=False):
    """

    :param self:
    :param x: Shape (batch_size, num_points)
    :param y:
    :param z:
    :param intrinsics:
    :return:
    """
    fx, fy, cx, cy = parse_intrinsics(intrinsics)

    x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
    y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z

    if homogeneous:
        return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
    else:
        return torch.stack((x_lift, y_lift, z), dim=-1)


def world_from_xy_depth(xy, depth, cam2world, intrinsics):
    batch_size, *_ = cam2world.shape

    x_cam = xy[..., 0]
    y_cam = xy[..., 1]
    z_cam = depth

    pixel_points_cam = lift(
        x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True
    )
    world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[
        ..., :3
    ]

    return world_coords


def get_ray_directions(xy, cam2world, intrinsics, normalize=True):
    z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
    pixel_points = world_from_xy_depth(
        xy, z_cam, intrinsics=intrinsics, cam2world=cam2world
    )  # (batch, num_samples, 3)

    cam_pos = cam2world[..., :3, 3]
    ray_dirs = pixel_points - cam_pos[..., None, :]  # (batch, num_samples, 3)
    if normalize:
        ray_dirs = F.normalize(ray_dirs, dim=-1)
    return ray_dirs

from PIL import Image
def _load_16big_png_depth(depth_png) -> np.ndarray:
    with Image.open(depth_png) as depth_pil:
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
        # we cast it to uint16, then reinterpret as float16, then cast to float32
        depth = (
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
            .astype(np.float32)
            .reshape((depth_pil.size[1], depth_pil.size[0]))
        )
    return depth
def _load_depth(path, scale_adjustment) -> np.ndarray:
    d = _load_16big_png_depth(path) * scale_adjustment
    d[~np.isfinite(d)] = 0.0
    return d[None]  # fake feature channel

# NOTE currently using CO3D V1 because they switch to NDC cameras in 2. TODO is to make conversion code (different intrinsics), verify pointclouds, and switch. 

class PokemonRooms(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        num_context=2,
        n_skip=0,
        num_trgt=1,
        dynamic=True,
        low_res=(128,128),
        val=False,
    ):

        self.n_trgt=num_trgt
        self.num_skip=n_skip
        self.low_res=low_res
        self.base_path=os.path.join("/nobackup/nvme1/pokemon_rooms/","%s_pokemon"%["static","dynamic"][dynamic])
        print(self.base_path)

        self.scenedirs = sorted(glob(os.path.join(self.base_path,"*")),key=lambda x:int(x.split("_")[-1]))[:50]
        n=40;self.scenedirs = self.scenedirs[n:] if val else self.scenedirs[:n]

        print("done with dataloader init")

    def set_seq(self,seq_query):
        self.seqs={seq_query:self.total_sorted_seq[seq_query]}
        self.all_frame_names=[x for x in self.total_all_frame_names if x[0]==seq_query]
        self.total_num_data=len(self.all_frame_names)

    def sparsify(self, dict, sparsity):
        new_dict = {}
        if sparsity is None:
            return dict
        else:
            # Sample upper_limit pixel idcs at random.
            rand_idcs = np.random.choice(
                self.img_sidelength ** 2, size=sparsity, replace=False
            )
            for key in ["rgb", "uv"]:
                new_dict[key] = dict[key][rand_idcs]

            for key, v in dict.items():
                if key not in ["rgb", "uv"]:
                    new_dict[key] = dict[key]

            return new_dict

    def set_img_sidelength(self, new_img_sidelength):
        """For multi-resolution training: Updates the image sidelength with which images are loaded."""
        self.img_sidelength = new_img_sidelength
        for instance in self.all_instances:
            instance.set_img_sidelength(new_img_sidelength)

    def __len__(self):
        #print("overfitting");return 3
        return len(self.scenedirs)

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys:
                result[key].append(entry[key])

        for key in keys:
            try:
                result[key] = torch.stack(result[key], dim=0)
            except:
                continue

        return result

    def __getitem__(self, idx,seq_query=None):

        low_res=self.low_res
        context = []
        trgt = []
        post_input = []

        n_skip = (random.choice(self.num_skip) if type(self.num_skip)==list else self.num_skip) + 1

        scenedir = self.scenedirs[idx]

        imgs = sorted(glob(os.path.join(scenedir,"rgb","*")))
        segs = sorted(glob(os.path.join(scenedir,"instance_map","*")))
        depths = sorted(glob(os.path.join(scenedir,"depth","*")))
        poses = [sorted(glob(os.path.join(scenedir,"pose","%s_*"%s))) for s in ["cam","obj_0","obj_1"]]

        idxs = list(range(len(imgs)))[:self.n_trgt*n_skip:n_skip]
        try: depths=torch.stack([torch.from_numpy(np.load(depths[i])["arr_0"]) for i in idxs]).float()
        except: print(scenedir)
        imgs=torch.stack([torch.from_numpy(plt.imread(imgs[i])) for i in idxs])[...,:3] * 255
        poses=torch.stack([torch.stack([torch.from_numpy(np.genfromtxt(poses_[i],dtype=np.float32)).view(4,4) for i in idxs]) for poses_ in poses])
        segs=torch.stack([torch.from_numpy(plt.imread(segs[i])) for i in idxs])
        if len(segs.unique())<3 or (segs.unique()!=torch.tensor([0.0000e+00, 1.5259e-05, 1.5274e-02])).all(): 
            print("bad set")
            return self[0]
        segs=torch.stack([(segs-x).abs()<1e-7 for x in (segs.unique() if 1 else torch.tensor([0.0000e+00, 1.5259e-05, 1.5274e-02]))],1).float()
        with open(os.path.join(scenedir,"intrinsics.txt")) as f: raw_K_flat = np.array(f.readline()[:-4].split()).astype(np.float32)

        K = torch.eye(3)
        K[0,0]=K[1,1]=float(raw_K_flat[0])
        K[0,2]=float(raw_K_flat[1])
        K[1,2]=float(raw_K_flat[2])
        K[0]/=imgs.size(2)
        K[1]/=imgs.size(1)
        K=K[None].expand(len(imgs),-1,-1)

        large_scale=2
        # why is the large aspect ratio weird. TODO put back as portrait aspect ratio. not huge deal but no reason to do this. 
        imgs_large = imgs#F.interpolate(torch.stack([x.permute(2,0,1) for x in imgs]),(int(256*large_scale),int(288*large_scale)),antialias=True,mode="bilinear")
        imgs = F.interpolate(torch.stack([x.permute(2,0,1) for x in imgs]),low_res,antialias=True,mode="bilinear")
        depths = F.interpolate(depths[:,None],low_res,antialias=True,mode="bilinear").squeeze(1)
        segs = F.interpolate(segs.flatten(0,1)[:,None],low_res,mode="nearest").squeeze(1).unflatten(0,segs.shape[:2])

        imgs = imgs/255 * 2 - 1

        # use midas est. depths instead of GT 
        #if 1: depths = self.midas_large(imgs_large.cuda()*.5+.5)
        uv = np.mgrid[0:low_res[0], 0:low_res[1]].astype(float).transpose(1, 2, 0)
        uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
        uv = uv/ torch.tensor([low_res[1]-1, low_res[0]-1])  # uv in [0,1]
        uv = uv[None].expand(len(imgs),-1,-1,-1).flatten(1,2)

        bwd_optflows=[]
        for i in range(1,len(imgs)):
            pix_projs=[]
            if 0: #static
                pix0 = K[0] @ (poses[0][i-1].inverse() @ poses[0][i] @ hom(K[0].inverse() @ (hom(uv[i])*depths[i].flatten().unsqueeze(-1)).T,0))[:3]
                pix_proj = pix0[:2]/(pix0[2]+1e-5)
            else: # dynamic
                for pose_i,pose in enumerate(poses):
                    objpose = (pose[i-1] @ pose[i].inverse()) if pose_i else torch.eye(4)
                    pix0 = K[0] @ (poses[0][i-1].inverse() @ objpose @ poses[0][i] @ hom(K[0].inverse() @ (hom(uv[i])*depths[i].flatten().unsqueeze(-1)).T,0))[:3]
                    pix0 = pix0[:2]/(pix0[2]+1e-5)
                    pix_projs.append(pix0)
            pix_projs=torch.stack(pix_projs)
            pix_proj = (segs[i].flatten(1,2).unsqueeze(1)*pix_projs).sum(0)

            #img_samp = F.grid_sample( imgs[[i-1]][:,:3], pix_proj.T.unflatten(0,low_res)[None]*2-1)[0].permute(1,2,0)
            #plt.imsave(f"/nobackup/users/camsmith/tmp1_{i}.png",imgs[i][:3].permute(1,2,0).numpy()*.5+.5)
            #plt.imsave(f"/nobackup/users/camsmith/tmp0_{i}.png",img_samp.numpy()*.5+.5)
            #plt.imsave(f"/nobackup/users/camsmith/tmp2_{i}.png",imgs[i-1][:3].permute(1,2,0).numpy()*.5+.5)
            #print("saving")
            #zz
            bwd_optflows.append(pix_proj.T-uv[0])
        bwd_optflows=ch_fst(torch.stack(bwd_optflows))
        #fwd_optflows=[]
        #for i in range(len(imgs)-1):
        #    #pix_projs=[]
        #    #pix0 = K[0] @ (poses[0][i+1].inverse() @ poses[0][i] @ hom(K[0].inverse() @ (hom(uv[i])*depths[i].flatten().unsqueeze(-1)).T,0))[:3]
        #    #pix_proj = pix0[:2]/(pix0[2]+1e-5)
        #    #for pose_i,pose in enumerate(poses):
        #    #    objpose = (pose[i-1] @ pose[i].inverse()) if pose_i else torch.eye(4)
        #    #    pix0 = K[0] @ (poses[0][i-1].inverse() @ objpose @ poses[0][i] @ hom(K[0].inverse() @ (hom(uv[i])*depths[i].flatten().unsqueeze(-1)).T,0))[:3]
        #    #    pix0 = pix0[:2]/(pix0[2]+1e-5)
        #    #    pix_projs.append(pix0)
        #    #pix_projs=torch.stack(pix_projs)
        #    #pix_proj = (segs[i].flatten(1,2).unsqueeze(1)*pix_projs).sum(0)

        #    #img_samp = F.grid_sample( imgs[[i-1]][:,:3], pix_proj.T.unflatten(0,low_res)[None]*2-1)[0].permute(1,2,0)
        #    #plt.imsave("/nobackup/users/camsmith/tmp1_{i}.png",imgs[i][:3].permute(1,2,0).numpy()*.5+.5)
        #    #plt.imsave("/nobackup/users/camsmith/tmp0_{i}.png",img_samp.numpy()*.5+.5)
        #    #plt.imsave("/nobackup/users/camsmith/tmp2_{i}.png",imgs[i-1][:3].permute(1,2,0).numpy()*.5+.5)
        #    #print("saving")
        #    fwd_optflows.append(pix_proj.T-uv[0])
        #fwd_optflows=ch_fst(torch.stack(fwd_optflows))
        # TODO verify RAFT flow is about the same but just less detailed

        c2w=poses[0]

        model_input = {
                "rgb": imgs,
                "rgb_large": imgs_large.permute(0,3,1,2),
                "depth": depths.squeeze(1),
                "intrinsics": K,
                #"fwd_flow": fwd_optflows,
                "bwd_flow": bwd_optflows,
                "trgt_c2w": c2w,
                "x_pix": uv,
                "segs": segs,
                }

        gt = {
                "rgb": ch_sec(imgs)*.5+.5,
                "depth": depths.squeeze(1).flatten(1,2).unsqueeze(-1),
                "intrinsics": K,
                "x_pix": uv,
                }
        return model_input,gt

        model_input = {
                "trgt_rgb": imgs[1:],
                "ctxt_rgb": imgs[:-1],
                "trgt_rgb_large": imgs_large[1:].permute(0,3,1,2),
                "ctxt_rgb_large": imgs_large[:-1].permute(0,3,1,2),
                "ctxt_depth": depths.squeeze(1)[:-1],
                "trgt_depth": depths.squeeze(1)[1:],
                "intrinsics": K[1:],
                "bwd_flow": bwd_optflows,
                "trgt_c2w": c2w[1:],
                "ctxt_c2w": c2w[:-1],
                "x_pix": uv[1:],
                "segs": segs[1:],
                "segs": segs[:-1],
                }

        gt = {
                "trgt_rgb": ch_sec(imgs[1:])*.5+.5,
                "ctxt_rgb": ch_sec(imgs[:-1])*.5+.5,
                "ctxt_depth": depths.squeeze(1)[:-1].flatten(1,2).unsqueeze(-1),
                "trgt_depth": depths.squeeze(1)[1:].flatten(1,2).unsqueeze(-1),
                "intrinsics": K[1:],
                "x_pix": uv[1:],
                }

        return model_input,gt
