"""All-in-memory dataset for DA3 volume training (smith300).

Extends Smith300DA3Dataset to also store world-Z (EEF height) per frame and
compute height-bin discretization stats over the whole dataset.

Per sample:
  rgb:           (3, 504, 504) float32 in [0, 1]
  gt_pix_504:    (N_WINDOW, 2) GT EEF pixel coords in 504-space
  gt_pix_valid:  (N_WINDOW,) bool — False for clamped (off-episode) future steps
  gt_z_bin:      (N_WINDOW,) long — height bin index in [0, N_HEIGHT_BINS-1]
  da3_depth:     (504, 504) float32 — frozen DA3 depth for distillation
"""
import os, sys, json
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, "/data/cameron/para/para_mac")
from data_smith300_para import EEF_BODY_NAME, _scale_K_to
DEFAULT_SMITH300_XML = "/data/cameron/para/libero/example_twolink.xml"
import mujoco

DA3_INPUT = 504
N_WINDOW  = 8
N_HEIGHT_BINS = 32
N_ROT_BINS    = 32   # per-axis euler bins by default. If a `rot_pca_path` is supplied
                      # the dataset switches to 1D-PCA mode and rot_bin_t is (N,) instead
                      # of (N, 3) — bins index the principal axis projection.
                      # Cameron 2026-05-20: 1D PCA preferred at deploy (per-axis joint argmax
                      # produces OOD axis combinations); 1D bounds output to a manifold
                      # through training data even if it loses some variance.
N_GRIPPER_BINS = 32


def project(world_pts, world_to_camera, K):
    ones = np.ones((world_pts.shape[0], 1))
    pts_h = np.concatenate([world_pts, ones], axis=-1)
    cam = (world_to_camera @ pts_h.T).T[:, :3]
    z = np.clip(cam[:, 2], 1e-3, None)
    norm = cam[:, :2] / z[:, None]
    homog = np.concatenate([norm, np.ones((norm.shape[0], 1))], axis=-1)
    pix = (K @ homog.T).T[:, :2]
    return pix, cam[:, 2]


class Smith300DA3VolumeDataset(Dataset):
    """Same as Smith300DA3Dataset but adds per-frame world-Z (EEF height) +
    binning into N_HEIGHT_BINS over the dataset min/max."""

    def __init__(self, root_dir="/data/cameron/mac_robot_datasets/first_mobile_collection",
                 image_size=DA3_INPUT, n_window=N_WINDOW, frame_stride=1,
                 mujoco_xml=DEFAULT_SMITH300_XML, depth_subdir="da3_depth_large",
                 n_height_bins=N_HEIGHT_BINS, height_pad_frac=0.05,
                 sessions_whitelist=None,
                 rot_pca_path=None,
                 rot_kmeans_path=None):
        self.sessions_whitelist = set(sessions_whitelist) if sessions_whitelist else None
        self.rot_pca_path = rot_pca_path
        self.rot_kmeans_path = rot_kmeans_path
        self.image_size   = image_size
        self.n_window     = n_window
        self.s            = frame_stride
        self.depth_subdir = depth_subdir
        self.n_height_bins = n_height_bins

        root = Path(root_dir)
        sessions = sorted([d for d in root.iterdir() if d.is_dir()])
        if self.sessions_whitelist:
            sessions = [s for s in sessions if s.name in self.sessions_whitelist]
        if not sessions:
            raise FileNotFoundError(f"No sessions under {root}"
                                     + (f" matching whitelist {self.sessions_whitelist}"
                                        if self.sessions_whitelist else ""))
        print(f"Smith300DA3VolumeDataset: loading {len(sessions)} sessions"
              + (f" (whitelisted)" if self.sessions_whitelist else ""))

        mj_model = mujoco.MjModel.from_xml_path(mujoco_xml)
        mj_data  = mujoco.MjData(mj_model)
        eef_id   = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY_NAME)
        n_qpos = mj_model.nq

        self.episodes = []
        self.rgb_t       = []
        self.depth_t     = []
        self.pix_t       = []
        self.eef_z_t     = []   # world-Z (height) per frame
        self.eef_euler_t = []   # (3,) euler xyz per frame
        self.gripper_t   = []   # scalar gripper (q_motors[6]) per frame
        self.session_idx = []

        for sess_idx, sess in enumerate(sessions):
            meta_path = sess / "meta.json"
            if not meta_path.exists():
                continue
            meta = json.load(open(meta_path))
            IMG_W, IMG_H = meta["image_size_wh"]
            K_orig = np.array(meta["K"], dtype=np.float64)
            T_camera_arucoBase = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
            T_W_baseBody       = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
            T_CAM_WORLD = T_camera_arucoBase @ T_W_baseBody
            K_target = _scale_K_to((IMG_W, IMG_H), K_orig, image_size)

            ep_path = next((sess / p for p in ["rgb_overlay/episodes.json", "episodes.json"]
                            if (sess / p).exists()), None)
            if ep_path is None:
                continue
            sess_episodes = json.load(open(ep_path))["episodes"]

            joints = np.load(sess / "joints.npz")
            q_motors_all = np.asarray(joints["q_motors"], dtype=np.float64)
            n_motors = q_motors_all.shape[1]
            n_frames = q_motors_all.shape[0]
            # UMI data (ArUco-tracked handheld gripper) has eef_pos / eef_quat directly in
            # joints.npz; smith300 data only has q_motors and needs MuJoCo FK. Prefer the
            # baked-in values when present.
            joint_eef_pos_all  = np.asarray(joints["eef_pos"],  dtype=np.float64) if "eef_pos"  in joints.files else None
            joint_eef_quat_all = np.asarray(joints["eef_quat"], dtype=np.float64) if "eef_quat" in joints.files else None

            needed = set()
            for ep in sess_episodes:
                for f in range(int(ep["start"]), min(int(ep["end"]), n_frames - 1) + 1):
                    needed.add(f)

            local_to_global = {}
            for f in sorted(needed):
                if joint_eef_pos_all is not None and joint_eef_quat_all is not None:
                    # UMI-style: use baked-in eef from ArUco tracking (quat already xyzw)
                    eef_pos = joint_eef_pos_all[f].copy()
                    eef_quat_xyzw = joint_eef_quat_all[f].copy()
                else:
                    # Smith300: MuJoCo FK from q_motors
                    q = np.zeros(n_qpos, dtype=np.float64)
                    q[:min(n_motors, n_qpos)] = q_motors_all[f, :n_qpos]
                    mj_data.qpos[:n_qpos] = q
                    mujoco.mj_forward(mj_model, mj_data)
                    eef_pos = mj_data.xpos[eef_id].copy()
                    quat_wxyz = mj_data.xquat[eef_id].copy()
                    eef_quat_xyzw = quat_wxyz[[1, 2, 3, 0]]
                eef_euler = ScipyR.from_quat(eef_quat_xyzw).as_euler('xyz').astype(np.float32)
                gripper = float(q_motors_all[f, 6]) if n_motors >= 7 else 0.0
                pix, _ = project(eef_pos.reshape(1, 3), T_CAM_WORLD, K_target)
                pix = pix[0].astype(np.float32)
                if not (0 <= pix[0] < image_size and 0 <= pix[1] < image_size):
                    continue

                img_path = sess / f"rgb_{f:06d}.jpg"
                depth_path = sess / self.depth_subdir / f"rgb_{f:06d}.npy"
                if not img_path.exists():
                    continue
                bgr = cv2.imread(str(img_path))
                if bgr is None: continue
                rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
                rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
                rgb = rgb.astype(np.float32) / 255.0
                rgb_chw = rgb.transpose(2, 0, 1)
                # Depth is optional now — current dino_kv models don't use depth distillation.
                # New datasets can skip the precompute step entirely.
                if depth_path.exists():
                    depth = np.load(depth_path).astype(np.float32)
                else:
                    depth = np.zeros((image_size, image_size), dtype=np.float32)

                g_idx = len(self.rgb_t)
                local_to_global[f] = g_idx
                self.rgb_t.append(rgb_chw)
                self.depth_t.append(depth)
                self.pix_t.append(pix)
                self.eef_z_t.append(float(eef_pos[2]))
                self.eef_euler_t.append(eef_euler)
                self.gripper_t.append(gripper)
                self.session_idx.append(sess_idx)

            for ep in sess_episodes:
                fs = []
                for f in range(int(ep["start"]), min(int(ep["end"]), n_frames - 1) + 1):
                    if f in local_to_global:
                        fs.append(local_to_global[f])
                if len(fs) < 2:
                    continue
                self.episodes.append({"frames": np.asarray(fs, dtype=np.int64)})

            print(f"  {sess.name}: {sum(1 for s in self.session_idx if s == sess_idx)} frames loaded")

        self.rgb_t      = torch.from_numpy(np.stack(self.rgb_t,    axis=0))
        self.depth_t    = torch.from_numpy(np.stack(self.depth_t,  axis=0))
        self.pix_t      = torch.from_numpy(np.stack(self.pix_t,    axis=0))
        self.eef_z_t    = torch.tensor(self.eef_z_t, dtype=torch.float32)
        self.eef_euler_t= torch.from_numpy(np.stack(self.eef_euler_t, axis=0)).float()  # (N, 3)
        self.gripper_t  = torch.tensor(self.gripper_t, dtype=torch.float32)              # (N,)

        # Gripper stats (pad ±5%)
        g_lo, g_hi = float(self.gripper_t.min()), float(self.gripper_t.max())
        g_pad = (g_hi - g_lo) * 0.05 if g_hi > g_lo else 0.05
        self.min_grip = g_lo - g_pad
        self.max_grip = g_hi + g_pad
        self.n_rot_bins     = N_ROT_BINS
        self.n_gripper_bins = N_GRIPPER_BINS

        # Rotation: 1D PCA if rot_pca_path is provided (and file exists), else per-axis.
        r_lo, r_hi = self.eef_euler_t.min(0).values, self.eef_euler_t.max(0).values
        r_pad = (r_hi - r_lo) * 0.05
        self.min_rot = (r_lo - r_pad).tolist()
        self.max_rot = (r_hi + r_pad).tolist()
        if self.rot_kmeans_path and os.path.exists(self.rot_kmeans_path):
            km = np.load(self.rot_kmeans_path)
            self.kmeans_centroids_quat  = km['centroids_quat']        # (K, 4)
            self.kmeans_centroids_euler = km['centroids_euler']       # (K, 3)
            self.kmeans_bin_counts      = km['bin_counts']            # (K,)
            self.kmeans_n_clusters      = int(km['n_clusters'])
            self.rotation_mode          = 'kmeans'
            self.rot_pca_mean = None; self.rot_pca_axis = None
            self.rot_pca_min  = 0.0;  self.rot_pca_max  = 1.0
            self.rot_pca_ev_ratio = 0.0
            self.rot_bin_t = self._bin_rotation_kmeans(self.eef_euler_t)  # (N,) long in [0, K)
            print(f"  Rotation: k-means mode (K={self.kmeans_n_clusters}), basis={self.rot_kmeans_path}, "
                  f"bin counts={self.kmeans_bin_counts.tolist()}")
        elif self.rot_pca_path and os.path.exists(self.rot_pca_path):
            pca = np.load(self.rot_pca_path)
            self.rot_pca_mean     = pca['mean']                # (3,)
            self.rot_pca_axis     = pca['principal_axis']      # (3,)
            self.rot_pca_min      = float(pca['pca_min'])
            self.rot_pca_max      = float(pca['pca_max'])
            self.rot_pca_ev_ratio = float(pca['ev_ratio_pc1'])
            self.rotation_mode    = '1d_pca'
            self.rot_bin_t        = self._bin_rotation_1d_pca(self.eef_euler_t)   # (N,) long
            self.kmeans_n_clusters = 0
            print(f"  Rotation: 1D PCA mode, basis={self.rot_pca_path} (EV ratio {self.rot_pca_ev_ratio:.3f})")
        else:
            self.rot_pca_mean = None; self.rot_pca_axis = None
            self.rot_pca_min  = 0.0;  self.rot_pca_max  = 1.0
            self.rot_pca_ev_ratio = 0.0
            self.rotation_mode    = 'per_axis'
            self.kmeans_n_clusters = 0
            self.rot_bin_t        = self._bin_rotation(self.eef_euler_t)          # (N, 3) long
            print(f"  Rotation: per-axis mode (no PCA path provided)")
        self.grip_bin_t = self._bin_gripper(self.gripper_t)        # (N,)   long

        # Height stats — pad range by `height_pad_frac` to keep edges from saturating.
        z_lo, z_hi = float(self.eef_z_t.min()), float(self.eef_z_t.max())
        z_range = z_hi - z_lo
        self.min_height = z_lo - z_range * height_pad_frac
        self.max_height = z_hi + z_range * height_pad_frac
        self.z_bin_t = self._bin_height(self.eef_z_t)  # (N,) long in [0, B-1]
        # bin centers in world units (for inference 3D recovery)
        bin_w = (self.max_height - self.min_height) / self.n_height_bins
        self.bin_centers = self.min_height + (torch.arange(self.n_height_bins).float() + 0.5) * bin_w

        self.samples = []
        for ep_idx, ep in enumerate(self.episodes):
            for t in range(len(ep["frames"]) - 1):
                self.samples.append((ep_idx, t))

        n = len(self.rgb_t)
        gb_rgb = self.rgb_t.element_size() * self.rgb_t.numel() / 1e9
        gb_d   = self.depth_t.element_size() * self.depth_t.numel() / 1e9
        print(f"Smith300DA3VolumeDataset ready: {len(self.episodes)} eps, {n} frames, "
              f"{len(self.samples)} samples, rgb={gb_rgb:.2f} GB, depth={gb_d:.2f} GB")
        print(f"  Height range observed: [{z_lo:.4f}, {z_hi:.4f}] → padded "
              f"[{self.min_height:.4f}, {self.max_height:.4f}], {n_height_bins} bins of "
              f"~{bin_w*1000:.1f}mm each")
        # Bin occupancy summary
        with torch.no_grad():
            counts = torch.bincount(self.z_bin_t, minlength=n_height_bins).tolist()
        print(f"  Z-bin occupancy: min={min(counts)}, max={max(counts)}, "
              f"empty_bins={sum(1 for c in counts if c == 0)}")

    def _bin_height(self, z):
        z = z.float()
        norm = (z - self.min_height) / max(self.max_height - self.min_height, 1e-8)
        return (norm * self.n_height_bins).long().clamp(0, self.n_height_bins - 1)

    def _bin_rotation(self, eul):
        """eul: (N, 3). Returns (N, 3) long — per-axis euler bin indices."""
        eul = eul.float()
        lo = torch.tensor(self.min_rot, dtype=torch.float32)
        hi = torch.tensor(self.max_rot, dtype=torch.float32)
        norm = (eul - lo) / (hi - lo).clamp_min(1e-8)
        return (norm.clamp(0, 1) * (self.n_rot_bins - 1)).long()                # (N, 3)

    def _bin_rotation_kmeans(self, eul):
        """eul: (N, 3) euler xyz. Returns (N,) long — assign each sample to nearest
        centroid in canonical-quaternion space."""
        from scipy.spatial.transform import Rotation as ScipyR
        eul_np = eul.numpy()
        quats = ScipyR.from_euler('xyz', eul_np).as_quat()                         # (N, 4) xyzw
        mask = quats[:, 3] < 0
        quats[mask] *= -1
        quats /= (np.linalg.norm(quats, axis=-1, keepdims=True) + 1e-12)
        # Pairwise L2 distance to all centroids → argmin
        centroids = self.kmeans_centroids_quat                                     # (K, 4)
        d2 = ((quats[:, None, :] - centroids[None, :, :]) ** 2).sum(-1)            # (N, K)
        bins = d2.argmin(axis=-1)
        return torch.tensor(bins, dtype=torch.long)

    def _bin_rotation_1d_pca(self, eul):
        """eul: (N, 3). Returns (N,) long — bin index along the PCA-1D axis."""
        eul = eul.float()
        mean = torch.tensor(self.rot_pca_mean, dtype=torch.float32)
        axis = torch.tensor(self.rot_pca_axis, dtype=torch.float32)
        proj = (eul - mean) @ axis                                              # (N,)
        norm = (proj - self.rot_pca_min) / max(self.rot_pca_max - self.rot_pca_min, 1e-8)
        return (norm.clamp(0, 1) * (self.n_rot_bins - 1)).long()

    def _bin_gripper(self, g):
        g = g.float()
        norm = (g - self.min_grip) / max(self.max_grip - self.min_grip, 1e-8)
        return (norm * self.n_gripper_bins).long().clamp(0, self.n_gripper_bins - 1)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ep_idx, t = self.samples[idx]
        frames = self.episodes[ep_idx]["frames"]
        L = len(frames)
        s = self.s
        cur_g = int(frames[t])
        last_real = L - 1
        future_local_raw = [t + (i + 1) * s for i in range(self.n_window)]
        future_local = [min(i, last_real) for i in future_local_raw]
        valid = torch.tensor([raw <= last_real for raw in future_local_raw], dtype=torch.bool)
        future_global = frames[future_local]
        gt_pix      = self.pix_t[future_global]                    # (N_WINDOW, 2)
        gt_z_bin    = self.z_bin_t[future_global]                  # (N_WINDOW,)
        gt_rot_bin  = self.rot_bin_t[future_global]                # (N_WINDOW, 3) per-axis
        gt_grip_bin = self.grip_bin_t[future_global]               # (N_WINDOW,)
        rgb         = self.rgb_t[cur_g]
        depth       = self.depth_t[cur_g]
        start_pix   = self.pix_t[cur_g]
        # CURRENT frame's gripper/rotation/height bins — used as conditioning for the
        # gripper / rotation prediction heads (so they know the starting state).
        cur_grip_bin = self.grip_bin_t[cur_g]                       # scalar
        cur_rot_bin  = self.rot_bin_t[cur_g]                        # (3,) per-axis
        cur_z_bin    = self.z_bin_t[cur_g]                          # scalar (current height)
        return {
            "rgb":           rgb,
            "gt_pix_504":    gt_pix,
            "gt_pix_valid":  valid,
            "gt_z_bin":      gt_z_bin,
            "gt_rot_bin":    gt_rot_bin,
            "gt_grip_bin":   gt_grip_bin,
            "cur_grip_bin":  cur_grip_bin,
            "cur_rot_bin":   cur_rot_bin,
            "cur_z_bin":     cur_z_bin,
            "da3_depth":     depth,
            "start_pix_504": start_pix,
            "ep_idx":        torch.tensor(ep_idx, dtype=torch.long),
            "start_t":       torch.tensor(t, dtype=torch.long),
        }


if __name__ == "__main__":
    ds = Smith300DA3VolumeDataset()
    s = ds[0]
    for k, v in s.items():
        if hasattr(v, 'shape'): print(f"  {k}: {tuple(v.shape)} {v.dtype}")
        else: print(f"  {k}: {v}")
