"""Minimal window dataset for the volume AR model.

Per sample (at frame t of demo d):
  rgb:                (3, 448, 448) — current frame, ImageNet-normalized
  past_eef_world:     (20, 3)       — world EEF at frames [t-19..t]; clamped at 0 (repeat earliest)
  current_eef_world:  (3,)          — == past_eef_world[-1]
  target_eef_world:   (8, 3)        — world EEF at frames [t+stride..t+8*stride] (clamped at demo end)
  target_grip:        (8,)          — gripper at those frames (-1 / +1)
  target_rot_euler:   (8, 3)        — EEF euler XYZ at those frames
  target_voxel_idx:   (8,)          — flat voxel index of each target world coord
  world_to_camera:    (4, 4)
  valid_mask:         (8,)          — False where future is clamped at demo end (skip loss)
"""
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

from robot_volume import (
    voxel_centers_world, world_to_voxel_idx,
    N_PAST_EEF, T_FUTURE, IMAGE_SIZE,
)

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


class VolumeWindowDataset(Dataset):
    def __init__(self, cache_root, benchmark_name="libero_spatial", task_id=0,
                 image_size=IMAGE_SIZE, n_past=N_PAST_EEF, t_future=T_FUTURE,
                 frame_stride=3, max_demos=0):
        self.image_size = image_size
        self.n_past = n_past
        self.t_future = t_future
        self.s = frame_stride

        bench_root = Path(cache_root) / benchmark_name / f"task_{task_id}"
        if not bench_root.exists():
            raise FileNotFoundError(bench_root)

        self.demos = []
        self.samples = []
        for d_idx, demo_dir in enumerate(sorted(bench_root.glob("demo_*"))):
            if max_demos > 0 and d_idx >= max_demos:
                break
            frames = sorted((demo_dir / "frames").glob("*.png"))
            if not frames:
                continue
            self.demos.append({
                "frame_paths": frames,
                "eef_pos":     np.load(demo_dir / "eef_pos.npy"),
                "eef_quat":    np.load(demo_dir / "eef_quat.npy"),
                "gripper":     np.load(demo_dir / "gripper.npy"),
                "world_to_cam": np.load(demo_dir / "world_to_cam.npy"),
                "T":           len(frames),
            })
            # one sample per frame; we'll need >= 1 future frame so cap at T-1
            for t in range(len(frames) - 1):
                self.samples.append((len(self.demos) - 1, t))
        print(f"VolumeWindowDataset: {len(self.demos)} demos, {len(self.samples)} samples "
              f"(past={n_past}, future={t_future}, stride={frame_stride})")

    def __len__(self):
        return len(self.samples)

    def _load_frame(self, path):
        bgr = cv2.imread(str(path))
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        if rgb.shape[0] != self.image_size:
            rgb = cv2.resize(rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
        return torch.from_numpy(rgb.transpose(2, 0, 1)).float()

    def __getitem__(self, idx):
        d, t = self.samples[idx]
        demo = self.demos[d]
        T = demo["T"]
        N = self.n_past
        Tf = self.t_future
        s = self.s

        # Past EEF indices: [t-(N-1), ..., t]; clamp at 0 (repeat earliest).
        past_idx = [max(0, t - (N - 1 - i)) for i in range(N)]
        past_pos = demo["eef_pos"][past_idx].astype(np.float32)              # (N, 3)
        current_pos = past_pos[-1]                                            # (3,)

        # Future indices: [t + s, t + 2s, ..., t + Tf*s]; clamp at T-1.
        last_real = T - 1
        future_idx_raw = [t + (i + 1) * s for i in range(Tf)]
        future_idx = [min(i, last_real) for i in future_idx_raw]
        valid_mask = torch.tensor([raw <= last_real for raw in future_idx_raw], dtype=torch.bool)

        target_pos  = demo["eef_pos"][future_idx].astype(np.float32)         # (Tf, 3)
        target_quat = demo["eef_quat"][future_idx].astype(np.float32)
        # robosuite obs["robot0_eef_quat"] is stored as WXYZ (despite some 'xyzw' comments
        # in the cache code). Scipy expects XYZW — swap before decoding.
        target_quat_xyzw = target_quat[:, [1, 2, 3, 0]]
        try:
            target_eul = np.stack(
                [ScipyR.from_quat(q).as_euler('xyz') for q in target_quat_xyzw],
                axis=0,
            ).astype(np.float32)
        except ValueError:
            target_eul = np.zeros_like(target_pos)
        target_grp  = demo["gripper"][future_idx].astype(np.float32)         # (Tf,)

        target_voxel_idx = world_to_voxel_idx(torch.from_numpy(target_pos)).long()  # (Tf,)

        rgb = self._load_frame(demo["frame_paths"][t])

        return {
            "rgb":               rgb,
            "past_eef_world":    torch.from_numpy(past_pos),                  # (N, 3)
            "current_eef_world": torch.from_numpy(current_pos),               # (3,)
            "target_eef_world":  torch.from_numpy(target_pos),                # (Tf, 3)
            "target_grip":       torch.from_numpy(target_grp),                # (Tf,)
            "target_rot_euler":  torch.from_numpy(target_eul),                # (Tf, 3)
            "target_voxel_idx":  target_voxel_idx,                            # (Tf,)
            "valid_mask":        valid_mask,                                  # (Tf,)
            "world_to_camera":   torch.from_numpy(demo["world_to_cam"].astype(np.float32)),  # (4, 4)
            "demo_idx":          torch.tensor(d, dtype=torch.long),
            "start_t":           torch.tensor(t, dtype=torch.long),
        }
