"""All-in-memory data loader for DA3 pixel-aligned training (smith300 robot data).

Per sample:
  rgb:           (3, 504, 504) float32 in [0, 1] — DA3 expects [0, 1] (not ImageNet-normed)
  gt_pix_504:    (N_WINDOW, 2) GT EEF pixel coords in 504-space (model trains heatmap at 288 res
                  — we scale at loss-time)
  gt_pix_valid:  (N_WINDOW,) bool — False for clamped (off-episode) future steps
  da3_depth:     (504, 504) float32 — frozen DA3 depth, used as distillation target
"""
import os, sys, json
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

# Reuse para_mac's MuJoCo FK
sys.path.insert(0, "/data/cameron/para/para_mac")
from data_smith300_para import EEF_BODY_NAME, _scale_K_to
# Use local XML+assets copy (the mac mount path is flaky for file reads).
DEFAULT_SMITH300_XML = "/data/cameron/para/libero/example_twolink.xml"
import mujoco

DA3_INPUT = 504
N_WINDOW  = 8


def project(world_pts, world_to_camera, K):
    """world_pts: (N, 3) → (N, 2) (u, v) pixel + (N,) depth. Numpy."""
    ones = np.ones((world_pts.shape[0], 1))
    pts_h = np.concatenate([world_pts, ones], axis=-1)
    cam = (world_to_camera @ pts_h.T).T[:, :3]
    z = np.clip(cam[:, 2], 1e-3, None)
    norm = cam[:, :2] / z[:, None]
    homog = np.concatenate([norm, np.ones((norm.shape[0], 1))], axis=-1)
    pix = (K @ homog.T).T[:, :2]
    return pix, cam[:, 2]


class Smith300DA3Dataset(Dataset):
    def __init__(self, root_dir="/data/cameron/mac_robot_datasets/first_mobile_collection",
                 image_size=DA3_INPUT, n_window=N_WINDOW, frame_stride=1,
                 mujoco_xml=DEFAULT_SMITH300_XML, depth_subdir="da3_depth"):
        self.depth_subdir = depth_subdir
        self.image_size = image_size
        self.n_window = n_window
        self.s = frame_stride

        root = Path(root_dir)
        sessions = sorted([d for d in root.iterdir() if d.is_dir()])
        if not sessions:
            raise FileNotFoundError(f"No sessions under {root}")
        print(f"Smith300DA3Dataset: loading {len(sessions)} sessions")

        mj_model = mujoco.MjModel.from_xml_path(mujoco_xml)
        mj_data  = mujoco.MjData(mj_model)
        eef_id   = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY_NAME)
        n_qpos = mj_model.nq

        self.episodes = []     # {"frames": np.array of global indices}
        # Per-frame fields
        self.rgb_t       = []  # (N, 3, 504, 504)
        self.depth_t     = []  # (N, 504, 504)
        self.pix_t       = []  # (N, 2) — EEF pixel in 504 space
        self.session_idx = []
        # session-level cam matrix not strictly needed downstream; we only store per-frame pixels

        for sess_idx, sess in enumerate(sessions):
            meta_path = sess / "meta.json"
            if not meta_path.exists():
                continue
            meta = json.load(open(meta_path))
            IMG_W, IMG_H = meta["image_size_wh"]
            K_orig = np.array(meta["K"], dtype=np.float64)
            T_camera_arucoBase = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
            T_W_baseBody       = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
            T_CAM_WORLD = T_camera_arucoBase @ T_W_baseBody
            K_target = _scale_K_to((IMG_W, IMG_H), K_orig, image_size)

            ep_path = next((sess / p for p in ["rgb_overlay/episodes.json", "episodes.json"]
                            if (sess / p).exists()), None)
            if ep_path is None:
                continue
            sess_episodes = json.load(open(ep_path))["episodes"]

            joints = np.load(sess / "joints.npz")
            q_motors_all = np.asarray(joints["q_motors"], dtype=np.float64)
            n_motors = q_motors_all.shape[1]
            n_frames = q_motors_all.shape[0]

            needed = set()
            for ep in sess_episodes:
                for f in range(int(ep["start"]), min(int(ep["end"]), n_frames - 1) + 1):
                    needed.add(f)

            local_to_global = {}
            for f in sorted(needed):
                # FK
                q = np.zeros(n_qpos, dtype=np.float64)
                q[:min(n_motors, n_qpos)] = q_motors_all[f, :n_qpos]
                mj_data.qpos[:n_qpos] = q
                mujoco.mj_forward(mj_model, mj_data)
                eef_pos = mj_data.xpos[eef_id].copy()
                pix, _ = project(eef_pos.reshape(1, 3), T_CAM_WORLD, K_target)
                pix = pix[0].astype(np.float32)
                # Skip frames where EEF is OOB
                if not (0 <= pix[0] < image_size and 0 <= pix[1] < image_size):
                    continue

                img_path = sess / f"rgb_{f:06d}.jpg"
                depth_path = sess / self.depth_subdir / f"rgb_{f:06d}.npy"
                if not img_path.exists() or not depth_path.exists():
                    continue
                bgr = cv2.imread(str(img_path))
                if bgr is None: continue
                rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
                rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
                rgb = rgb.astype(np.float32) / 255.0
                rgb_chw = rgb.transpose(2, 0, 1)
                depth = np.load(depth_path).astype(np.float32)  # (504, 504)

                g_idx = len(self.rgb_t)
                local_to_global[f] = g_idx
                self.rgb_t.append(rgb_chw)
                self.depth_t.append(depth)
                self.pix_t.append(pix)
                self.session_idx.append(sess_idx)

            for ep in sess_episodes:
                fs = []
                for f in range(int(ep["start"]), min(int(ep["end"]), n_frames - 1) + 1):
                    if f in local_to_global:
                        fs.append(local_to_global[f])
                if len(fs) < 2:
                    continue
                self.episodes.append({"frames": np.asarray(fs, dtype=np.int64)})

            print(f"  {sess.name}: {sum(1 for s in self.session_idx if s == sess_idx)} frames loaded")

        self.rgb_t   = torch.from_numpy(np.stack(self.rgb_t,   axis=0))
        self.depth_t = torch.from_numpy(np.stack(self.depth_t, axis=0))
        self.pix_t   = torch.from_numpy(np.stack(self.pix_t,   axis=0))

        self.samples = []
        for ep_idx, ep in enumerate(self.episodes):
            for t in range(len(ep["frames"]) - 1):
                self.samples.append((ep_idx, t))

        n = len(self.rgb_t)
        gb_rgb = self.rgb_t.element_size() * self.rgb_t.numel() / 1e9
        gb_d   = self.depth_t.element_size() * self.depth_t.numel() / 1e9
        print(f"Smith300DA3Dataset ready: {len(self.episodes)} eps, {n} frames, "
              f"{len(self.samples)} samples, rgb={gb_rgb:.2f} GB, depth={gb_d:.2f} GB")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ep_idx, t = self.samples[idx]
        frames = self.episodes[ep_idx]["frames"]
        L = len(frames)
        s = self.s

        # Current frame (the input)
        cur_g = int(frames[t])
        # Future N_WINDOW pixels: [t+s, t+2s, ..., t+N_WINDOW*s] clamped at end
        last_real = L - 1
        future_local_raw = [t + (i + 1) * s for i in range(self.n_window)]
        future_local = [min(i, last_real) for i in future_local_raw]
        valid = torch.tensor([raw <= last_real for raw in future_local_raw], dtype=torch.bool)
        future_global = frames[future_local]
        gt_pix = self.pix_t[future_global]                        # (N_WINDOW, 2)
        rgb = self.rgb_t[cur_g]                                    # (3, 504, 504)
        depth = self.depth_t[cur_g]                                # (504, 504)
        return {
            "rgb":          rgb,
            "gt_pix_504":   gt_pix,
            "gt_pix_valid": valid,
            "da3_depth":    depth,
            "ep_idx":       torch.tensor(ep_idx, dtype=torch.long),
            "start_t":      torch.tensor(t, dtype=torch.long),
        }
