"""Panda real-robot dataset for PARA training.

Loads pre-cached 448x448 JPG frames + joint state NPYs, with FK and projection
pre-computed at init time for fast __getitem__.
"""
import os
import json
import glob

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR
import mujoco

N_WINDOW = 6
N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04
IMAGE_SIZE = 448

# Calibrated camera params
T_CAM_WORLD = np.array([
    [ 0.94774869,  0.29251588,  0.12730624, -0.41554978],
    [ 0.24809099, -0.42493276, -0.87056476,  0.33529506],
    [-0.20055743,  0.85666015, -0.47530002,  1.1555837 ],
    [ 0.,          0.,          0.,          1.        ]], dtype=np.float64)

CAM_K = np.array([
    [1372.7, 0.0,    956.9],
    [0.0,    1357.2, 555.0],
    [0.0,    0.0,    1.0  ]], dtype=np.float64)

# Original image resolution
IMG_W, IMG_H = 1920, 1080

# Camera intrinsics scaled to 448x448
CAM_K_448 = np.array([
    [CAM_K[0, 0] * IMAGE_SIZE / IMG_W, 0.0, CAM_K[0, 2] * IMAGE_SIZE / IMG_W],
    [0.0, CAM_K[1, 1] * IMAGE_SIZE / IMG_H, CAM_K[1, 2] * IMAGE_SIZE / IMG_H],
    [0.0, 0.0, 1.0],
], dtype=np.float64)


def project_to_pixel(pos_world, T_cw, K):
    """Project a 3D world point to 2D pixel coordinates."""
    p_cam = T_cw[:3, :3] @ pos_world + T_cw[:3, 3]
    if p_cam[2] <= 0:
        return None
    u = K[0, 0] * p_cam[0] / p_cam[2] + K[0, 2]
    v = K[1, 1] * p_cam[1] / p_cam[2] + K[1, 2]
    return np.array([u, v], dtype=np.float32)


class PandaTrajectoryDataset(Dataset):
    """Fast dataset for PARA training on real Panda data.

    Pre-computes FK, projection, and gripper values at init.
    __getitem__ only reads cached 448x448 JPGs and assembles tensors.
    """

    def __init__(self, data_dir, episodes_json=None, image_size=IMAGE_SIZE, frame_stride=1):
        self.data_dir = data_dir
        self.cache_dir = os.path.join(data_dir, "cached_448")
        self.image_size = image_size
        self.n_window = N_WINDOW
        self.frame_stride = frame_stride

        if not os.path.isdir(self.cache_dir):
            raise FileNotFoundError(
                f"Cached images not found at {self.cache_dir}. "
                "Run the pre-cache script first.")

        # Load episode annotations
        if episodes_json is None:
            episodes_json = os.path.join(data_dir, "episodes.json")
        with open(episodes_json) as f:
            ep_data = json.load(f)
        self.episodes = ep_data["episodes"]

        # Pre-compute FK for all frames in all episodes
        print("Pre-computing FK for all frames...", flush=True)
        from ExoConfigs.panda_exo_handeye_4x2 import PANDA_HANDEYE_4X2_CONFIG
        mj_model = mujoco.MjModel.from_xml_string(PANDA_HANDEYE_4X2_CONFIG.xml)
        mj_data = mujoco.MjData(mj_model)
        hand_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "hand")

        # Collect all unique frame indices across episodes
        all_frame_indices = set()
        for ep in self.episodes:
            for idx in range(ep["start"], ep["end"] + 1):
                all_frame_indices.add(idx)

        # Pre-compute per-frame: eef_pos, eef_quat, pixel_2d, gripper
        self.frame_data = {}  # frame_idx → dict
        for idx in sorted(all_frame_indices):
            npy_path = os.path.join(data_dir, f"{idx:06d}.npy")
            if not os.path.exists(npy_path):
                continue
            js = np.load(npy_path).astype(np.float64)

            # FK
            mj_data.qpos[:N_ARM_JOINTS] = js[:N_ARM_JOINTS]
            gw = js[7] if len(js) > 7 else 1.0
            if mj_data.qpos.size >= N_ARM_JOINTS + 2:
                mj_data.qpos[N_ARM_JOINTS] = gw * GRIPPER_POS_MAX
                mj_data.qpos[N_ARM_JOINTS + 1] = gw * GRIPPER_POS_MAX
            mujoco.mj_forward(mj_model, mj_data)

            eef_pos = mj_data.xpos[hand_id].copy().astype(np.float32)
            quat_wxyz = mj_data.xquat[hand_id].copy()
            eef_quat = quat_wxyz[[1, 2, 3, 0]].astype(np.float32)  # xyzw
            eef_euler = ScipyR.from_quat(eef_quat).as_euler('xyz').astype(np.float32)

            # Project to 448x448 pixel coords
            pix = project_to_pixel(eef_pos.astype(np.float64), T_CAM_WORLD, CAM_K_448)
            if pix is not None:
                pixel_2d = np.array([
                    np.clip(pix[0], 0, image_size - 1),
                    np.clip(pix[1], 0, image_size - 1),
                ], dtype=np.float32)
            else:
                pixel_2d = np.zeros(2, dtype=np.float32)

            grip = np.float32(2.0 * gw - 1.0)  # 0→-1, 1→+1

            self.frame_data[idx] = {
                "eef_pos": eef_pos,
                "eef_quat": eef_quat,
                "eef_euler": eef_euler,
                "pixel_2d": pixel_2d,
                "gripper": grip,
            }

        print(f"  Pre-computed {len(self.frame_data)} frames", flush=True)

        # Build samples
        self.samples = []
        for ep_idx, ep in enumerate(self.episodes):
            ep_start, ep_end = ep["start"], ep["end"]
            ep_len = ep_end - ep_start + 1
            window_len = (N_WINDOW - 1) * frame_stride + 1
            for t in range(ep_len - window_len + 1):
                self.samples.append((ep_idx, ep_start + t))

        # Normalized camera intrinsics
        self.cam_k_norm = CAM_K.copy().astype(np.float32)
        self.cam_k_norm[0] /= IMG_W
        self.cam_k_norm[1] /= IMG_H

        # ImageNet normalization constants
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        print(f"PandaDataset: {len(self.episodes)} episodes, "
              f"{len(self.samples)} samples, stride={frame_stride}", flush=True)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ep_idx, start_frame = self.samples[idx]

        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []
        trajectory_quat = []
        trajectory_euler = []
        rgb_frames_raw = []
        rgb_ref = None

        for k in range(self.n_window):
            frame_idx = start_frame + k * self.frame_stride
            ts = f"{frame_idx:06d}"

            # Load pre-cached 448x448 JPG (fast!)
            img_path = os.path.join(self.cache_dir, f"{ts}.jpg")
            bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
            rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

            if rgb_ref is None:
                rgb_ref = rgb
            rgb_frames_raw.append(rgb)

            # Look up pre-computed FK data
            fd = self.frame_data[frame_idx]
            trajectory_3d.append(fd["eef_pos"])
            trajectory_quat.append(fd["eef_quat"])
            trajectory_euler.append(fd["eef_euler"])
            trajectory_2d.append(fd["pixel_2d"])
            trajectory_gripper.append(fd["gripper"])

        trajectory_2d = np.stack(trajectory_2d)
        trajectory_3d = np.stack(trajectory_3d)
        trajectory_gripper = np.array(trajectory_gripper, dtype=np.float32)
        trajectory_quat = np.stack(trajectory_quat)
        trajectory_euler = np.stack(trajectory_euler)
        rgb_frames_raw = np.stack(rgb_frames_raw)

        # Heatmap targets
        heatmap_targets = np.zeros((self.n_window, self.image_size, self.image_size), dtype=np.float32)
        for t in range(self.n_window):
            x_i = int(np.clip(round(float(trajectory_2d[t, 0])), 0, self.image_size - 1))
            y_i = int(np.clip(round(float(trajectory_2d[t, 1])), 0, self.image_size - 1))
            heatmap_targets[t, y_i, x_i] = 1.0

        # Normalize RGB for DINO
        rgb_t = (np.transpose(rgb_ref, (2, 0, 1)) - self.mean[:, None, None]) / self.std[:, None, None]

        world_to_camera = T_CAM_WORLD.astype(np.float32)

        return {
            "rgb": torch.from_numpy(rgb_t).float(),
            "heatmap_target": torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d": torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d": torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat": torch.from_numpy(trajectory_quat).float(),
            "trajectory_euler": torch.from_numpy(trajectory_euler).float(),
            "rgb_frames_raw": torch.from_numpy(rgb_frames_raw).float(),
            "world_to_camera": torch.from_numpy(world_to_camera).float(),
            "base_z": torch.tensor(0.0, dtype=torch.float32),
            "target_3d": torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose": torch.from_numpy(
                np.linalg.inv(T_CAM_WORLD).astype(np.float32)).float(),
            "cam_K_norm": torch.from_numpy(self.cam_k_norm).float(),
            "demo_idx": torch.tensor(ep_idx, dtype=torch.long),
            "start_t": torch.tensor(start_frame, dtype=torch.long),
        }


if __name__ == "__main__":
    ds = PandaTrajectoryDataset(
        "/data/cameron/panda_data/data_20260420_115853_632_frames")
    print(f"Dataset size: {len(ds)}")
    import time
    t0 = time.time()
    for i in range(min(20, len(ds))):
        _ = ds[i]
    elapsed = time.time() - t0
    print(f"20 samples in {elapsed:.2f}s = {20/elapsed:.1f} samples/sec")
