# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import statistics 
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)

def make_sample(sample,aspect,budget=192*640/4,hires_factor=2,med_factor=1,low_res=None,hi_res=None):
    
    y=np.sqrt(budget/aspect)
    x=budget/y
    low_res_=[int(y),int(x)]
    mult32=lambda x:x-(x%32)+32
    if low_res is None: low_res=[mult32(x) for x in low_res_]
    if hi_res is None: hi_res=[mult32(int(hires_factor*x)) for x in low_res_]
    med_res=[mult32(int(med_factor*x)) for x in low_res_]

    #print("making sample")
    uv = np.mgrid[0 : low_res[0], 0 : low_res[1]].astype(float).transpose(1, 2, 0)
    uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
    uv = uv / torch.tensor([low_res[1]-1, low_res[0]-1])  # uv in [0,1]

    uv_hires = np.mgrid[0 : hi_res[0], 0 : hi_res[1]].astype(float).transpose(1, 2, 0)
    uv_hires = torch.from_numpy(np.flip(uv_hires, axis=-1).copy()).long()
    uv_hires = uv_hires / torch.tensor([hi_res[1]-1, hi_res[0]-1])  # uv in [0,1]

    model_input,gt={},{}
    model_input["rgb"]= F.interpolate(sample["rgb"],low_res,antialias=True,mode="bilinear")

    if "dino_pca" in sample: model_input["dino_pca"]= F.interpolate(sample["dino_pca"],low_res,antialias=True,mode="bilinear")

    #model_input["bwd_flow_large"]= sample["bwd_flow"]
    if "bwd_flow" in sample: model_input["bwd_flow"]= F.interpolate(sample["bwd_flow"],low_res,antialias=True,mode="bilinear")
    #from pdb import set_trace as pdb_;pdb_() 
    if "rig_flow_masks" in sample:model_input["rig_flow_masks"]= F.interpolate(sample["rig_flow_masks"].flatten(0,1)[:,None].float(),low_res,mode="nearest").squeeze(1).unflatten(0,sample["rig_flow_masks"].shape[:2])

    #model_input["rgb_med"]= F.interpolate(sample["rgb"]*.5+.5,med_res,antialias=True,mode="bilinear")

    if "pred_tracks" in sample:
        model_input["pred_tracks"]= sample["pred_tracks"]
        model_input["pred_visibility"]= sample["pred_visibility"]

    #model_input["rgb_large"]= F.interpolate(sample["rgb"]*.5+.5,hi_res,antialias=True,mode="bilinear")*255
    model_input["x_pix"]=uv[None].flatten(1,2).expand(len(model_input["rgb"]),-1,-1)
    #model_input["x_pix_large"]=uv_hires[None].flatten(1,2).expand(len(model_input["rgb"]),-1,-1)
    gt["rgb"]=ch_sec(model_input["rgb"])*.5+.5

    if "intrinsics" in sample: model_input["gt_intrinsics"]=model_input["intrinsics"]=sample["intrinsics"]
    if "depth_inp" in sample:
        gt["depth_inp"]=model_input["depth_inp"]= ch_sec(F.interpolate(sample["depth_inp"][:,None],low_res,mode="nearest"))
        #model_input["depth_inp_large"]= ch_sec(F.interpolate(sample["depth_inp"][:,None],hi_res))
    if "seg_imgs" in sample:
        gt["seg_imgs"]=model_input["seg_imgs"]= ch_sec(F.interpolate(sample["seg_imgs"].float(),low_res,mode="nearest"))
    #for k,v in sample.items(): 
    #    if k not in model_input and k not in gt: model_input[k]=v
    if "c2w" in sample: model_input["c2w"]=sample["c2w"]
    if "org_ratio" in sample: model_input["org_ratio"]=sample["org_ratio"]
    #print("done making sample")
    return model_input,gt

class OptFlowFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        val=False,
        sf=1,# img scale factor (fractional makes it cheaper)
    ):

        self.n_trgt=num_trgt-1
        self.val=val
        self.num_skip=n_skip
        self.low_res=torch.tensor(low_res)
        self.sf=sf

        #path="/data/cameron/monocular_ests/horns/lowrespkg_flow.pt"

        print("Loading data")
        #self.paths=list(glob("/data/cameron/monocular_ests/*/lowrespkg_flow.pt"))
        self.paths=list(glob("/data/cameron/monocular_ests/re10k/*/lowrespkg.pt"))

    def __len__(self): return 100000000

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys: result[key].append(entry[key])

        for key in keys:
            try: result[key] = torch.stack(result[key], dim=0)
            except: continue
        return result

    def __getitem__(self, idx,seq_query=None):

        try: data= list(torch.load(random.choice(self.paths)))
        except: return self[0]
        if len(data[0]["bwd_flow"])<4: return self[0]
        i=random.randint(1,len(data[0]["bwd_flow"])-1)
        data=[{k:(v[[i-1]] if k in ["rig_flow_masks","bwd_flow"] else v if k=="org_ratio" else v[[i-1,i]]) for k,v in x.items() } for x in data]
        return data

class PointTrackFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        val=False,
        sf=1,# img scale factor (fractional makes it cheaper)
    ):

        self.n_trgt=num_trgt-1
        self.val=val
        self.num_skip=n_skip
        self.low_res=torch.tensor(low_res)
        #self.imgs = torch.load(path+"/imgs.pt")
        self.sf=sf

        #self.img_paths = list(glob("/data/DAVIS/1080p/*/*"))
        #self.img_paths = list(glob("/data/ImageNet10k/*/*.JPEG"))
        #random.shuffle(self.img_paths)
        #self.imgs = torch.load(path+"/imgs.pt")

    def __len__(self): return int(1e6)

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys: result[key].append(entry[key])

        for key in keys:
            try: result[key] = torch.stack(result[key], dim=0)
            except: continue
        return result

    def __getitem__(self, idx,seq_query=None):
        scene=np.random.choice([x for x in os.listdir("/data/cameron/monocular_ests/davis") if "2" not in x])
        #scene="bear"
        use_tracks=True
        try:
            self.img_paths = sorted(list(glob("/data/DAVIS/1080p/%s/*.jpg"%scene)))
            idx1,idx2=np.random.randint(len(self.img_paths)),np.random.randint(len(self.img_paths))
            if use_tracks:self.tracks = list(torch.load("/data/cameron/monocular_ests/davis/%s/pred_tracks.pt"%scene))
        except: 
            print("bad load for ",scene)
            return self[0]

        frames = torch.cat([torch.from_numpy(plt.imread(img_path)).permute(2,0,1)[None].float() for img_path in [self.img_paths[idx1],self.img_paths[idx2]]])
        if frames.max()>2: frames=frames/255

        if use_tracks: 
            pred_tracks = rearrange( self.tracks[0],"s 1 t xy c -> t (xy s) c")[[idx1,idx2]]
            pred_visibility = rearrange( self.tracks[1],"s 1 t xy -> t (xy s)")[[idx1,idx2]]
        else: frames=frames[:1]
 
        org_ratio=frames[0].size(-2)/frames[0].size(-1)
        h,s=3,1
        hi_res=[640, 1024]
        sample = {
                "rgb":frames* 2-1,
                }
        if use_tracks:
            sample["pred_tracks"]=pred_tracks
            sample["pred_visibility"]=pred_visibility
        switch=[1,-1][0]
        return make_sample(sample, 1/org_ratio,hires_factor=h,budget=192*640/(8//s),
                low_res=[int(128*self.sf),int(224*self.sf)][::switch],#[::[-1,1][frames.size(-1)>frames.size(-2)]],
                hi_res=hi_res[::-1]#[::[-1,1][frames.size(-1)>frames.size(-2)]])
                )



class ImageFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        val=False,
        sf=1,# img scale factor (fractional makes it cheaper)
    ):

        self.n_trgt=num_trgt-1
        self.val=val
        self.num_skip=n_skip
        self.low_res=torch.tensor(low_res)
        #self.imgs = torch.load(path+"/imgs.pt")
        self.sf=sf

        #self.img_paths = list(glob("/data/co3dhydrants/co3d/hydrants/hydrant/*/images/*"))
        self.img_paths = list(glob("/data/pets_co3d/dog/*/images/*.jpg"))
        #self.img_paths = list(glob("/data/DAVIS/480p/*/*"))
        #self.img_paths = list(glob("/data/DAVIS/1080p/bear/*"))
        #self.img_paths = list(glob("/data/ImageNet10k/*/*.JPEG"))
        random.shuffle(self.img_paths)
        self.img_paths=self.img_paths[:10000]

    def __len__(self): return int(1e6)

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys: result[key].append(entry[key])

        for key in keys:
            try: result[key] = torch.stack(result[key], dim=0)
            except: continue
        return result

    def __getitem__(self, idx,seq_query=None):

        idx=0

        img_path = np.random.choice(self.img_paths)
        try: frames = torch.from_numpy(plt.imread(img_path)).permute(2,0,1)[None].float()
        except:return self[0]
        if frames.size(1)!=3:
            return self[0]
        #frames = self.imgs[[np.random.randint(len(self.imgs))]]#[:self.n_trgt]
        if frames.max()>2: frames=frames/255
 
        org_ratio=frames[0].size(-2)/frames[0].size(-1)
        h,s=3,1
        hi_res=[640, 1024]
        #self.rig_flow_masks=torch.ones_like(self.rig_flow_masks[:,:])
        sample = {
                "rgb":frames[:self.n_trgt*self.num_skip:self.num_skip]* 2-1,
                }
        switch=[1,-1][0]
        return make_sample(sample, 1/org_ratio,hires_factor=h,budget=192*640/(8//s),
                low_res=[int(128*self.sf),int(224*self.sf)][::switch],#[::[-1,1][frames.size(-1)>frames.size(-2)]],
                hi_res=hi_res[::-1]#[::[-1,1][frames.size(-1)>frames.size(-2)]])
                )


class ImageFolder_(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        val=False,
        sf=1,# img scale factor (fractional makes it cheaper)
    ):

        self.n_trgt=num_trgt-1
        self.val=val
        self.num_skip=n_skip
        self.low_res=torch.tensor(low_res)
        self.sf=sf

        print("Loading data")
        self.path=path
        #try:self.dino_feats = torch.load(path+"/dino_feats.pt")
        #except:print("no dino feats")
        try: self.tracks = list(torch.load(path+"/pred_tracks_offline.pt"))
        except: 
            self.tracks = list(torch.load(path+"/pred_tracks_more.pt"))
            #except: self.tracks = list(torch.load(path+"/pred_tracks.pt"))
        self.imgs = torch.load(path+"/imgs.pt")
        try:self.seg_imgs = torch.load(path+"/seg_imgs.pt")
        except:pass
        try:self.dino_feats = torch.load(path+"/dino_feats.pt",map_location="cpu")
        except:pass
        self.bwd_flow = torch.load(path+"/bwd_flow.pt")
        self.rig_flow_masks = torch.load(path+"/rig_flow_masks.pt")[:,:1]
        self.tracks[0] = rearrange(self.tracks[0],"g b t p c -> b t (p g) c")[0]
        self.tracks[1] = rearrange(self.tracks[1],"g b t p -> b t (p g)")[0]
        self.depths = torch.load(path+"/video_depth_ests.pt")
        self.mdepths = torch.load(path+"/depth_ests.pt")

        # align video depths with first frame of metric depths
        #depth_map1,depth_map2=(1e-1+self.mdepths[0][0]).view(-1),(1e-1+self.depths[0]).view(-1)
        #A = torch.stack([depth_map2, torch.ones_like(depth_map2)], dim=1)
        #solution = torch.linalg.lstsq(A,depth_map1[:,None]).solution
        #scale, shift = solution[:,0]
        #self.depths=self.depths*scale+shift
        # scale using median depth
        #self.depths=self.mdepths[0].median()/self.depths[0].median() * self.depths

        self.depths=self.mdepths[0]
        
        print("Done loading data")
        self.poses = None
        if os.path.exists(path+"/poses.pt"): self.poses = torch.load(path+"/poses.pt")

        self.f= torch.load(path+"/intrinsics.pt")

    def __len__(self): return 1

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys: result[key].append(entry[key])

        for key in keys:
            try: result[key] = torch.stack(result[key], dim=0)
            except: continue
        return result

    def __getitem__(self, idx,seq_query=None):

        idx=0

        context = []
        trgt = []
        post_input = []

        frames = self.imgs#[:self.n_trgt]
        f=self.depths[1]
        depth_frames = self.depths#[0]#[:self.n_trgt]

        if frames.max()>2: frames=frames/255

        #f=.8;print("llff focal override")

        intrinsics = repeat(torch.eye(3), "i j -> b i j", b=len(depth_frames)).clone()
        intrinsics[:, :2, 2] = 0.5
        f=self.f
        intrinsics[:, 0, 0] = f 
        intrinsics[:, 1, 1] = f * depth_frames.size(-1) / depth_frames.size(-2)

        org_ratio=frames[0].size(-2)/frames[0].size(-1)
        h,s=3,1
        hi_res=[640, 1024]

        pred_tracks = self.tracks[0][:self.n_trgt*self.num_skip:self.num_skip]
        pred_visibility = self.tracks[1][:self.n_trgt*self.num_skip:self.num_skip]
        #downsampling until more scalable approach
        #s=4
        gs=1
        track_sl=64
        pred_tracks = rearrange( rearrange(pred_tracks,"t (x y s) c -> (t s) c x y",y=track_sl,x=track_sl)[...,::gs,::gs], "(t s) c x y -> t (x y s) c",t=self.n_trgt)
        pred_visibility = rearrange( rearrange(pred_visibility,"t (x y s) -> (t s) x y",y=track_sl,x=track_sl)[...,::gs,::gs], "(t s) x y -> t (x y s)",t=self.n_trgt)
        #from pdb import set_trace as pdb_;pdb_() 

        #self.rig_flow_masks=torch.ones_like(self.rig_flow_masks[:,:])
        sample = {
                "intrinsics":intrinsics[:self.n_trgt*self.num_skip:self.num_skip],
                "rgb":frames[:self.n_trgt*self.num_skip:self.num_skip]* 2-1,
                "seg_imgs":self.seg_imgs[:self.n_trgt*self.num_skip:self.num_skip],
                "dino_pca":self.dino_feats[:self.n_trgt*self.num_skip:self.num_skip],
                "depth_inp":depth_frames[:self.n_trgt*self.num_skip:self.num_skip],"org_ratio":org_ratio,
                "bwd_flow":self.bwd_flow[:self.n_trgt*self.num_skip:self.num_skip][:-1], 
                "rig_flow_masks":self.rig_flow_masks[:self.n_trgt*self.num_skip:self.num_skip][:-1], 
                "pred_tracks":pred_tracks,
                "pred_visibility":pred_visibility,
                }
        if self.poses is not None: sample["c2w"]=self.poses[:self.n_trgt*self.num_skip:self.num_skip]
        switch=[1,-1][0]
        return make_sample(sample, 1/org_ratio,hires_factor=h,budget=192*640/(8//s),
                low_res=[int(128*self.sf),int(224*self.sf)][::switch],#[::[-1,1][frames.size(-1)>frames.size(-2)]],
                hi_res=hi_res[::-1]#[::[-1,1][frames.size(-1)>frames.size(-2)]])
                )
class MultiImageFolder(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        val=False,
        sf=1,# img scale factor (fractional makes it cheaper)
    ):
        self.paths=glob("/data/cameron/monocular_ests/pets_dogs/*/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/pets_dogs/518_74410_144191/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/pets_dogs/1037_30321_22281/lowrespkg.pt")
        #self.paths=["/data/cameron/monocular_ests/pets_dogs/1037_30321_22281/lowrespkg.pt","/data/cameron/monocular_ests/pets_dogs/518_74410_144191/lowrespkg.pt",
        #            "/data/cameron/monocular_ests/pets_dogs/1037_30381_22807/lowrespkg.pt","/data/cameron/monocular_ests/pets_dogs/1037_30392_22995/lowrespkg.pt",]
        #self.paths=glob("/data/cameron/monocular_ests/hydrants_redo/*/lowrespkg.pt")#[:1]
        #self.paths=sorted(glob("/data/cameron/monocular_ests/re10k/*/lowrespkg.pt"))
        #self.paths=glob("/data/cameron/monocular_ests/epic_kitchens/*/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/walkingvid_clips/*/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/tri_robotics/*/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/tri_robotics_sim_more/*/lowrespkg.pt")
        #self.paths=list(glob("/data/cameron/monocular_ests/tri_robotics_sim_more/*/lowrespkg.pt"))[:1]
        #self.paths=list(glob("/data/cameron/monocular_ests/robot_real_moredirs/*/lowrespkg.pt"))[:]
        #self.paths=list(glob("/data/cameron/monocular_ests/robot_real_wvideodepth/-data-cameron-LBM11-flattened_robotdirs-Bimanual*/lowres*"))[:1]

        #self.paths=glob("/data/cameron/monocular_ests/walkingvid_clips/0000264/lowrespkg.pt")
        #self.paths=glob("/data/cameron/monocular_ests/walkingvid_clips/5030659/lowrespkg.pt")

        self.step=0
        
    def __len__(self): return len(self.paths)

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys: result[key].append(entry[key])

        for key in keys:
            print(key)
            try: result[key] = torch.stack(result[key], dim=0)
            except: continue
        return result

    def __getitem__(self, idx,seq_query=None):
        self.step+=1
        data= list(torch.load(self.paths[idx]))
        data=[{k:v for k,v in x.items() if type(v)!=float} for x in data]
        #for x in data:
        #    for k,v in x.items():
        #        if "rig" in k:print(k,v.shape)
        # testing just first and last item
        #data=[{k:(v[[-1]] if k=="rig_flow_masks" else v if k=="org_ratio" else v[[0,-1]]) for k,v in x.items() } for x in data]
        #for k,v in data[0].items(): if k!="org_ratio": print(k,v.shape)

        if any([x in self.paths[idx] for x in ["re10k","hydrant"]]):data[0]["rig_flow_masks"]=torch.ones_like(data[0]["rig_flow_masks"])
        if data[0]["rig_flow_masks"].size(0)!=9:return self[random.randint(0,len(self)-1)]
        #print(data[0]["rgb"].size(0))
        return data
