from io import BytesIO
from pathlib import Path
from random import shuffle
import random

import torch
import torchvision.transforms as tf
from einops import rearrange, repeat
from jaxtyping import Float, UInt8
from omegaconf import DictConfig
from PIL import Image
from torch import Tensor
from torch.utils.data import IterableDataset
from data import common

class DatasetRealEstate10k(torch.utils.data.Dataset):
    cfg: DictConfig
    chunks: list
    near: float = 0.1
    far: float = 10.0

    def __init__(
        self,
        num_query_views=2, imsl=256, n_skip=9, val= False,res_factor=1,
    ) -> None:

        # Collect chunks.
        self.n_skip =n_skip[0] if type(n_skip)==type([]) else n_skip
        self.n_trgt =num_query_views
        self.res_factor=res_factor
        self.val =val 
        root = Path("/data/scene-rep/Real-Estate-10k/re10k_pt") / ["train","test"][val]
        self.chunks = [path for path in root.iterdir() if path.suffix == ".torch"]
        print("making data")

    def __len__(self): 
        return len(self.chunks)
    def __getitem__(self, idx, context=False, input_context=True):
        print("loading chunk",idx)
        chunk_path=self.chunks[idx]
        # Load the chunk.
        chunk = torch.load(chunk_path)
        print("done loading chunk",idx)

        imsl=[86,154]

        print("doing data",idx)
        for example in chunk:
            print("example",idx)
            try:
                extrinsics, intrinsics = self.convert_poses(example["cameras"])

                n = len(example["cameras"])

                end = self.n_trgt*(self.n_skip+1)
                start = 0
                if end>n: end=n-1
                else: start=random.randint(0, n-end-1) 
                idxs = torch.linspace(start,start+end,self.n_trgt).long().clip(0,n-1).tolist()

                # Load the images.
                imgs = self.convert_images([ example["images"][idx] for idx in idxs ])*2-1

                extrinsics=extrinsics[idxs]
                intrinsics=intrinsics[idxs]
                trgt={
                        "rgb":imgs,
                        "intrinsics":intrinsics,
                        "c2w":extrinsics,
                }
            except: 
                print("bad idx")
                continue
            return common.make_sample(trgt,154/86)
            return common.make_sample(trgt,
                                      (min(int(imsl[0]/self.res_factor),imgs.size(-2)),min(int(imsl[1]/self.res_factor),imgs.size(-1))),
                                      (min(int(imsl[0]/self.res_factor),imgs.size(-2)),min(int(imsl[1]/self.res_factor),imgs.size(-1))),
                                      (min(imsl[0]*8,imgs.size(-2)),min(imsl[1]*8,imgs.size(-1))))



            return {
                "context": {
                    "extrinsics": extrinsics[context_indices],
                    "intrinsics": intrinsics[context_indices],
                    "near": self.get_bound("near", len(context_indices)),
                    "far": self.get_bound("far", len(context_indices)),
                    "image": context_images,
                },
                "target": {
                    "extrinsics": extrinsics[target_indices],
                    "intrinsics": intrinsics[target_indices],
                    "near": self.get_bound("near", len(target_indices)),
                    "far": self.get_bound("far", len(target_indices)),
                    "image": target_images,
                },
                "scene": example["key"],
            }

    def get_bound(
        self,
        bound,
        num_views: int,
    ):
        value = torch.tensor(getattr(self, bound), dtype=torch.float32)
        return repeat(value, "-> v", v=num_views)

    def convert_poses(
        self,
        poses,
    ):
        b, _ = poses.shape

        # Convert the intrinsics to a 3x3 normalized K matrix.
        intrinsics = torch.eye(3, dtype=torch.float32)
        intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone()
        fx, fy, cx, cy = poses[:, :4].T
        intrinsics[:, 0, 0] = fx
        intrinsics[:, 1, 1] = fy
        intrinsics[:, 0, 2] = cx
        intrinsics[:, 1, 2] = cy

        # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix.
        w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone()
        w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4)
        return w2c.inverse(), intrinsics

    def convert_images(
        self,
        images,
        ):

        torch_images = []
        for image in images:
            image = Image.open(BytesIO(image.numpy().tobytes()))
            torch_images.append(tf.ToTensor()(image))

        return torch.stack(torch_images)
