"""Smith300 arm dataset for PARA training.

Adapted from panda_streaming/data_panda_para.py for the smith300 capture
format:
  rgb_NNNNNN.jpg     # 960x540 (anisotropic resize to 448x448 for the model)
  joints.npz         # q_motors[T,6] (no gripper recorded yet) + ticks/timestamps
  meta.json          # K, T_camera_arucoBase, T_W_baseBody_inv_aruco_offset, image_size_wh
  rgb_overlay/episodes.json  # parsed episodes (or root episodes.json)

Key differences vs the panda version:
  - q_motors has 6 entries; the smith300 MuJoCo model has 7 hinge joints
    (6 arm + 1 gripper finger). We pad q[6] = 0 since gripper isn't recorded.
  - World->camera transform is T_camera_arucoBase @ T_W_baseBody_inv_aruco_offset
    (camera pose in arucoBase frame, then arucoBase->baseBody offset).
  - EEF body in the smith300 XML is "virtual_gripper_keypoint" (not "hand").
  - trajectory_gripper is set to 0 everywhere; train_smith300_para.py sets
    GRIPPER_LOSS_WEIGHT=0 so the head is still wired but contributes no loss.
"""
import os
import json

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR
import mujoco

N_WINDOW = 6
N_MOTORS_FULL = 7  # smith300 model has 7 hinge joints
IMAGE_SIZE = 448
EEF_BODY_NAME = "virtual_gripper_keypoint"
DEFAULT_SMITH300_XML = (
    "/home/cameronsmith/mnt/mac/smith300_para_stuff/example_twolink.xml"
)


def project_to_pixel(pos_world, T_cw, K):
    """Project a 3D world point to 2D pixel coordinates."""
    p_cam = T_cw[:3, :3] @ pos_world + T_cw[:3, 3]
    if p_cam[2] <= 0:
        return None
    u = K[0, 0] * p_cam[0] / p_cam[2] + K[0, 2]
    v = K[1, 1] * p_cam[1] / p_cam[2] + K[1, 2]
    return np.array([u, v], dtype=np.float32)


# Standard non-identity RGB->RGB' permutations for channel-swap augmentation.
_CHAN_PERMS = np.array([
    [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0],
], dtype=np.int64)


def _augment_color(rgb_hwc: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """Aggressive color augmentation for the (H, W, 3) float32 RGB input in
    [0, 1]. Specifically tuned to defeat the gripper-appearance shortcut
    when training on one rig (UMI: green gripper) and deploying on another
    (smith300: white gripper):
      - Channel permutation @ 80% (was 50%): turns green into red/blue/etc,
        forcing features that aren't channel-specific.
      - Full-circle hue jitter (was ±30deg): hue can land anywhere, so
        "green" isn't a stable cue.
      - Saturation 0.0-1.8x (was 0.5-1.5): includes full grayscale.
      - 30% pure grayscale convert: hardest forcing — model must use shape
        and texture, not color.
      - Brightness 0.6-1.4x (was 0.7-1.3).
    """
    out = rgb_hwc
    # 1. Channel permutation (swap R/G/B). Bumped to 80%.
    if rng.random() < 0.8:
        perm = _CHAN_PERMS[rng.integers(0, len(_CHAN_PERMS))]
        out = out[..., perm]
    # 2. Hue + saturation jitter (single HSV pass to avoid double conversion).
    do_hue = rng.random() < 0.7
    do_sat = rng.random() < 0.7
    if do_hue or do_sat:
        hsv = cv2.cvtColor((out * 255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
        if do_hue:
            # Full circle: ±90 (OpenCV hue is 0..180).
            hsv[..., 0] = (hsv[..., 0] + int(rng.integers(-90, 91))) % 180
        if do_sat:
            hsv[..., 1] = np.clip(hsv[..., 1] * float(rng.uniform(0.0, 1.8)), 0, 255)
        out = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32) / 255.0
    # 3. Pure grayscale convert (30%): strips all color info, the strongest
    # possible forcing toward shape/texture features.
    if rng.random() < 0.3:
        gray = (0.299 * out[..., 0] + 0.587 * out[..., 1] + 0.114 * out[..., 2])
        out = np.stack([gray, gray, gray], axis=-1)
    # 4. Brightness.
    if rng.random() < 0.6:
        out = np.clip(out * float(rng.uniform(0.6, 1.4)), 0, 1)
    return out


def _scale_K_to(image_size_wh_orig, K_orig, target_size):
    """Anisotropic K rescale from (W_orig, H_orig) -> (target_size, target_size)."""
    W, H = image_size_wh_orig
    sx = target_size / float(W)
    sy = target_size / float(H)
    K = K_orig.copy().astype(np.float64)
    K[0, 0] *= sx
    K[0, 2] *= sx
    K[1, 1] *= sy
    K[1, 2] *= sy
    return K


class Smith300TrajectoryDataset(Dataset):
    """Same contract as PandaTrajectoryDataset (returns the same dict shape)."""

    def __init__(self, data_dir, episodes_json=None,
                 image_size=IMAGE_SIZE, frame_stride=1,
                 mujoco_xml=DEFAULT_SMITH300_XML,
                 augment_color: bool = False,
                 n_window: int = N_WINDOW,
                 use_keyframes: bool = False):
        self.data_dir = data_dir
        self.image_size = image_size
        self.n_window = int(n_window)
        self.frame_stride = frame_stride
        self.augment_color = bool(augment_color)
        self.use_keyframes = bool(use_keyframes)
        # Per-worker rng so DataLoader workers don't all use the same seed.
        # (numpy default_rng with no seed pulls from system entropy.)
        self._rng = np.random.default_rng()

        # ── Camera intrinsics + extrinsics from meta.json ───────────────
        with open(os.path.join(data_dir, "meta.json")) as f:
            meta = json.load(f)
        IMG_W, IMG_H = meta["image_size_wh"]
        K_orig = np.array(meta["K"], dtype=np.float64)
        T_camera_arucoBase = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
        T_W_baseBody = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
        # Camera pose expressed in MuJoCo's "world" (= base body) frame
        T_CAM_WORLD = T_camera_arucoBase @ T_W_baseBody

        self.IMG_W = IMG_W
        self.IMG_H = IMG_H
        self.K_orig = K_orig
        self.K_target = _scale_K_to((IMG_W, IMG_H), K_orig, image_size)
        self.T_CAM_WORLD = T_CAM_WORLD

        # ── Episodes ─────────────────────────────────────────────────────
        if episodes_json is None:
            for cand in ["rgb_overlay/episodes.json", "episodes.json"]:
                p = os.path.join(data_dir, cand)
                if os.path.exists(p):
                    episodes_json = p
                    break
        if episodes_json is None or not os.path.exists(episodes_json):
            raise FileNotFoundError(
                f"No episodes.json found under {data_dir}/rgb_overlay/ or {data_dir}/."
            )
        with open(episodes_json) as f:
            ep_data = json.load(f)
        self.episodes = ep_data["episodes"]

        # ── Joint state ──────────────────────────────────────────────────
        joints = np.load(os.path.join(data_dir, "joints.npz"))
        q_motors_all = np.asarray(joints["q_motors"], dtype=np.float64)
        n_recorded_motors = q_motors_all.shape[1]
        n_frames_total = q_motors_all.shape[0]

        # If the dataset already saved EEF poses (UMI capture path: pose comes
        # from the umi-aruco PnP, not from arm FK), prefer those over running
        # FK. Same baseBody-frame convention either way -- the smith300 dataset
        # writer puts EEFs in MuJoCo world (= baseBody), and vis_umi.py also
        # transforms its detected pose into that frame before saving.
        eef_pos_saved = (np.asarray(joints["eef_pos"], dtype=np.float64)
                         if "eef_pos" in joints.files else None)
        eef_quat_saved = (np.asarray(joints["eef_quat"], dtype=np.float64)
                          if "eef_quat" in joints.files else None)
        eef_euler_saved = (np.asarray(joints["eef_euler"], dtype=np.float64)
                           if "eef_euler" in joints.files else None)
        use_saved_eef = (eef_pos_saved is not None
                         and eef_quat_saved is not None
                         and eef_euler_saved is not None)

        # ── MuJoCo FK setup ──────────────────────────────────────────────
        print(f"Loading MuJoCo model: {mujoco_xml}", flush=True)
        mj_model = mujoco.MjModel.from_xml_path(mujoco_xml)
        mj_data = mujoco.MjData(mj_model)
        eef_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY_NAME)
        if eef_id < 0:
            raise RuntimeError(f"body {EEF_BODY_NAME!r} missing from XML {mujoco_xml}")
        n_qpos = mj_model.nq

        # ── Pre-compute per-frame FK + projection ────────────────────────
        all_frame_indices = set()
        for ep in self.episodes:
            ep_end = min(int(ep["end"]), n_frames_total - 1)
            for idx in range(int(ep["start"]), ep_end + 1):
                all_frame_indices.add(idx)

        print(f"Pre-computing per-frame data for {len(all_frame_indices)} "
              f"frames "
              f"({'using saved EEF poses (UMI)' if use_saved_eef else 'running arm FK'})",
              flush=True)
        self.frame_data = {}
        for idx in sorted(all_frame_indices):
            q_in = q_motors_all[idx]
            if use_saved_eef:
                # UMI path: dataset already provides eef_pos/quat/euler in
                # baseBody frame; skip arm FK entirely.
                eef_pos = eef_pos_saved[idx].astype(np.float32)
                eef_quat = eef_quat_saved[idx].astype(np.float32)  # xyzw
                eef_euler = eef_euler_saved[idx].astype(np.float32)
            else:
                q = np.zeros(n_qpos, dtype=np.float64)
                q[:min(n_recorded_motors, n_qpos)] = q_in[:n_qpos]
                mj_data.qpos[:n_qpos] = q
                mujoco.mj_forward(mj_model, mj_data)
                eef_pos = mj_data.xpos[eef_id].copy().astype(np.float32)
                quat_wxyz = mj_data.xquat[eef_id].copy()
                eef_quat = quat_wxyz[[1, 2, 3, 0]].astype(np.float32)  # xyzw
                eef_euler = ScipyR.from_quat(eef_quat).as_euler('xyz').astype(np.float32)

            pix = project_to_pixel(
                eef_pos.astype(np.float64), T_CAM_WORLD, self.K_target,
            )
            # `pixel_valid` is True only when the EEF actually lands inside the
            # image. When the projection fails (behind camera) or lands outside
            # the frame, we still emit a clamped pixel for shape consistency,
            # but the training loss MUST drop these timesteps — otherwise the
            # loader writes (0, 0) into the heatmap target and trains a strong
            # peak at the top-left corner.
            if pix is not None:
                pixel_valid = (0.0 <= float(pix[0]) <= float(image_size - 1)
                               and 0.0 <= float(pix[1]) <= float(image_size - 1))
                pixel_2d = np.array([
                    np.clip(pix[0], 0, image_size - 1),
                    np.clip(pix[1], 0, image_size - 1),
                ], dtype=np.float32)
            else:
                pixel_valid = False
                pixel_2d = np.zeros(2, dtype=np.float32)

            # Gripper: if the dataset recorded a 7th motor (q[6]), use it as
            # the gripper target value. Otherwise fall back to 0 (older
            # datasets, train_smith300_para.py zeros GRIPPER_LOSS_WEIGHT in
            # that case). The downstream BCE loss thresholds gripper > 0
            # for "closed", which matches our smith300 convention where
            # q[6] grows positive as the gripper closes from middle.
            grip = np.float32(q_in[6]) if n_recorded_motors >= 7 else np.float32(0.0)

            self.frame_data[idx] = {
                "eef_pos": eef_pos,
                "eef_quat": eef_quat,
                "eef_euler": eef_euler,
                "pixel_2d": pixel_2d,
                "pixel_valid": bool(pixel_valid),
                "gripper": grip,
            }

        # ── Build sample windows ─────────────────────────────────────────
        # Two sampling modes:
        #   contiguous: ONE sample per frame in [ep_start, ep_end]. The
        #               window [t, t+stride, ..., t+(n_window-1)*stride] is
        #               clamped to ep_end when it would run past — the last
        #               frame is repeated as padding. So an episode of length
        #               L yields L samples (the model sees every frame as a
        #               starting state), not L - n_window + 1.
        #   keyframes : episodes' user-annotated keyframes drive the window.
        #               One sample per keyframe (K samples for K keyframes):
        #               starting at kf[i] we build [kf[i], kf[i+1], ..., kf[i+n_window-1]],
        #               padding with kf[-1] when we run past the end. So an
        #               episode with K keyframes contributes K samples — the
        #               model sees every keyframe as a starting state.
        self.samples = []
        self.episode_keyframes = []  # only populated in keyframe mode
        self.episode_ends = []       # ep_end per ep_idx (for contiguous padding)
        n_skipped_invalid = 0
        for ep_idx, ep in enumerate(self.episodes):
            ep_start = int(ep["start"])
            ep_end = min(int(ep["end"]), n_frames_total - 1)
            self.episode_ends.append(ep_end)
            if self.use_keyframes:
                kf_frames = [int(kf["frame"]) for kf in ep.get("keyframes", [])
                             if ep_start <= int(kf["frame"]) <= ep_end]
                if not kf_frames:
                    print(f"WARN: episode {ep.get('id', ep_idx)} has no "
                          f"keyframes; skipping in keyframes mode.", flush=True)
                    self.episode_keyframes.append([])
                    continue
                # Store UNPADDED — the pad-with-last-kf happens on-the-fly in
                # __getitem__ so we get K samples per episode (one per kf).
                self.episode_keyframes.append(kf_frames)
                for start_kf in range(len(kf_frames)):
                    if not self.frame_data[kf_frames[start_kf]].get("pixel_valid", True):
                        n_skipped_invalid += 1
                        continue
                    self.samples.append((ep_idx, start_kf))
            else:
                self.episode_keyframes.append([])  # keep alignment
                ep_len = ep_end - ep_start + 1
                for t in range(ep_len):  # one sample per frame
                    f = ep_start + t
                    # Skip samples whose START frame has an invalid EEF
                    # projection (off-frame or behind-camera). Without this
                    # the loader would feed start_kp = (0, 0) to the model,
                    # AND the heatmap target for the current step would be
                    # clamped to the top-left corner — both train an attractor
                    # there. Future-window timesteps that hit invalid frames
                    # are still masked out in the loss via "trajectory_valid".
                    if not self.frame_data[f].get("pixel_valid", True):
                        n_skipped_invalid += 1
                        continue
                    self.samples.append((ep_idx, f))
        if n_skipped_invalid:
            print(f"  skipped {n_skipped_invalid} samples whose start frame "
                  f"had an off-frame EEF projection.", flush=True)

        # Normalized intrinsics (orig, not scaled)
        self.cam_k_norm = K_orig.copy().astype(np.float32)
        self.cam_k_norm[0] /= IMG_W
        self.cam_k_norm[1] /= IMG_H

        # ImageNet normalization
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        if self.use_keyframes:
            kf_counts = [len(k) for k in self.episode_keyframes if k]
            print(f"Smith300Dataset (keyframes): {len(self.episodes)} episodes, "
                  f"{len(self.samples)} samples, n_window={self.n_window}, "
                  f"keyframes/ep (post-pad)={kf_counts}", flush=True)
        else:
            print(f"Smith300Dataset: {len(self.episodes)} episodes, "
                  f"{len(self.samples)} samples, stride={frame_stride}", flush=True)

        # ── Preload all referenced RGBs into memory ─────────────────────
        # JPEG decode + resize dominate __getitem__ when each sample touches
        # n_window frames. Datasets are small (a few hundred frames at
        # 448x448x3 uint8 = ~200 KB per frame), so caching the entire set
        # is trivial RAM (~70 MB / 350 frames) and removes disk I/O from
        # the training loop.
        self.rgb_cache: dict[int, np.ndarray] = {}
        for frame_idx in sorted(all_frame_indices):
            path = os.path.join(data_dir, f"rgb_{frame_idx:06d}.jpg")
            bgr = cv2.imread(path, cv2.IMREAD_COLOR)
            if bgr is None:
                raise FileNotFoundError(f"missing frame: {path}")
            if bgr.shape[1] != image_size or bgr.shape[0] != image_size:
                bgr = cv2.resize(bgr, (image_size, image_size),
                                 interpolation=cv2.INTER_AREA)
            self.rgb_cache[int(frame_idx)] = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        n_bytes = sum(a.nbytes for a in self.rgb_cache.values())
        print(f"Preloaded {len(self.rgb_cache)} RGB frames "
              f"({n_bytes/(1<<20):.1f} MB) into memory.", flush=True)

    def __len__(self):
        return len(self.samples)

    def _load_rgb_resized(self, frame_idx):
        # Cache hit (almost always — only frames referenced by episode ranges
        # get pre-loaded, which is everything the loader could ever ask for).
        rgb_u8 = self.rgb_cache.get(int(frame_idx))
        if rgb_u8 is None:
            # Defensive fallback: decode from disk if cache somehow misses.
            path = os.path.join(self.data_dir, f"rgb_{frame_idx:06d}.jpg")
            bgr = cv2.imread(path, cv2.IMREAD_COLOR)
            if bgr is None:
                raise FileNotFoundError(f"missing frame: {path}")
            if bgr.shape[1] != self.image_size or bgr.shape[0] != self.image_size:
                bgr = cv2.resize(bgr, (self.image_size, self.image_size),
                                 interpolation=cv2.INTER_AREA)
            rgb_u8 = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        return rgb_u8.astype(np.float32) / 255.0

    def __getitem__(self, idx):
        ep_idx, start_frame = self.samples[idx]
        # In keyframe mode start_frame is actually start_kf_idx; resolve
        # the per-step frame indices from the (unpadded) keyframe list,
        # clamping to the last kf when the window runs past the end.
        if self.use_keyframes:
            kf_frames = self.episode_keyframes[ep_idx]
            K = len(kf_frames)
            window_frames = [kf_frames[min(start_frame + k, K - 1)]
                             for k in range(self.n_window)]
        else:
            # Contiguous: clamp to ep_end so windows starting near the end
            # of an episode pad with the last frame (the model sees every
            # frame as a valid starting state).
            ep_end = self.episode_ends[ep_idx]
            window_frames = [min(start_frame + k * self.frame_stride, ep_end)
                             for k in range(self.n_window)]

        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []
        trajectory_quat = []
        trajectory_euler = []
        trajectory_valid = []
        rgb_frames_raw = []
        rgb_ref = None

        for k in range(self.n_window):
            frame_idx = window_frames[k]
            rgb = self._load_rgb_resized(frame_idx)
            if rgb_ref is None:
                rgb_ref = rgb
            rgb_frames_raw.append(rgb)

            fd = self.frame_data[frame_idx]
            trajectory_3d.append(fd["eef_pos"])
            trajectory_quat.append(fd["eef_quat"])
            trajectory_euler.append(fd["eef_euler"])
            trajectory_2d.append(fd["pixel_2d"])
            trajectory_gripper.append(fd["gripper"])
            trajectory_valid.append(fd.get("pixel_valid", True))

        trajectory_2d = np.stack(trajectory_2d)
        trajectory_3d = np.stack(trajectory_3d)
        trajectory_gripper = np.array(trajectory_gripper, dtype=np.float32)
        trajectory_quat = np.stack(trajectory_quat)
        trajectory_euler = np.stack(trajectory_euler)
        trajectory_valid = np.asarray(trajectory_valid, dtype=bool)
        rgb_frames_raw = np.stack(rgb_frames_raw)

        # Heatmap targets at the model's image_size resolution
        heatmap_targets = np.zeros((self.n_window, self.image_size, self.image_size), dtype=np.float32)
        for t in range(self.n_window):
            x_i = int(np.clip(round(float(trajectory_2d[t, 0])), 0, self.image_size - 1))
            y_i = int(np.clip(round(float(trajectory_2d[t, 1])), 0, self.image_size - 1))
            heatmap_targets[t, y_i, x_i] = 1.0

        # Apply color augmentation to the model-input frame ONLY -- keep
        # rgb_frames_raw untouched so wandb visualization shows the real
        # capture (and doesn't shimmer between cached vis batches).
        rgb_for_model = rgb_ref
        if self.augment_color:
            rgb_for_model = _augment_color(rgb_ref.copy(), self._rng)
        rgb_t = (np.transpose(rgb_for_model, (2, 0, 1)) - self.mean[:, None, None]) / self.std[:, None, None]

        world_to_camera = self.T_CAM_WORLD.astype(np.float32)
        camera_pose = np.linalg.inv(self.T_CAM_WORLD).astype(np.float32)

        return {
            "rgb": torch.from_numpy(rgb_t).float(),
            "heatmap_target": torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d": torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d": torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat": torch.from_numpy(trajectory_quat).float(),
            "trajectory_euler": torch.from_numpy(trajectory_euler).float(),
            "trajectory_valid": torch.from_numpy(trajectory_valid),
            "rgb_frames_raw": torch.from_numpy(rgb_frames_raw).float(),
            "world_to_camera": torch.from_numpy(world_to_camera).float(),
            "base_z": torch.tensor(0.0, dtype=torch.float32),
            "target_3d": torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose": torch.from_numpy(camera_pose).float(),
            "cam_K_norm": torch.from_numpy(self.cam_k_norm).float(),
            "demo_idx": torch.tensor(ep_idx, dtype=torch.long),
            "start_t": torch.tensor(start_frame, dtype=torch.long),
        }


if __name__ == "__main__":
    ds = Smith300TrajectoryDataset(
        "/data/cameron/mac_robot_datasets/dataset_20260501_180125")
    print(f"Dataset size: {len(ds)}")
    s = ds[0]
    for k, v in s.items():
        if isinstance(v, torch.Tensor):
            print(f"  {k}: {tuple(v.shape)} {v.dtype}")
        else:
            print(f"  {k}: {v}")
