import os
import numpy as np
from PIL import Image
import scipy, scipy.io
from glob import glob
from easydict import EasyDict
from collections import OrderedDict
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
import torch
import geometry
import vis_scene_graph
import sgvis2
import kornia

def get_metadata(name):
    if name == "mnist":
        metadata = EasyDict(
            {
                "image_size": 28,
                "num_classes": 10,
                "train_images": 60000,
                "val_images": 10000,
                "num_channels": 1,
            }
        )
    elif name == "mnist_m":
        metadata = EasyDict(
            {
                "image_size": 28,
                "num_classes": 10,
                "train_images": 60000,
                "val_images": 10000,
                "num_channels": 3,
            }
        )
    elif name == "cifar10":
        metadata = EasyDict(
            {
                "image_size": 32,
                "num_classes": 10,
                "train_images": 50000,
                "val_images": 10000,
                "num_channels": 3,
            }
        )
    elif name == "melanoma":
        metadata = EasyDict(
            {
                "image_size": 64,
                "num_classes": 2,
                "train_images": 33126,
                "val_images": 0,
                "num_channels": 3,
            }
        )
    elif name == "afhq":
        metadata = EasyDict(
            {
                "image_size": 64,
                "num_classes": 3,
                "train_images": 14630,
                "val_images": 1500,
                "num_channels": 3,
            }
        )
    elif name == "celeba":
        metadata = EasyDict(
            {
                "image_size": 64,
                "num_classes": 4,
                "train_images": 109036,
                "val_images": 12376,
                "num_channels": 3,
            }
        )
    elif name == "cars":
        metadata = EasyDict(
            {
                "image_size": 64,
                "num_classes": 196,
                "train_images": 8144,
                "val_images": 8041,
                "num_channels": 3,
            }
        )
    elif name == "hydrants":
        metadata = EasyDict(
            {
                "image_size": 128,
                "num_classes": 102,
                "train_images": 10000,
                "val_images": 100,
                "num_channels": 3,
            }
        )
    elif name in ["vid_tetris","tetris","spheres","spheres_imgs","spheres_scenegraph"]:
        metadata = EasyDict(
            {
                "image_size": 128,
                "num_classes": 102,
                "train_images": 10000,
                "val_images": 100,
                "num_channels": 3,
            }
        )
    elif name == "flowers":
        metadata = EasyDict(
            {
                "image_size": 64,
                "num_classes": 102,
                "train_images": 2040,
                "val_images": 6149,
                "num_channels": 3,
            }
        )
    elif name == "gtsrb":
        metadata = EasyDict(
            {
                "image_size": 32,
                "num_classes": 43,
                "train_images": 39252,
                "val_images": 12631,
                "num_channels": 3,
            }
        )
    else:
        raise ValueError(f"{name} dataset nor supported!")
    return metadata

class hydrants_dataset(Dataset):
    def __init__(self, root_dir, transform=None):

        self.scenes = sorted(glob("/data/co3dhydrants/co3d/hydrants/hydrant/*"))
        self.withpt_imgs = sorted(glob("/data/co3dhydrants/co3d/hydrants/hydrant/*/images/*.pt"))
        #self.withpt_imgs = sorted(glob("/data/co3dhydrants/co3d/hydrants/hydrant/519_74488_144698/images/*.pt"))
        self.transform=transform

    def __len__(self):
        return len(self.withpt_imgs)#scenes)

    def __getitem__(self, idx):

        #idx=np.random.randint(0,10)

        #img_dir = self.scenes[idx]
        #img_dir = "/data/co3dhydrants/co3d/hydrants/hydrant/519_74488_144698"
        #img_paths = sorted(glob(os.path.join(img_dir,"images/*.jpg")))

        scenepath = self.withpt_imgs[idx]

        n_frame=1
        n_skip=4
        #if len(img_paths)<n_frame*n_skip: return self[0]
        start_idx = 0#np.random.randint(0,len(img_paths)-n_frame*n_skip)
        #img_paths = img_paths[:1]#start_idx:start_idx+n_frame*n_skip:n_skip]
        #img_path = random.choice(img_paths)#start_idx:start_idx+n_frame*n_skip:n_skip]
        #depth,masks,clip_embs = torch.load(img_path.replace(".jpg","_scenegraph.pt"),map_location="cpu")

        img_path = scenepath.replace("_scenegraph.pt",".jpg")
        try: depth,masks,clip_embs = torch.load(scenepath,map_location="cpu")
        except: return self[0]

        imsl=128

        #imgs = torch.cat([torch.from_numpy(plt.imread(img_path)).permute(2,0,1)[None]/255*2-1 for img_path in img_paths])
        img = torch.from_numpy(plt.imread(img_path)).permute(2,0,1)[None]/255*2-1
        scene={"rgb":F.interpolate(img,(imsl,imsl),mode="bilinear", antialias=True)}

        depth=F.interpolate(depth[None,None],(256,256))[0,0]

        # lift depth into 3d point cloud to get object bboxes
        mask_levels=[2,1,0]
        masks_flat = torch.cat([masks[i] for i in mask_levels])#(masks[2],masks[1],masks[0]))
        clip_embs_flat = torch.cat([clip_embs[i] for i in mask_levels])#(clip_embs[2],clip_embs[1],clip_embs[0]))

        point_cloud = sgvis2.depth_map_to_point_cloud(depth, sgvis2.K, sgvis2.d)
        mask_points = [sgvis2.filter_points_by_mask(point_cloud.cpu(), m.cpu()) for m in masks_flat]
        bboxs_flat= torch.stack([ torch.stack(sgvis2.compute_bounding_box(mask_points_)) for mask_points_ in mask_points]) / 100

        max_masks = 100 if len(mask_levels)==3 else 40 if len(mask_levels)==2 else 10

        scene["point_cloud"]=F.interpolate(point_cloud.T.unflatten(1,(256,256))[None],(imsl,imsl))[0].flatten(1,2).T.cpu()
        scene["bboxs"]=torch.cat((bboxs_flat.cpu()[:max_masks],torch.zeros(max(0,max_masks-len(bboxs_flat)),2,3)))
        scene["clip_embs"]=torch.cat((clip_embs_flat.cpu()[:max_masks],torch.zeros(max(0,max_masks-len(clip_embs_flat)),512)))

        return scene
class spheres_dataset(Dataset):
    def __init__(self, root_dir, transform=None ,imgs=False):

        self.scenes = glob(os.path.join("/data/toy_spheres_scenegraph","*"))#[:1]
        self.transform=transform
        self.just_imgs=imgs

    def __len__(self):
        return 100000#len(self.scenes)

    def __getitem__(self, idx):

        #scene=torch.load(self.scenes[idx],map_location="cpu")
        scene=torch.load(random.choice(self.scenes),map_location="cpu")

        #scene = {k:v for k,v in scene.items() if k in ["rgb","cameras","scene_graph"]}
        scene["cameras"] = torch.cat((scene["cameras"][...,:3,-1],kornia.geometry.conversions.rotation_matrix_to_axis_angle(scene["cameras"][...,:3,:3])),-1)/2

        # replace camera pose with raymap image embedding (to keep all tensors image level)
        w=h=64
        #uv = torch.stack(torch.meshgrid(torch.linspace(-1,1,w),torch.linspace(-1,1,w)),-1)
        #raymap = torch.cat(geometry.get_world_rays_(uv.flatten(0,1)[None],torch.eye(3)[None],scene["cameras"]),-1).unflatten(1,(h,w))
        #scene["raymap_origin"]= raymap[...,:3]
        #scene["raymap_dir"]= raymap[...,3:]
        scene["rgb"]=scene["rgb"].clip(-1,1)
        #if self.just_imgs: del scene["cameras"]

        #for k in list(scene.keys()): scene["conditioning_%s"%k]=scene[k][[0]]
        #idxs=[0]+torch.randint(0,len(scene["rgb"]),(6,1)).squeeze().sort()[0].tolist()+[-1]
        idxs=list(range(0,len(scene["rgb"])))[::2]
        if self.just_imgs: idxs=[np.random.randint(0,len(scene["rgb"]))]
        #idxs=[0];print("doing just first img")
        scene={k:v[idxs] if "graph" not in k else v for k,v in scene.items()} 
        scene["idx"]=torch.tensor([idx])[None].expand(len(idxs),-1)

        #scene={k:F.interpolate(v.permute(0,3,1,2)[[np.random.randint(0,len(v))]],(h,w),mode="bilinear", antialias=True) for k,v in scene.items()} 
        #scene={k:F.interpolate(v.permute(0,3,1,2)[[0]+torch.randint(0,len(v),(6,1)).squeeze().numpy()+[-1]],(h,w),mode="bilinear", antialias=True) for k,v in scene.items()} 
        scene["scene_graph"]=vis_scene_graph.pack_scene_graph(scene["scene_graph"])

        for k,v in scene.items():
            if v.size(1)>20: scene[k]= F.interpolate(v.permute(0,3,1,2),(h,w),mode="bilinear", antialias=True)

        #idxs=range(0,17,2)

        return scene
class vid_tetris_dataset(Dataset):
    def __init__(self, root_dir, transform=None):

        self.scenes = glob(os.path.join(root_dir,"*"))[:1]
        self.transform=transform

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        scene=torch.load(self.scenes[idx])
        #image = torch.Image.fromarray(np.uint8(scene["distorted_composited"]["rgb"].permute(1,2,0).numpy()*255)).convert('RGB')
        #if self.transform is not None: image = self.transform(image)
        #scene=scene["distorted_composited"]|{"undistorted_"+k:v for k,v in scene["composited"].items()}
        #scene={k:v for k,v in scene.items() if "seg" not in k} # need permutation invariant loss for seg
        scene={k:F.interpolate(v[:8]*2-1,(128,128),mode="bilinear", antialias=True) for k,v in scene.items()} 

        for k,v in scene.items(): 
            if "shape_type" in k: scene[k]=v/10

        #scene={"rgb":scene["rgb"]}

        # 14dim when concat
        return scene

class tetris_dataset(Dataset):
    def __init__(self, root_dir, transform=None):

        self.scenes = glob(os.path.join(root_dir,"*"))
        self.transform=transform

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        scene=torch.load(self.scenes[idx])
        #image = torch.Image.fromarray(np.uint8(scene["distorted_composited"]["rgb"].permute(1,2,0).numpy()*255)).convert('RGB')
        #if self.transform is not None: image = self.transform(image)
        scene=scene["distorted_composited"]|{"undistorted_"+k:v for k,v in scene["composited"].items()}
        scene={k:v for k,v in scene.items() if "seg" not in k} # need permutation invariant loss for seg
        scene={k:F.interpolate(v[None]*2-1,(128,128),mode="bilinear", antialias=True)[0] for k,v in scene.items()} 

        #scene={"rgb":scene["rgb"]}
        for k,v in scene.items(): 
            if "shape_type" in k: scene[k]=v/10

        # 14dim when concat
        return scene

class oxford_flowers_dataset(Dataset):
    def __init__(self, indexes, labels, root_dir, transform=None):
        self.images = []
        self.targets = []
        self.transform = transform

        for i in indexes:
            self.images.append(
                os.path.join(
                    root_dir,
                    "jpg",
                    "image_" + "".join(["0"] * (5 - len(str(i)))) + str(i) + ".jpg",
                )
            )
            self.targets.append(labels[i - 1] - 1)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        target = self.targets[idx]
        if self.transform is not None:
            image = self.transform(image)
        return image, target


# TODO: Add datasets imagenette/birds/svhn etc etc.
def get_dataset(name, data_dir, metadata):
    """
    Return a dataset with the current name. We only support two datasets with
    their fixed image resolutions. One can easily add additional datasets here.

    Note: To avoid learning the distribution of transformed data, don't use heavy
        data augmentation with diffusion models.
    """
    if name == "mnist":
        transform_train = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    metadata.image_size, scale=(0.8, 1.0), ratio=(0.8, 1.2)
                ),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.MNIST(
            root=data_dir,
            train=True,
            download=True,
            transform=transform_train,
        )
    elif name == "mnist_m":
        transform_train = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    metadata.image_size, scale=(0.8, 1.0), ratio=(0.8, 1.2)
                ),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.ImageFolder(
            data_dir,
            transform=transform_train,
        )
    elif name == "cifar10":
        transform_train = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.CIFAR10(
            root=data_dir,
            train=True,
            download=True,
            transform=transform_train,
        )
    elif name in ["imagenette", "melanoma", "afhq"]:
        transform_train = transforms.Compose(
            [
                transforms.Resize(74),
                transforms.RandomCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.ImageFolder(
            data_dir,
            transform=transform_train,
        )
    elif name == "celeba":
        # celebA has a large number of images, avoiding randomcropping.
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.CenterCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )
        #train_set = datasets.CelebA(
        #    root=data_dir,
        #    #train=True,
        #    download=True,
        #    transform=transform_train,
        #)
        train_set = datasets.ImageFolder(
            data_dir,
            transform=transform_train,
        )
    elif name == "cars":
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.RandomCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.ImageFolder(
            data_dir,
            transform=transform_train,
        )
    elif name == "hydrants":
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
            ]
        )
        train_set = hydrants_dataset(
            data_dir,
            transform_train,
        )
    elif "spheres" in name:
        transform_train = transforms.Compose( [ transforms.Resize(64), transforms.ToTensor(), ])
        train_set = spheres_dataset( data_dir, transform_train,imgs=True)
    elif name == "vid_tetris":
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
            ]
        )
        train_set = vid_tetris_dataset(
            data_dir,
            transform_train,
        )
    elif name == "tetris":
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
            ]
        )
        train_set = tetris_dataset(
            data_dir,
            transform_train,
        )
    elif name == "flowers":
        transform_train = transforms.Compose(
            [
                transforms.Resize(64),
                transforms.RandomCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )
        splits = scipy.io.loadmat(os.path.join(data_dir, "setid.mat"))
        labels = scipy.io.loadmat(os.path.join(data_dir, "imagelabels.mat"))
        labels = labels["labels"][0]
        train_set = oxford_flowers_dataset(
            np.concatenate((splits["trnid"][0], splits["valid"][0]), axis=0),
            labels,
            data_dir,
            transform_train,
        )
    elif name == "gtsrb":
        # celebA has a large number of images, avoiding randomcropping.
        transform_train = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
            ]
        )
        train_set = datasets.ImageFolder(
            data_dir,
            transform=transform_train,
        )
    else:
        raise ValueError(f"{name} dataset nor supported!")
    return train_set


def remove_module(d):
    return OrderedDict({(k[len("module.") :], v) for (k, v) in d.items()})


def fix_legacy_dict(d):
    keys = list(d.keys())
    if "model" in keys:
        d = d["model"]
    if "state_dict" in keys:
        d = d["state_dict"]
    keys = list(d.keys())
    # remove multi-gpu module.
    if "module." in keys[1]:
        d = remove_module(d)
    return d
