import matplotlib.pyplot as plt; 
import cv2
import os
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST as MNIST_, FashionMNIST

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

from PIL import Image


from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)

class TetrisObjs(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        overfit_size=11000,
    ):

        self.scene_dicts = glob("/data/toy_tetris/dataset/*.pt")

        self.train_idx=int(len(self.scene_dicts)*.9)
        self.is_val=False

    def train(self): self.is_val=False
    def val(self): self.is_val=True

    def __len__(self):
        return self.train_idx if not self.is_val else (len(self.scene_dicts)-self.train_idx)

    def __getitem__(self, idx):

        if self.is_val: idx+=self.train_idx
        scene_dict=torch.load(self.scene_dicts[idx])

        res=(128,128)
        img=scene_dict["distorted_composited"]["rgb"]
        img=F.interpolate(img[None],res,antialias=True,mode="bilinear")[0]

        return {"rgb":img}
