# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

from tqdm import tqdm

import sys

from glob import glob
import os
import gzip
import json
import numpy as np
from data import common

from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)

class Tanks(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        low_res=(96,112),
        path=".",
        scene=None,
        val=False,
    ):

        self.n_trgt=num_trgt
        self.val=val
        self.num_skip=n_skip
        self.low_res=torch.tensor(low_res)

        src = "/nobackup/nvme1/custom_tandt/caterpillar_subset"
        src = "/nobackup/nvme1/mipnerf360/360_v2/bonsai/"

        sys.path.append("/home/camsmith/repos/official_splatting/scene")
        import dataset_readers
        cam_extrinsics = dataset_readers.read_extrinsics_binary(os.path.join(src, "sparse/0", "images.bin"))
        cam_intrinsics = dataset_readers.read_intrinsics_binary(os.path.join(src, "sparse/0", "cameras.bin"))
        cam_infos_unsorted = dataset_readers.readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(src, "images"))
        cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)

        c2w=[]
        self.images=[]
        for cam_info in tqdm(cam_infos[:40]):
            tmp=torch.eye(4)
            tmp[:3,:3]=torch.from_numpy(cam_info.R)
            tmp[:3,3]=torch.from_numpy(cam_info.T)
            c2w.append(tmp.inverse())
            self.images.append(torch.from_numpy(plt.imread(cam_info.image_path)).permute(2,0,1))
        self.images=torch.stack(self.images)
        f=cam_intrinsics[1].params[0]
        self.K=torch.eye(3)
        self.K[0,0]=f/cam_infos[0].width
        self.K[1,1]=f/cam_infos[0].height
        self.K[:2,2]=.5
        self.extrinsics=torch.stack(c2w)

        fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
        ax.plot(*self.extrinsics[:20,:3,-1].cpu().unbind(1),c="red",label="Estimated Trajectory");
        plt.savefig("tmp.png")
        from pdb import set_trace as pdb_;pdb_() 
        zz

    def __len__(self):
        return len(self.images)-self.n_trgt*self.num_skip-1


    def __getitem__(self, idx,seq_query=None):

        idx=0

        context = []
        trgt = []
        post_input = []

        n_skip = self.num_skip + 1
        idxs=[idx+i*n_skip for i in range(self.n_trgt)]
        frames = self.images[idxs]

        if frames.max()>2: frames=frames/255

        org_ratio=self.images[0].size(-2)/self.images[0].size(-1)
        trgt={"rgb":frames*2-1,"c2w":self.extrinsics[idxs],"intrinsics":self.K.expand(self.n_trgt,-1,-1),"org_ratio":org_ratio}
        s,h=4,3#4,3
        return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s))

    def collate_fn(self, batch_list):
        keys = batch_list[0].keys()
        result = defaultdict(list)

        for entry in batch_list:
            # make them all into a new dict
            for key in keys:
                result[key].append(entry[key])

        for key in keys:
            try:
                result[key] = torch.stack(result[key], dim=0)
            except:
                continue

        return result
