from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import torchvision.transforms as tf
from einops import rearrange, repeat
from jaxtyping import Float, Int64
from omegaconf import DictConfig
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
from glob import glob
import statistics

from data import common

def load_images(root: Path):
    image_paths = sorted(root.iterdir())[:22]
    to_tensor = tf.ToTensor()
    images = [
        to_tensor(Image.open(path))
        for path in tqdm(image_paths, desc="Loading images")
    ]
    depth_paths = [ glob("/data/cameron_depth_data_storage/"+str(path).replace("/","-")+"*.pt")[0] for path in image_paths ]
    focals = [int(path.split("_")[-1].split(".")[0]) for path in depth_paths]
    depth_images = [ torch.load( path ) for path in tqdm(depth_paths, desc="Loading images") ]
    return torch.stack(images),torch.stack(depth_images),image_paths,focals


class LLFF(Dataset):
    root: Path
    extrinsics: Float[Tensor, "batch 4 4"]
    intrinsics: Float[Tensor, "batch 3 3"]
    images: Float[Tensor, "batch 3 height width"]
    context_indices: Int64[Tensor, " view"]

    def __init__(self, low_res,num_trgt,n_skip,scene,val=False) -> None:
        super().__init__()
        n_skip = n_skip[0] if type(n_skip)==list else n_skip
        self.low_res,self.num_trgt,self.n_skip=torch.tensor(low_res),num_trgt,n_skip
        #root = "/nobackup/nvme1/llff/" if "garden" not in scene else "/nobackup/nvme1/mipnerf360/"
        root = "/data/nerf_llff_data/"
        self.root = Path(root+scene)

        self.low_res=False

        # Load the metadata.
        metadata = np.load(self.root / "poses_bounds.npy")
        metadata = torch.tensor(metadata)

        # Extract extrinsics (rotation and translation), intrinsics (image size and
        # focal length), and near/far values.
        b, _ = metadata.shape
        cameras = rearrange(metadata[:, :-2], "b (i j) -> b i j", i=3, j=5)
        rotation = cameras[:, :3, :3]
        translation = cameras[:, :3, 3]
        h, w, f = cameras[:, :3, 4].unbind(dim=-1)
        self.near, self.far = metadata[:, -2:].type(torch.float32).unbind(dim=-1)

        # Load the extrinsics.
        self.extrinsics = repeat(torch.eye(4), "i j -> b i j", b=b).clone()
        self.extrinsics[:, :3, :3] = rotation
        self.extrinsics[:, :3, 3] = translation

        # Convert the extrinsics to OpenCV-style camera-to-world format.
        conversion = torch.zeros((4, 4), dtype=torch.float32)
        conversion[0, 1] = 1
        conversion[1, 0] = 1
        conversion[2, 2] = -1
        conversion[3, 3] = 1
        self.extrinsics = self.extrinsics @ conversion
        self.images,self.depth_images,self.img_paths,focals = load_images(self.root / "images")

        f = statistics.median(focals)

        # Load the intrinsics and normalize them.
        self.intrinsics = repeat(torch.eye(3), "i j -> b i j", b=b).clone()
        self.intrinsics[:, :2, 2] = 0.5
        self.intrinsics[:, 0, 0] = f / w
        self.intrinsics[:, 1, 1] = f / h

        # Load the images.


    def __getitem__(self, index: int):

        n_skip = (random.choice(self.n_skip) if type(self.n_skip)==list else self.n_skip) + 1
        idxs = list(range(0,self.num_trgt*n_skip,n_skip))

        img=self.images[idxs]
        if img.max()>2: img=img/255

        org_ratio=img.size(-2)/img.size(-1)
        trgt={"rgb":img*2-1,"c2w":self.extrinsics[idxs],"intrinsics":self.intrinsics[idxs],"org_ratio":org_ratio}
        trgt["depth_inp"]=self.depth_images[idxs]
        h,s=3,1
        hi_res=[640, 1024]
        self.low_res=1
        return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s),low_res=[160*1,256*1] if 1 else [64,64],hi_res=hi_res)
        return common.make_sample(trgt,self.low_res.tolist(),(self.low_res*3).tolist(),(self.low_res*8).tolist())

    def __len__(self) -> int:
        return 100#len(self.images)-self.num_trgt*self.n_skip-1
