"""prerender_dataset.py — Pre-render LIBERO demos to disk for fast training.

Replaces the per-step OffScreenRenderEnv calls in RealTrajectoryDataset with
a one-time render pass. Output can be used by CachedTrajectoryDataset in data.py.

Output layout:
    <out_root>/
      libero_spatial/
        task_0/
          demo_0/
            frames/  000000.png  000001.png  ...   (flipud uint8 RGB at image_size)
            eef_pos.npy          (T, 3)  world-frame EEF position
            eef_quat.npy         (T, 4)  EEF quaternion xyzw
            gripper.npy          (T,)    gripper value from actions[:, 6]
            pix_uv.npy           (T, 2)  projected 2D pixel (u, v) in training image space
            cam_extrinsic.npy    (4, 4)  camera→world  (frame 0, static for agentview)
            cam_K_norm.npy       (3, 3)  normalized intrinsics (frame 0)
            world_to_cam.npy     (4, 4)  world→camera (frame 0)
            base_z.npy           ()      robot base Z scalar
            actions.npy          (T, 7)  full demo action array

Usage:
    python libero/prerender_dataset.py \
        --benchmark libero_spatial \
        --out_root /data/libero/parsed_libero \
        --image_size 448 \
        --camera agentview
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, os.path.dirname(__file__))

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)


def render_demo(env, states, actions, camera, image_size, out_dir):
    out_dir = Path(out_dir)
    frames_dir = out_dir / "frames"
    frames_dir.mkdir(parents=True, exist_ok=True)

    T = states.shape[0]
    eef_pos_list   = np.zeros((T, 3), dtype=np.float32)
    eef_quat_list  = np.zeros((T, 4), dtype=np.float32)
    gripper_list   = np.zeros((T,),   dtype=np.float32)
    pix_uv_list    = np.zeros((T, 2), dtype=np.float32)

    cam_extrinsic = None
    cam_K_norm    = None
    world_to_cam  = None
    base_z        = None

    img_key = f"{camera}_image" if camera != "robot0_eye_in_hand" else "robot0_eye_in_hand_image"

    for t in range(T):
        obs = env.set_init_state(states[t])
        env.env.sim.forward()

        # RGB → flipud → resize → save as PNG
        rgb = np.asarray(obs[img_key]).copy()
        if rgb.max() <= 1.0:
            rgb = (rgb * 255).astype(np.uint8)
        rgb = np.ascontiguousarray(np.flipud(rgb))
        if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
            rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
        cv2.imwrite(str(frames_dir / f"{t:06d}.png"), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))

        # EEF state
        eef_pos  = np.asarray(obs["robot0_eef_pos"],  dtype=np.float64)
        eef_quat = np.asarray(obs["robot0_eef_quat"], dtype=np.float64)
        eef_pos_list[t]  = eef_pos.astype(np.float32)
        eef_quat_list[t] = eef_quat.astype(np.float32)
        gripper_list[t]  = float(np.clip(actions[min(t, len(actions) - 1), 6], -1.0, 1.0))

        # Camera params (static for agentview — save once from frame 0)
        h, w = image_size, image_size
        wtc = get_camera_transform_matrix(env.env.sim, camera, h, w)
        pix_rc = project_points_from_world_to_camera(
            points=eef_pos.reshape(1, 3),
            world_to_camera_transform=wtc,
            camera_height=h,
            camera_width=w,
        )[0]
        v_raw, u_raw = float(pix_rc[0]), float(pix_rc[1])
        pix_uv_list[t] = [np.clip(u_raw, 0, w - 1), np.clip(v_raw, 0, h - 1)]

        if cam_extrinsic is None:
            cam_extrinsic = get_camera_extrinsic_matrix(env.env.sim, camera).astype(np.float32)
            cam_K         = get_camera_intrinsic_matrix(env.env.sim, camera, h, w).astype(np.float32)
            cam_K_norm    = cam_K.copy()
            cam_K_norm[0] /= float(w)
            cam_K_norm[1] /= float(h)
            world_to_cam  = wtc.astype(np.float32)
            base_body_id  = env.env.sim.model.body_name2id("robot0_base")
            base_z        = float(env.env.sim.data.xpos[base_body_id][2]) if base_body_id >= 0 else 0.0

    np.save(out_dir / "eef_pos.npy",        eef_pos_list)
    np.save(out_dir / "eef_quat.npy",       eef_quat_list)
    np.save(out_dir / "gripper.npy",        gripper_list)
    np.save(out_dir / "pix_uv.npy",         pix_uv_list)
    np.save(out_dir / "cam_extrinsic.npy",  cam_extrinsic)
    np.save(out_dir / "cam_K_norm.npy",     cam_K_norm)
    np.save(out_dir / "world_to_cam.npy",   world_to_cam)
    np.save(out_dir / "base_z.npy",         np.float32(base_z))
    np.save(out_dir / "actions.npy",        actions.astype(np.float32))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark",  type=str, default="libero_spatial")
    parser.add_argument("--out_root",   type=str, default="/data/libero/parsed_libero")
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--camera",     type=str, default="agentview")
    parser.add_argument("--task_ids",   type=str, default="all",
                        help="Comma-separated task indices or 'all'")
    parser.add_argument("--max_demos",  type=int, default=None,
                        help="Max demos per task (default: all)")
    args = parser.parse_args()

    bench     = bm_lib.get_benchmark_dict()[args.benchmark]()
    n_tasks   = bench.get_num_tasks()
    task_ids  = list(range(n_tasks)) if args.task_ids == "all" else [int(x) for x in args.task_ids.split(",")]

    for task_id in task_ids:
        task      = bench.get_task(task_id)
        demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(task_id))
        task_name = task.name
        print(f"\n{'='*60}")
        print(f"Task {task_id}: {task_name}")

        with h5py.File(demo_path, "r") as f:
            demo_keys   = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
            if args.max_demos:
                demo_keys = demo_keys[:args.max_demos]
            all_states  = [f[f"data/{k}/states"][()] for k in demo_keys]
            all_actions = [f[f"data/{k}/actions"][()] for k in demo_keys]

        bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
        env = OffScreenRenderEnv(
            bddl_file_name=bddl_file,
            camera_heights=args.image_size,
            camera_widths=args.image_size,
            camera_names=[args.camera],
        )
        env.seed(0)
        env.reset()

        task_out = Path(args.out_root) / args.benchmark / f"task_{task_id}"

        for d_idx, (demo_key, states, actions) in enumerate(
            tqdm(zip(demo_keys, all_states, all_actions), total=len(demo_keys), desc=f"Task {task_id}")
        ):
            demo_out = task_out / demo_key
            # Skip if already fully rendered
            if (demo_out / "actions.npy").exists() and len(list((demo_out / "frames").glob("*.png"))) == states.shape[0]:
                tqdm.write(f"  {demo_key}: already rendered, skipping")
                continue
            render_demo(env, states, actions, args.camera, args.image_size, demo_out)

        env.close()
        print(f"Task {task_id} done → {task_out}")

    print(f"\nAll done. Output: {args.out_root}")


if __name__ == "__main__":
    main()
