from copy import deepcopy
from dataclasses import dataclass
import lightning.pytorch as pl
# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import torch
from torch import LongTensor
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from typing import Dict, List, Tuple, Union, Callable
import os
import numpy as np

from .raw_data import RawData
from .asset import Asset
from .transform import TransformConfig, transform_asset
from .datapath import DatapathConfig, Datapath
from .spec import ConfigSpec

from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
from ..tokenizer.parse import get_tokenizer
from ..model.spec import ModelInput

@dataclass
class DatasetConfig(ConfigSpec):
    '''
    Config to handle dataset format.
    '''
    # shuffle dataset
    shuffle: bool

    # batch size
    batch_size: int

    # number of workers
    num_workers: int
    
    # datapath
    datapath_config: DatapathConfig
    
    # use pin memory
    pin_memory: bool = True
    
    # use persistent workers
    persistent_workers: bool = True
    
    @classmethod
    def parse(cls, config) -> 'DatapathConfig':
        cls.check_keys(config)
        return DatasetConfig(
            shuffle=config.shuffle,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            pin_memory=config.pin_memory,
            persistent_workers=config.persistent_workers,
            datapath_config=DatapathConfig.parse(config.datapath_config),
        )
    
    def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
        res: Dict[str, DatasetConfig] = {}
        datapath_config_dict = self.datapath_config.split_by_cls()
        for cls in self.datapath_config.data_path:
            res[cls] = deepcopy(self)
            res[cls].datapath_config = datapath_config_dict[cls]
        return res

class UniRigDatasetModule(pl.LightningDataModule):  
    def __init__(
        self,
        process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
        predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
        predict_transform_config: Union[TransformConfig, None]=None,
        tokenizer_config: Union[TokenizerConfig, None]=None,
        debug: bool=False,
        data_name: str='raw_data.npz',
        datapath: Union[Datapath, None]=None,
        cls: Union[str, None]=None,
    ):
        super().__init__()
        self.process_fn                 = process_fn
        self.predict_dataset_config     = predict_dataset_config
        self.predict_transform_config   = predict_transform_config
        self.tokenizer_config           = tokenizer_config
        self.debug                      = debug
        self.data_name                  = data_name
        
        if debug:
            print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
        
        if datapath is not None:
            self.train_datapath = None
            self.validate_datapath = None
            self.predict_datapath = {
                cls: deepcopy(datapath),
            }
            self.predict_dataset_config = {
                cls: DatasetConfig(
                    shuffle=False,
                    batch_size=1,
                    num_workers=0,
                    datapath_config=deepcopy(datapath),
                    pin_memory=False,
                    persistent_workers=False,
                )
            }
        else:
            # build predict datapath
            if self.predict_dataset_config is not None:
                self.predict_datapath = {
                    cls: Datapath(self.predict_dataset_config[cls].datapath_config)
                    for cls in self.predict_dataset_config
                }
            else:
                self.predict_datapath = None
        
        # get tokenizer
        if tokenizer_config is None:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(config=tokenizer_config)

    def prepare_data(self):
        pass

    def setup(self, stage=None):   
        if self.predict_datapath is not None:
            self._predict_ds = {}
            for cls in self.predict_datapath:
                self._predict_ds[cls] = UniRigDataset(
                    process_fn=self.process_fn,
                    data=self.predict_datapath[cls].get_data(),
                    name=f"predict-{cls}",
                    tokenizer=self.tokenizer,
                    transform_config=self.predict_transform_config,
                    debug=self.debug,
                    data_name=self.data_name,
                )
    
    def predict_dataloader(self):
        if not hasattr(self, "_predict_ds"):
            self.setup()
        return self._create_dataloader(
            dataset=self._predict_ds,
            config=self.predict_dataset_config,
            is_train=False,
            drop_last=False,
        )

    def _create_dataloader(
        self,
        dataset: Union[Dataset, Dict[str, Dataset]],
        config: DatasetConfig,
        is_train: bool,
        **kwargs,
    ) -> Union[DataLoader, Dict[str, DataLoader]]:
        def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
            return DataLoader(
                dataset,
                batch_size=config.batch_size,
                shuffle=config.shuffle,
                num_workers=config.num_workers,
                pin_memory=config.pin_memory,
                persistent_workers=config.persistent_workers,
                collate_fn=dataset.collate_fn,
                **kwargs,
            )
        if isinstance(dataset, Dict):
            return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
        else:
            return create_single_dataloader(dataset, config, **kwargs)

class UniRigDataset(Dataset):
    def __init__(
        self,
        data: List[Tuple[str, str]], # (cls, part)
        name: str,
        process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
        tokenizer: Union[TokenizerSpec, None]=None,
        transform_config: Union[TransformConfig, None]=None,
        debug: bool=False,
        data_name: str='raw_data.npz',
    ) -> None:
        super().__init__()
        
        self.data               = data
        self.name               = name
        self.process_fn         = process_fn
        self.tokenizer          = tokenizer
        self.transform_config   = transform_config
        self.debug              = debug
        self.data_name          = data_name
        
        if not debug:
            assert self.process_fn is not None, 'missing data processing function'

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx) -> ModelInput:
        cls, dir_path = self.data[idx]
        raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
        asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
        
        first_augments, second_augments = transform_asset(
            asset=asset,
            transform_config=self.transform_config,
        )
        if self.tokenizer is not None and asset.parents is not None:
            tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
        else:
            tokens = None
        return ModelInput(
            tokens=tokens,
            pad=None if self.tokenizer is None else self.tokenizer.pad,
            vertices=asset.sampled_vertices.astype(np.float32),
            normals=asset.sampled_normals.astype(np.float32),
            joints=None if asset.joints is None else asset.joints.astype(np.float32),
            tails=None if asset.tails is None else asset.tails.astype(np.float32),
            asset=asset,
            augments=None,
        )

    def _collate_fn_debug(self, batch):
        return batch
    
    def _collate_fn(self, batch):
        return data.dataloader.default_collate(self.process_fn(batch))

    def collate_fn(self, batch):
        if self.debug:
            return self._collate_fn_debug(batch)
        return self._collate_fn(batch)