"""History-window dataset for the autoregressive transformer policy.

Wraps the per-demo cached arrays produced by prerender_dataset.py and yields
8-frame (or N-frame) history windows: past H frames + their EEF pixel coords,
plus the next-step EEF pixel as prediction target.

Designed to be a drop-in alternative to CachedTrajectoryDataset for train_ar.py.
"""
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR


IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


class HistoryTrajectoryDataset(Dataset):
    """Yields (history_imgs[H,3,448,448], history_eef_xy[H,2], target_eef_xy[2]) windows.

    Layout under cache_root mirrors CachedTrajectoryDataset (data.py):
        <benchmark>/task_<id>/demo_<idx>/
            frames/000000.png ...
            pix_uv.npy        (T, 2)  EEF pixel coords in image-pixel space
            eef_pos.npy       (T, 3)  — kept for stats, not required here
            ...

    Sample index = (demo_idx, t) where t is the frame whose NEXT step is the prediction
    target (so we need pix_uv[t+1] to exist). History: pix_uv[t-H+1 .. t] (left-padded
    by repeating the earliest available frame if the demo starts within H of t).
    """

    def __init__(
        self,
        cache_root,
        benchmark_name="libero_spatial",
        task_ids=None,        # list of ints, or None = all found
        image_size=448,
        history_len=8,
        frame_stride=1,       # AR mode wants dense history; default to 1 (every frame)
        max_demos=0,          # 0 = all
    ):
        self.image_size   = image_size
        self.history_len  = history_len
        self.frame_stride = frame_stride

        bench_root = Path(cache_root) / benchmark_name
        if not bench_root.exists():
            raise FileNotFoundError(f"Cache not found: {bench_root}")

        task_dirs = sorted(bench_root.glob("task_*"))
        if task_ids is not None:
            task_dirs = [d for d in task_dirs if int(d.name.split("_")[1]) in task_ids]

        self.demos = []
        self.samples = []
        for task_dir in task_dirs:
            task_demo_count = 0
            for demo_dir in sorted(task_dir.glob("demo_*")):
                if max_demos > 0 and task_demo_count >= max_demos:
                    break
                frames_dir = demo_dir / "frames"
                if not frames_dir.exists():
                    continue
                frame_paths = sorted(frames_dir.glob("*.png"))
                if not frame_paths:
                    continue
                T = len(frame_paths)
                if T < 2:
                    continue  # need at least one history frame + a target

                demo = {
                    "frame_paths": frame_paths,
                    "pix_uv":      np.load(demo_dir / "pix_uv.npy"),
                    "T":           T,
                }
                demo_idx = len(self.demos)
                self.demos.append(demo)
                # Sample range: t from 0 to T-2 (inclusive). target = t + frame_stride (clamped).
                for t in range(T - 1):
                    self.samples.append((demo_idx, t))
                task_demo_count += 1

        print(f"HistoryTrajectoryDataset: {len(self.demos)} demos, {len(self.samples)} samples "
              f"(H={history_len}, stride={frame_stride})")

    def __len__(self):
        return len(self.samples)

    def _load_frame(self, path):
        bgr = cv2.imread(str(path))
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
            rgb = cv2.resize(rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        # Normalize per ImageNet stats
        rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
        return torch.from_numpy(rgb).permute(2, 0, 1).float()  # (3, H, W)

    def __getitem__(self, idx):
        demo_idx, t = self.samples[idx]
        demo = self.demos[demo_idx]
        T = demo["T"]
        H = self.history_len
        s = self.frame_stride

        # History frame indices: [t - (H-1)*s, ..., t - s, t] — left-padded by clamping at 0.
        hist_idx = [max(0, t - (H - 1 - k) * s) for k in range(H)]
        # Target index: t + s (clamped at T-1).
        tgt_idx = min(t + s, T - 1)

        imgs = torch.stack([self._load_frame(demo["frame_paths"][i]) for i in hist_idx], dim=0)  # (H, 3, Hi, Wi)
        eef_xy = np.stack([demo["pix_uv"][i] for i in hist_idx], axis=0).astype(np.float32)      # (H, 2)
        target_xy = demo["pix_uv"][tgt_idx].astype(np.float32)                                    # (2,)

        # Clamp pixel coords to [0, image_size-1] just in case the prerender wrote something OOB
        eef_xy = np.clip(eef_xy, 0, self.image_size - 1)
        target_xy = np.clip(target_xy, 0, self.image_size - 1)

        return {
            "history_imgs":   imgs,                                  # (H, 3, 448, 448) — normalized
            "history_eef_xy": torch.from_numpy(eef_xy).float(),       # (H, 2)
            "target_eef_xy":  torch.from_numpy(target_xy).float(),    # (2,)
            "demo_idx":       torch.tensor(demo_idx, dtype=torch.long),
            "start_t":        torch.tensor(t, dtype=torch.long),
        }


class WindowTrajectoryDataset(Dataset):
    """Yields W consecutive frames + their EEF pixel coords, for multi-target AR training.

    For a W-frame window, the AR model with attention context H predicts the EEF at each step
    t in [H, W-1] using frames[t-H:t]. So one window contributes (W - H) supervision targets,
    all sharing one DINO forward.

    Sample = (demo_idx, start_t). The window spans frames [start_t, start_t + W - 1] (clamped
    at demo end via left-pad of the last valid frame).
    """

    def __init__(
        self,
        cache_root,
        benchmark_name="libero_spatial",
        task_ids=None,
        image_size=448,
        window_len=20,
        frame_stride=1,
        max_demos=0,
    ):
        self.image_size   = image_size
        self.window_len   = window_len
        self.frame_stride = frame_stride

        bench_root = Path(cache_root) / benchmark_name
        if not bench_root.exists():
            raise FileNotFoundError(f"Cache not found: {bench_root}")

        task_dirs = sorted(bench_root.glob("task_*"))
        if task_ids is not None:
            task_dirs = [d for d in task_dirs if int(d.name.split("_")[1]) in task_ids]

        self.demos = []
        self.samples = []
        for task_dir in task_dirs:
            task_demo_count = 0
            for demo_dir in sorted(task_dir.glob("demo_*")):
                if max_demos > 0 and task_demo_count >= max_demos:
                    break
                frames_dir = demo_dir / "frames"
                if not frames_dir.exists():
                    continue
                frame_paths = sorted(frames_dir.glob("*.png"))
                if not frame_paths or len(frame_paths) < 2:
                    continue
                T = len(frame_paths)
                demo = {
                    "frame_paths":  frame_paths,
                    "pix_uv":       np.load(demo_dir / "pix_uv.npy"),
                    "eef_pos":      np.load(demo_dir / "eef_pos.npy"),
                    "eef_quat":     np.load(demo_dir / "eef_quat.npy"),
                    "gripper":      np.load(demo_dir / "gripper.npy"),
                    "cam_extrinsic": np.load(demo_dir / "cam_extrinsic.npy"),
                    "cam_K_norm":    np.load(demo_dir / "cam_K_norm.npy"),
                    "T":            T,
                }
                demo_idx = len(self.demos)
                self.demos.append(demo)
                # One sample per valid window start; cap at last frame to avoid all-pad windows.
                for start_t in range(T):
                    self.samples.append((demo_idx, start_t))
                task_demo_count += 1

        print(f"WindowTrajectoryDataset: {len(self.demos)} demos, {len(self.samples)} samples "
              f"(W={window_len}, stride={frame_stride})")

    def __len__(self):
        return len(self.samples)

    def _load_frame(self, path):
        bgr = cv2.imread(str(path))
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
            rgb = cv2.resize(rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
        return torch.from_numpy(rgb).permute(2, 0, 1).float()

    def __getitem__(self, idx):
        demo_idx, start_t = self.samples[idx]
        demo = self.demos[demo_idx]
        T = demo["T"]
        W = self.window_len
        s = self.frame_stride

        # Window indices [start_t, start_t+s, start_t+2s, ..., start_t+(W-1)*s], clamped at T-1.
        win_idx = [min(T - 1, start_t + k * s) for k in range(W)]
        imgs   = torch.stack([self._load_frame(demo["frame_paths"][i]) for i in win_idx], dim=0)  # (W, 3, H, W)
        eef_xy = np.stack([demo["pix_uv"][i]   for i in win_idx], axis=0).astype(np.float32)      # (W, 2)
        eef_pos  = np.stack([demo["eef_pos"][i]  for i in win_idx], axis=0).astype(np.float32)     # (W, 3)
        eef_quat = np.stack([demo["eef_quat"][i] for i in win_idx], axis=0).astype(np.float32)     # (W, 4)
        gripper  = np.stack([demo["gripper"][i]  for i in win_idx], axis=0).astype(np.float32)     # (W,)
        eef_xy = np.clip(eef_xy, 0, self.image_size - 1)

        # Quat → euler XYZ for rotation supervision
        try:
            eef_euler = np.stack([ScipyR.from_quat(q).as_euler('xyz') for q in eef_quat], axis=0).astype(np.float32)
        except ValueError:
            eef_euler = np.zeros_like(eef_pos)

        # valid_mask[t] = True iff frame at win_idx[t] is NOT a clamp-pad of the demo end.
        last_real = T - 1
        valid_mask = torch.tensor([wi < last_real or k == win_idx.index(last_real)
                                   for k, wi in enumerate(win_idx)], dtype=torch.bool)

        # eef_start_xyz_world = EEF position at the FIRST frame of this window.
        # Used by voxel variant C as the anchor for relative-PE per Cameron's spec.
        eef_start_xyz = eef_pos[0].copy()
        cam_K_norm    = demo.get("cam_K_norm")
        cam_extrinsic = demo.get("cam_extrinsic")

        out = {
            "window_imgs":      imgs,                                         # (W, 3, H, W) normalized
            "window_eef_xy":    torch.from_numpy(eef_xy).float(),              # (W, 2) image-pixel
            "window_eef_pos":   torch.from_numpy(eef_pos).float(),              # (W, 3) world-frame
            "window_eef_quat":  torch.from_numpy(eef_quat).float(),             # (W, 4) xyzw
            "window_eef_euler": torch.from_numpy(eef_euler).float(),            # (W, 3) euler xyz (rad)
            "window_gripper":   torch.from_numpy(gripper).float(),              # (W,) -1..+1
            "window_eef_start": torch.from_numpy(eef_start_xyz).float(),        # (3,) anchor for variant C
            "valid_mask":       valid_mask,                                     # (W,)
            "demo_idx":         torch.tensor(demo_idx, dtype=torch.long),
            "start_t":          torch.tensor(start_t, dtype=torch.long),
        }
        if cam_K_norm is not None:
            cam_K = cam_K_norm.copy()
            cam_K[0] *= self.image_size
            cam_K[1] *= self.image_size
            out["cam_K"] = torch.from_numpy(cam_K).float()                      # (3, 3) pixel intrinsics
        if cam_extrinsic is not None:
            out["cam_extrinsic"] = torch.from_numpy(cam_extrinsic).float()       # (4, 4) camera→world
        return out


def target_xy_to_grid_idx(target_xy, image_size, grid_size):
    """Convert (B, 2) pixel coords in [0, image_size) to flat grid indices in [0, grid_size**2)."""
    cell = image_size / grid_size
    gx = (target_xy[..., 0] / cell).long().clamp(0, grid_size - 1)
    gy = (target_xy[..., 1] / cell).long().clamp(0, grid_size - 1)
    return gy * grid_size + gx  # (B,)


def grid_idx_to_pixel(idx, image_size, grid_size):
    """Inverse of target_xy_to_grid_idx — returns the cell center pixel."""
    gy = idx // grid_size
    gx = idx %  grid_size
    cell = image_size / grid_size
    return torch.stack([
        (gx.float() + 0.5) * cell,
        (gy.float() + 0.5) * cell,
    ], dim=-1)
