"""LIBERO-native dataset for PARA training.

Two dataset classes:
  RealTrajectoryDataset    — renders on-the-fly via OffScreenRenderEnv (slow, no multiprocessing)
  CachedTrajectoryDataset  — loads pre-rendered frames from disk (fast, supports num_workers>0)

Use prerender_dataset.py to generate the cache before using CachedTrajectoryDataset.
"""

import os
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)
import robosuite.utils.transform_utils as T_rob

N_WINDOW = 4
MIN_GRIPPER = -1.0   # gripper action space: -1 = open, +1 = close
MAX_GRIPPER =  1.0


def process_gripper_value(gripper_value):
    return max(MIN_GRIPPER, min(MAX_GRIPPER, float(gripper_value)))


def project_3d_to_2d(point_3d, camera_pose, cam_k):
    p = np.append(np.asarray(point_3d, dtype=np.float64), 1.0)
    p_cam = camera_pose @ p
    if p_cam[2] <= 0:
        return None
    pix = cam_k @ p_cam[:3]
    return pix[:2] / pix[2]


class RealTrajectoryDataset(Dataset):
    def __init__(
        self,
        dataset_root=None,
        image_size=448,
        benchmark_name="libero_spatial",
        task_id=0,
        camera="agentview",
        max_demos=10,
        frame_stride=3,  # sample every Nth frame; stride=3 @ 20Hz → ~6.7Hz, N_WINDOW=6 spans ~0.9s
    ):
        self.image_size = image_size
        self.n_window = N_WINDOW
        self.camera = camera
        self.frame_stride = frame_stride

        bench = benchmark.get_benchmark_dict()[benchmark_name]()
        task = bench.get_task(task_id)
        self.demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(task_id))
        if not os.path.isfile(self.demo_path):
            raise FileNotFoundError(f"LIBERO demo file not found: {self.demo_path}")

        with h5py.File(self.demo_path, "r") as f:
            demos = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
            if max_demos is not None:
                demos = demos[: max(1, int(max_demos))]
            self.demo_states  = [f[f"data/{k}/states"][()] for k in demos]
            self.demo_actions = [f[f"data/{k}/actions"][()] for k in demos]
            self.demo_names   = demos

        bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
        self.env = OffScreenRenderEnv(
            bddl_file_name=bddl_file,
            camera_heights=image_size,
            camera_widths=image_size,
            camera_names=[camera],
        )
        self.env.seed(0)
        self.env.reset()

        # Query robot base z once (static in simulation)
        base_body_name = "robot0_base"
        base_body_id = self.env.env.sim.model.body_name2id(base_body_name)
        self.base_z = float(self.env.env.sim.data.xpos[base_body_id][2]) if base_body_id >= 0 else 0.0

        self.samples = []
        for d_idx, states in enumerate(self.demo_states):
            for t in range(states.shape[0]):
                self.samples.append((d_idx, t))

        print(f"Loaded {len(self.demo_states)} demos from {Path(self.demo_path).name}")
        print(f"Created {len(self.samples)} samples (one per state)")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        demo_idx, start_t = self.samples[idx]
        states  = self.demo_states[demo_idx]
        actions = self.demo_actions[demo_idx]

        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []
        trajectory_quat = []
        rgb_frames_raw = []   # per-timestep flipud float [0,1] images for visualization
        rgb_ref = None
        cam_pose_ref = None
        cam_k_norm_ref = None
        world_to_camera_ref = None

        img_key = f"{self.camera}_image" if self.camera != "robot0_eye_in_hand" else "robot0_eye_in_hand_image"
        for k in range(self.n_window):
            t = min(start_t + k * self.frame_stride, states.shape[0] - 1)
            obs = self.env.set_init_state(states[t])
            self.env.env.sim.forward()

            rgb = np.asarray(obs[img_key]).copy()
            if rgb.dtype != np.float32:
                rgb = rgb.astype(np.float32)
                if rgb.max() > 1.0:
                    rgb = rgb / 255.0
            rgb = np.ascontiguousarray(np.flipud(rgb))
            if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
                rgb = cv2.resize(rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
            if rgb_ref is None:
                rgb_ref = rgb

            eef_pos  = np.asarray(obs["robot0_eef_pos"],  dtype=np.float64)
            eef_quat = np.asarray(obs["robot0_eef_quat"], dtype=np.float64)
            # Use demo action[:, 6] (clean -1/+1 signal) instead of gripper_qpos
            # whose symmetric fingers always average to ~0.
            grip = float(actions[min(t, len(actions) - 1), 6])
            trajectory_3d.append(eef_pos)
            trajectory_quat.append(eef_quat)
            trajectory_gripper.append(np.clip(grip, MIN_GRIPPER, MAX_GRIPPER))
            rgb_frames_raw.append(rgb.copy())  # flipud float [0,1], after resize

            h, w = rgb.shape[:2]
            world_to_camera = get_camera_transform_matrix(self.env.env.sim, self.camera, h, w)
            pix_rc = project_points_from_world_to_camera(
                points=eef_pos.reshape(1, 3),
                world_to_camera_transform=world_to_camera,
                camera_height=h,
                camera_width=w,
            )[0]
            v_raw, u_raw = int(pix_rc[0]), int(pix_rc[1])
            # v_raw is the correct row on flipud(obs_img) — no additional flip needed.
            # (debug_libero_projection.py draws at (u, v) directly on np.flipud(obs_img) and it's correct)
            trajectory_2d.append(
                np.array(
                    [
                        np.clip(float(u_raw), 0, w - 1),
                        np.clip(float(v_raw), 0, h - 1),
                    ],
                    dtype=np.float32,
                )
            )

            if cam_pose_ref is None:
                cam_pose_ref = get_camera_extrinsic_matrix(self.env.env.sim, self.camera).astype(np.float32)
                cam_k = get_camera_intrinsic_matrix(self.env.env.sim, self.camera, h, w).astype(np.float32)
                cam_k_norm = cam_k.copy()
                cam_k_norm[0] /= float(w)
                cam_k_norm[1] /= float(h)
                cam_k_norm_ref = cam_k_norm
                world_to_camera_ref = get_camera_transform_matrix(
                    self.env.env.sim, self.camera, h, w
                ).astype(np.float32)

        trajectory_2d = np.asarray(trajectory_2d, dtype=np.float32)
        trajectory_3d = np.asarray(trajectory_3d, dtype=np.float32)
        trajectory_gripper = np.asarray(trajectory_gripper, dtype=np.float32)
        trajectory_quat = np.asarray(trajectory_quat, dtype=np.float32)
        # Convert quaternions (xyzw) → euler XYZ angles in radians
        trajectory_euler = np.stack([
            ScipyR.from_quat(q).as_euler('xyz') for q in trajectory_quat
        ], axis=0).astype(np.float32)  # (N, 3)
        rgb_frames_raw = np.stack(rgb_frames_raw, axis=0).astype(np.float32)  # (N, H, W, 3)

        heatmap_targets = []
        for t in range(self.n_window):
            x, y = trajectory_2d[t]
            x_i = int(np.clip(round(float(x)), 0, self.image_size - 1))
            y_i = int(np.clip(round(float(y)), 0, self.image_size - 1))
            hm = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            hm[y_i, x_i] = 1.0
            heatmap_targets.append(hm)
        heatmap_targets = np.stack(heatmap_targets, axis=0)

        rgb_t = torch.from_numpy(rgb_ref).permute(2, 0, 1).float()
        mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1)
        rgb_t = (rgb_t - mean) / std

        return {
            "rgb": rgb_t,
            "heatmap_target": torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d": torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d": torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat": torch.from_numpy(trajectory_quat).float(),        # (N, 4) EEF quaternions xyzw
            "trajectory_euler": torch.from_numpy(trajectory_euler).float(),      # (N, 3) euler XYZ radians
            "rgb_frames_raw": torch.from_numpy(rgb_frames_raw).float(),         # (N, H, W, 3) float [0,1]
            "world_to_camera": torch.from_numpy(world_to_camera_ref).float(),   # (4, 4) for projection
            "base_z": torch.tensor(self.base_z, dtype=torch.float32),           # robot base Z for vis
            "target_3d": torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose": torch.from_numpy(cam_pose_ref).float(),
            "cam_K_norm": torch.from_numpy(cam_k_norm_ref).float(),
            "demo_idx": torch.tensor(demo_idx, dtype=torch.long),
            "start_t": torch.tensor(start_t, dtype=torch.long),
        }


# ---------------------------------------------------------------------------
# CachedTrajectoryDataset — loads pre-rendered frames from disk
# ---------------------------------------------------------------------------

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


class CachedTrajectoryDataset(Dataset):
    """Dataset that loads pre-rendered LIBERO frames from disk.

    Run libero/prerender_dataset.py first to generate the cache.

    Expected layout under cache_root:
        <benchmark>/task_<id>/demo_<idx>/
            frames/000000.png ...
            eef_pos.npy   eef_quat.npy   gripper.npy   pix_uv.npy
            cam_extrinsic.npy   cam_K_norm.npy   world_to_cam.npy   base_z.npy

    __getitem__ is pure I/O + numpy — safe for num_workers > 0.
    """

    def __init__(
        self,
        cache_root,
        benchmark_name="libero_spatial",
        task_ids=None,       # list of ints, or None = all found
        image_size=448,
        n_window=N_WINDOW,
        frame_stride=3,
        max_demos=0,         # 0 = all demos, >0 = limit per task
        augment=False,       # apply viewpoint augmentations
    ):
        self.image_size   = image_size
        self.n_window     = n_window
        self.frame_stride = frame_stride
        self.augment      = augment

        bench_root = Path(cache_root) / benchmark_name
        if not bench_root.exists():
            raise FileNotFoundError(f"Cache not found: {bench_root}")

        task_dirs = sorted(bench_root.glob("task_*"))
        if task_ids is not None:
            task_dirs = [d for d in task_dirs if int(d.name.split("_")[1]) in task_ids]

        # Load all demo metadata into memory (npy arrays only, not images)
        self.demos   = []
        self.samples = []

        # Preload CLIP embeddings per task (if available)
        self.clip_embeddings = {}  # task_id → (D_clip,) tensor
        for clip_path in sorted(bench_root.glob("task_*_clip.pt")):
            tid = int(clip_path.stem.split("_")[1])
            self.clip_embeddings[tid] = torch.load(clip_path, map_location="cpu")
        if self.clip_embeddings:
            print(f"  Loaded CLIP embeddings for {len(self.clip_embeddings)} tasks")

        # Preload task descriptions per task (if available — used by VLA models)
        self.task_descriptions = {}  # task_id → str
        for desc_path in sorted(bench_root.glob("task_*_description.txt")):
            tid = int(desc_path.stem.split("_")[1])
            self.task_descriptions[tid] = desc_path.read_text().strip()
        if self.task_descriptions:
            print(f"  Loaded task descriptions for {len(self.task_descriptions)} tasks")

        for task_dir in task_dirs:
            task_id = int(task_dir.name.split("_")[1])
            clip_emb = self.clip_embeddings.get(task_id, None)
            task_desc = self.task_descriptions.get(task_id, f"task {task_id}")
            task_demo_count = 0
            for demo_dir in sorted(task_dir.glob("demo_*")):
                if max_demos > 0 and task_demo_count >= max_demos:
                    break
                frames_dir  = demo_dir / "frames"
                if not frames_dir.exists():
                    continue
                frame_paths = sorted(frames_dir.glob("*.png"))
                if not frame_paths:
                    continue
                T = len(frame_paths)

                # Load wrist camera data if available
                wrist_frames_dir = demo_dir / "wrist_frames"
                has_wrist = wrist_frames_dir.exists() and (demo_dir / "wrist_pix_uv.npy").exists()

                demo = {
                    "frame_paths":   frame_paths,
                    "eef_pos":       np.load(demo_dir / "eef_pos.npy"),
                    "eef_quat":      np.load(demo_dir / "eef_quat.npy"),
                    "gripper":       np.load(demo_dir / "gripper.npy"),
                    "pix_uv":        np.load(demo_dir / "pix_uv.npy"),
                    "cam_extrinsic": np.load(demo_dir / "cam_extrinsic.npy"),
                    "cam_K_norm":    np.load(demo_dir / "cam_K_norm.npy"),
                    "world_to_cam":  np.load(demo_dir / "world_to_cam.npy"),
                    "base_z":        float(np.load(demo_dir / "base_z.npy")),
                    "T":             T,
                    "clip_embedding": clip_emb,
                    "task_description": task_desc,
                    # Wrist camera (per-frame extrinsics since camera moves)
                    "has_wrist":          has_wrist,
                    "wrist_frame_paths":  sorted(wrist_frames_dir.glob("*.png")) if has_wrist else None,
                    "wrist_pix_uv":       np.load(demo_dir / "wrist_pix_uv.npy") if has_wrist else None,
                    "wrist_extrinsics":   np.load(demo_dir / "wrist_extrinsics.npy") if has_wrist else None,
                    "wrist_cam_K_norm":   np.load(demo_dir / "wrist_cam_K_norm.npy") if has_wrist else None,
                    "wrist_w2c":          np.load(demo_dir / "wrist_w2c.npy") if has_wrist else None,
                }
                demo_idx = len(self.demos)
                self.demos.append(demo)
                for t in range(T):
                    self.samples.append((demo_idx, t))
                task_demo_count += 1

        print(f"CachedTrajectoryDataset: {len(self.demos)} demos, {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        demo_idx, start_t = self.samples[idx]
        demo = self.demos[demo_idx]
        T    = demo["T"]

        trajectory_2d      = []
        trajectory_3d      = []
        trajectory_gripper = []
        trajectory_quat    = []
        rgb_frames_raw     = []
        rgb_ref            = None

        # --- Viewpoint augmentation: random perspective warp ---
        # Sample once per trajectory so all frames get the same warp.
        if self.augment:
            h_strength = np.random.uniform(-0.15, 0.15)
            v_strength = np.random.uniform(-0.15, 0.15)
            H_img, W_img = self.image_size, self.image_size
            src_pts = np.float32([[0,0],[W_img,0],[W_img,H_img],[0,H_img]])
            dst_pts = np.float32([
                [h_strength*W_img + v_strength*W_img/3,
                 v_strength*H_img + h_strength*H_img/3],
                [W_img - h_strength*W_img - v_strength*W_img/3,
                 v_strength*H_img - h_strength*H_img/3],
                [W_img + h_strength*W_img + v_strength*W_img/3,
                 H_img - v_strength*H_img + h_strength*H_img/3],
                [-h_strength*W_img - v_strength*W_img/3,
                 H_img - v_strength*H_img - h_strength*H_img/3],
            ])
            persp_M = cv2.getPerspectiveTransform(src_pts, dst_pts)

        for k in range(self.n_window):
            t = min(start_t + k * self.frame_stride, T - 1)

            bgr = cv2.imread(str(demo["frame_paths"][t]))
            rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
            if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
                rgb = cv2.resize(rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
            if self.augment:
                rgb = cv2.warpPerspective(rgb, persp_M, (W_img, H_img),
                                          borderMode=cv2.BORDER_REFLECT_101)
            if rgb_ref is None:
                rgb_ref = rgb
            rgb_frames_raw.append(rgb)

            trajectory_3d.append(demo["eef_pos"][t].astype(np.float64))
            trajectory_quat.append(demo["eef_quat"][t].astype(np.float64))
            trajectory_gripper.append(float(np.clip(demo["gripper"][t], MIN_GRIPPER, MAX_GRIPPER)))
            trajectory_2d.append(demo["pix_uv"][t].astype(np.float32))

        trajectory_2d      = np.stack(trajectory_2d,      axis=0).astype(np.float32)
        trajectory_3d      = np.stack(trajectory_3d,      axis=0).astype(np.float32)
        trajectory_gripper = np.array(trajectory_gripper,         dtype=np.float32)
        trajectory_quat    = np.stack(trajectory_quat,    axis=0).astype(np.float32)
        trajectory_euler   = np.stack([
            ScipyR.from_quat(q).as_euler('xyz') for q in trajectory_quat
        ], axis=0).astype(np.float32)
        # Delta axis-angle from reference rotation (avoids euler wrapping)
        import model as _model_module
        ref_quat = np.array(_model_module.REF_ROTATION_QUAT, dtype=np.float64)
        ref_rot = ScipyR.from_quat(ref_quat)
        trajectory_delta_rotvec = np.stack([
            (ref_rot.inv() * ScipyR.from_quat(q)).as_rotvec() for q in trajectory_quat
        ], axis=0).astype(np.float32)
        rgb_frames_raw = np.stack(rgb_frames_raw, axis=0).astype(np.float32)

        heatmap_targets = []
        for t in range(self.n_window):
            x, y = trajectory_2d[t]
            x_i  = int(np.clip(round(float(x)), 0, self.image_size - 1))
            y_i  = int(np.clip(round(float(y)), 0, self.image_size - 1))
            hm   = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            hm[y_i, x_i] = 1.0
            heatmap_targets.append(hm)
        heatmap_targets = np.stack(heatmap_targets, axis=0)

        rgb_t = torch.from_numpy(rgb_ref).permute(2, 0, 1).float()
        mean  = torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1)
        std   = torch.tensor(IMAGENET_STD,  dtype=torch.float32).view(3, 1, 1)
        rgb_t = (rgb_t - mean) / std

        # Wrist camera data (if available)
        has_wrist = demo.get("has_wrist", False)
        if has_wrist:
            wrist_t = min(start_t, demo["T"] - 1)
            wrist_bgr = cv2.imread(str(demo["wrist_frame_paths"][wrist_t]))
            wrist_rgb = cv2.cvtColor(wrist_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
            if wrist_rgb.shape[0] != self.image_size or wrist_rgb.shape[1] != self.image_size:
                wrist_rgb = cv2.resize(wrist_rgb, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
            wrist_rgb_t = torch.from_numpy(wrist_rgb).permute(2, 0, 1).float()
            wrist_rgb_t = (wrist_rgb_t - mean) / std

            # Wrist trajectory 2D: project FUTURE EEF positions onto wrist camera AT start_t
            # (not onto their own frame's wrist camera, which is always ~the same pixel)
            from robosuite.utils.camera_utils import project_points_from_world_to_camera as _proj
            wrist_w2c_start = demo["wrist_w2c"][wrist_t]  # (4, 4) at observation frame
            wrist_traj_2d = []
            wrist_in_view = []
            for k in range(self.n_window):
                t = min(start_t + k * self.frame_stride, demo["T"] - 1)
                eef_3d = demo["eef_pos"][t].astype(np.float64)
                pix_rc = _proj(
                    eef_3d.reshape(1, 3), wrist_w2c_start,
                    self.image_size, self.image_size
                )[0]
                uv = np.array([pix_rc[1], pix_rc[0]], dtype=np.float32)  # [col, row]
                wrist_traj_2d.append(uv)
                margin = 5
                in_view = (margin <= uv[0] < self.image_size - margin and
                           margin <= uv[1] < self.image_size - margin)
                wrist_in_view.append(float(in_view))
            wrist_traj_2d = np.stack(wrist_traj_2d, axis=0)
            wrist_in_view = np.array(wrist_in_view, dtype=np.float32)

            # Per-frame wrist extrinsics at start_t
            wrist_cam_extrinsic = demo["wrist_extrinsics"][wrist_t].astype(np.float32)
            wrist_cam_K_norm = demo["wrist_cam_K_norm"].astype(np.float32)
            wrist_w2c = demo["wrist_w2c"][wrist_t].astype(np.float32)
        else:
            wrist_rgb_t = torch.zeros_like(rgb_t)
            wrist_traj_2d = np.zeros_like(trajectory_2d)
            wrist_in_view = np.zeros(self.n_window, dtype=np.float32)
            wrist_cam_extrinsic = np.eye(4, dtype=np.float32)
            wrist_cam_K_norm = np.eye(3, dtype=np.float32)
            wrist_w2c = np.eye(4, dtype=np.float32)

        return {
            "rgb":                rgb_t,
            "heatmap_target":     torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d":      torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d":      torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat":    torch.from_numpy(trajectory_quat).float(),
            "trajectory_euler":   torch.from_numpy(trajectory_euler).float(),
            "rgb_frames_raw":     torch.from_numpy(rgb_frames_raw).float(),
            "world_to_camera":    torch.from_numpy(demo["world_to_cam"]).float(),
            "base_z":             torch.tensor(demo["base_z"], dtype=torch.float32),
            "target_3d":          torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose":        torch.from_numpy(demo["cam_extrinsic"]).float(),
            "cam_K_norm":         torch.from_numpy(demo["cam_K_norm"]).float(),
            "trajectory_delta_rotvec": torch.from_numpy(trajectory_delta_rotvec).float(),
            "demo_idx":           torch.tensor(demo_idx, dtype=torch.long),
            "start_t":            torch.tensor(start_t,  dtype=torch.long),
            "clip_embedding":     demo["clip_embedding"] if demo["clip_embedding"] is not None else torch.zeros(512),
            "task_description":   demo["task_description"],
            # Wrist camera
            "has_wrist":              torch.tensor(has_wrist, dtype=torch.bool),
            "wrist_rgb":              wrist_rgb_t,
            "wrist_trajectory_2d":    torch.from_numpy(wrist_traj_2d).float(),
            "wrist_camera_pose":      torch.from_numpy(wrist_cam_extrinsic).float(),
            "wrist_cam_K_norm":       torch.from_numpy(wrist_cam_K_norm).float(),
            "wrist_world_to_camera":  torch.from_numpy(wrist_w2c).float(),
            "wrist_in_view":          torch.from_numpy(wrist_in_view).float(),
        }

