"""All-in-memory data loader for the smith300 first_mobile_collection sessions.

Loads ALL 9 robot sessions, runs MuJoCo FK once per frame at init, preloads every
RGB at 448×448 as a single normalized float tensor. Per-sample yield matches the
existing VolumeWindowDataset shape so train_volume_smooth_v2.py runs unchanged
once we point it at this loader.

Per-sample dict:
  rgb:                (3, 448, 448)   ImageNet-normalized
  past_eef_world:     (20, 3)         clamped at episode start (repeat earliest)
  current_eef_world:  (3,)            == past_eef_world[-1]
  target_eef_world:   (8, 3)          future +1..+8 stride frames, clamped at episode end
  target_grip:        (8,)
  target_rot_euler:   (8, 3)          XYZ euler radians
  target_voxel_idx:   (8,)            flat voxel index
  valid_mask:         (8,)            False where future is clamped
  world_to_camera:    (4, 4)          T_camera_arucoBase @ T_W_baseBody_inv_aruco_offset
"""
import os, sys, json
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))
from robot_volume import (
    voxel_centers_world, world_to_voxel_idx,
    N_PAST_EEF, T_FUTURE, IMAGE_SIZE,
)

# Reuse para_mac's MuJoCo XML path (same as their data_smith300_para.py).
sys.path.insert(0, "/data/cameron/para/para_mac")
from data_smith300_para import (
    DEFAULT_SMITH300_XML, EEF_BODY_NAME, project_to_pixel, _scale_K_to,
)
import mujoco

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


def _resize_normalize(bgr, image_size):
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
    return rgb.transpose(2, 0, 1)  # (3, H, W)


class Smith300VolumeDataset(Dataset):
    def __init__(self, root_dir="/data/cameron/mac_robot_datasets/first_mobile_collection",
                 image_size=IMAGE_SIZE, n_past=N_PAST_EEF, t_future=T_FUTURE,
                 frame_stride=1, mujoco_xml=DEFAULT_SMITH300_XML):
        self.image_size = image_size
        self.n_past = n_past
        self.t_future = t_future
        self.s = frame_stride

        # ── Find sessions (dataset_* + extra_singleview_capture) ──
        root = Path(root_dir)
        session_dirs = sorted([d for d in root.iterdir() if d.is_dir()
                                and (d.name.startswith("dataset_")
                                     or d.name == "extra_singleview_capture")])
        if not session_dirs:
            raise FileNotFoundError(f"No sessions under {root}")
        print(f"Smith300VolumeDataset: loading {len(session_dirs)} sessions from {root}")

        # MuJoCo FK setup — one model, reused across sessions
        mj_model = mujoco.MjModel.from_xml_path(mujoco_xml)
        mj_data  = mujoco.MjData(mj_model)
        eef_id   = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY_NAME)
        if eef_id < 0:
            raise RuntimeError(f"body {EEF_BODY_NAME!r} not in {mujoco_xml}")
        n_qpos = mj_model.nq

        # ── Per-session: load meta, episodes, compute FK, preload frames ──
        self.episodes        = []   # flat list across all sessions
        self.frame_data      = []   # per-frame dict (eef_pos, eef_quat, eef_euler, gripper)
        self.rgb_tensor      = []   # per-frame (3, 448, 448) preloaded
        self.w2c_per_session = []   # one (4, 4) per session
        # Episode entry: {"session_idx", "ep_id", "frame_idx_start", "frame_idx_end"} —
        # frame indices are GLOBAL into self.frame_data / self.rgb_tensor

        for sess_idx, sess in enumerate(session_dirs):
            meta_path = sess / "meta.json"
            if not meta_path.exists():
                print(f"  skip {sess.name}: no meta.json")
                continue
            meta = json.load(open(meta_path))
            IMG_W, IMG_H = meta["image_size_wh"]
            K_orig = np.array(meta["K"], dtype=np.float64)
            T_camera_arucoBase = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
            T_W_baseBody       = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
            T_CAM_WORLD        = T_camera_arucoBase @ T_W_baseBody    # (4, 4) world → camera-frame
            K_target           = _scale_K_to((IMG_W, IMG_H), K_orig, image_size)  # (3, 3) pixel intrinsics
            # Bake K into the 4×4 so we can use the same world_to_pixel_torch as LIBERO:
            #   M[:3, :] = K @ T_CAM_WORLD[:3, :]
            #   M[3,  :] = T_CAM_WORLD[3,  :]   (== [0, 0, 0, 1])
            # Then proj = M @ [x,y,z,1] = (u*z, v*z, z, 1), and proj[:2]/proj[2] = (u, v) in pixels.
            M = np.eye(4, dtype=np.float64)
            M[:3, :] = K_target @ T_CAM_WORLD[:3, :]
            self.w2c_per_session.append(M.astype(np.float32))

            # Episodes (under rgb_overlay/episodes.json or episodes.json)
            ep_paths = [sess / "rgb_overlay" / "episodes.json", sess / "episodes.json"]
            ep_path = next((p for p in ep_paths if p.exists()), None)
            if ep_path is None:
                print(f"  skip {sess.name}: no episodes.json")
                continue
            ep_data = json.load(open(ep_path))
            sess_episodes = ep_data["episodes"]

            # Joints
            joints = np.load(sess / "joints.npz")
            q_motors_all = np.asarray(joints["q_motors"], dtype=np.float64)
            n_motors = q_motors_all.shape[1]
            n_frames = q_motors_all.shape[0]

            # Which frames do we actually need? (only those inside an annotated episode)
            needed_frames = set()
            for ep in sess_episodes:
                ep_end = min(int(ep["end"]), n_frames - 1)
                for f in range(int(ep["start"]), ep_end + 1):
                    needed_frames.add(f)

            # Frame-index map: session-local → global
            local_to_global = {}
            for f in sorted(needed_frames):
                # Run FK
                q = np.zeros(n_qpos, dtype=np.float64)
                q[:min(n_motors, n_qpos)] = q_motors_all[f, :n_qpos]
                mj_data.qpos[:n_qpos] = q
                mujoco.mj_forward(mj_model, mj_data)
                eef_pos  = mj_data.xpos[eef_id].copy().astype(np.float32)
                quat_wxyz = mj_data.xquat[eef_id].copy()
                eef_quat  = quat_wxyz[[1, 2, 3, 0]].astype(np.float32)  # xyzw
                eef_eul   = ScipyR.from_quat(eef_quat).as_euler('xyz').astype(np.float32)
                grip = np.float32(q_motors_all[f, 6]) if n_motors >= 7 else np.float32(0.0)

                # Preload image
                img_path = sess / f"rgb_{f:06d}.jpg"
                bgr = cv2.imread(str(img_path))
                if bgr is None:
                    continue
                rgb_t = _resize_normalize(bgr, image_size)

                global_idx = len(self.frame_data)
                local_to_global[f] = global_idx
                self.frame_data.append({
                    "eef_pos":  eef_pos,
                    "eef_quat": eef_quat,
                    "eef_euler": eef_eul,
                    "gripper":  grip,
                    "session_idx": sess_idx,
                })
                self.rgb_tensor.append(rgb_t)

            # Register episodes (clip end to actual loaded frames)
            for ep in sess_episodes:
                ep_start_local = int(ep["start"])
                ep_end_local   = min(int(ep["end"]), n_frames - 1)
                # Collect contiguous global indices inside this episode
                global_frames = []
                for f in range(ep_start_local, ep_end_local + 1):
                    if f in local_to_global:
                        global_frames.append(local_to_global[f])
                if len(global_frames) < 2:
                    continue
                self.episodes.append({
                    "session_idx": sess_idx,
                    "ep_id": ep.get("id", f"ep_{len(self.episodes)}"),
                    "global_frames": np.asarray(global_frames, dtype=np.int64),
                })

            print(f"  {sess.name}: {len([e for e in self.episodes if e['session_idx']==sess_idx])} eps, "
                  f"{sum(1 for fd in self.frame_data if fd['session_idx']==sess_idx)} frames loaded")

        # Stack the preloaded images into one big tensor on CPU (pin via DataLoader)
        self.rgb_tensor = torch.from_numpy(np.stack(self.rgb_tensor, axis=0))  # (N, 3, H, W) float32
        # Per-session world_to_camera as one tensor we can index by session_idx
        self.w2c_per_session = torch.from_numpy(np.stack(self.w2c_per_session, axis=0))  # (S, 4, 4)

        # Pre-extract per-frame fields as tensors (so __getitem__ is just index + gather)
        N = len(self.frame_data)
        self.eef_pos   = torch.tensor(np.stack([fd["eef_pos"]  for fd in self.frame_data], axis=0))  # (N, 3)
        self.eef_quat  = torch.tensor(np.stack([fd["eef_quat"] for fd in self.frame_data], axis=0))  # (N, 4)
        self.eef_euler = torch.tensor(np.stack([fd["eef_euler"]for fd in self.frame_data], axis=0))  # (N, 3)
        self.gripper   = torch.tensor(np.stack([fd["gripper"]  for fd in self.frame_data], axis=0))  # (N,)
        self.session   = torch.tensor(np.stack([fd["session_idx"] for fd in self.frame_data], axis=0), dtype=torch.long)  # (N,)

        # Build sample list: for each episode, every (ep_idx, t-in-ep) is a sample
        self.samples = []
        for ep_idx, ep in enumerate(self.episodes):
            L = len(ep["global_frames"])
            for t in range(L - 1):  # need t+stride to exist
                self.samples.append((ep_idx, t))

        print(f"Smith300VolumeDataset ready: {len(self.episodes)} episodes, "
              f"{len(self.frame_data)} frames, {len(self.samples)} samples, "
              f"rgb tensor {self.rgb_tensor.shape} ({self.rgb_tensor.element_size() * self.rgb_tensor.numel() / 1e9:.2f} GB)")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ep_idx, t = self.samples[idx]
        ep = self.episodes[ep_idx]
        frames = ep["global_frames"]
        L = len(frames)
        N = self.n_past
        Tf = self.t_future
        s = self.s

        # Past indices: [t-(N-1), ..., t] clamped at 0 (repeat earliest)
        past_local = [max(0, t - (N - 1 - i)) for i in range(N)]
        past_global = frames[past_local]
        # Future indices: [t+s, ..., t+Tf*s] clamped at L-1
        last_real = L - 1
        future_local_raw = [t + (i + 1) * s for i in range(Tf)]
        future_local = [min(i, last_real) for i in future_local_raw]
        valid_mask = torch.tensor([raw <= last_real for raw in future_local_raw], dtype=torch.bool)
        future_global = frames[future_local]

        rgb = self.rgb_tensor[frames[t]]
        past_pos  = self.eef_pos[past_global]
        cur_pos   = self.eef_pos[frames[t]]
        tgt_pos   = self.eef_pos[future_global]
        tgt_grip  = self.gripper[future_global]
        tgt_eul   = self.eef_euler[future_global]
        tgt_vox   = world_to_voxel_idx(tgt_pos)

        sess_idx  = int(self.session[frames[t]].item())
        w2c       = self.w2c_per_session[sess_idx]

        return {
            "rgb":               rgb,
            "past_eef_world":    past_pos,
            "current_eef_world": cur_pos,
            "target_eef_world":  tgt_pos,
            "target_grip":       tgt_grip,
            "target_rot_euler":  tgt_eul,
            "target_voxel_idx":  tgt_vox,
            "valid_mask":        valid_mask,
            "world_to_camera":   w2c,
            "ep_idx":            torch.tensor(ep_idx, dtype=torch.long),
            "start_t":           torch.tensor(t, dtype=torch.long),
        }


if __name__ == "__main__":
    ds = Smith300VolumeDataset(image_size=448, frame_stride=1)
    print("len(ds):", len(ds))
    s = ds[0]
    for k, v in s.items():
        print(f"  {k}: {tuple(v.shape) if hasattr(v, 'shape') else v}")
