"""Add wrist camera data to existing parsed LIBERO dataset.

Saves alongside existing agentview data:
    wrist_frames/000000.png ...     per-frame wrist images
    wrist_extrinsics.npy            (T, 4, 4) per-frame camera-to-world
    wrist_w2c.npy                   (T, 4, 4) per-frame world-to-camera
    wrist_cam_K_norm.npy            (3, 3) normalized intrinsics (constant)
    wrist_pix_uv.npy                (T, 2) EEF projected onto wrist camera

Usage:
    python prerender_wrist.py --benchmark libero_spatial --task_ids 0
"""
import argparse
import os
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

IMAGE_SIZE = 448
WRIST_CAM = "robot0_eye_in_hand"


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_ids", type=str, default="0", help="Comma-separated task IDs or 'all'")
    parser.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    parser.add_argument("--image_size", type=int, default=IMAGE_SIZE)
    parser.add_argument("--max_demos", type=int, default=0, help="0 = all demos")
    args = parser.parse_args()

    bench = bm.get_benchmark_dict()[args.benchmark]()
    n_tasks = bench.get_num_tasks()

    if args.task_ids.strip().lower() == "all":
        task_ids = list(range(n_tasks))
    else:
        task_ids = [int(x) for x in args.task_ids.split(",")]

    for task_id in task_ids:
        task = bench.get_task(task_id)
        print(f"\n{'='*60}")
        print(f"Task {task_id}: {task.name}")

        demo_file = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(task_id))
        with h5py.File(demo_file, "r") as f:
            demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
            all_states = {k: f[f"data/{k}/states"][()] for k in demo_keys}

        if args.max_demos > 0:
            demo_keys = demo_keys[:args.max_demos]

        bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
        env = OffScreenRenderEnv(
            bddl_file_name=bddl,
            camera_heights=args.image_size,
            camera_widths=args.image_size,
            camera_names=["agentview", WRIST_CAM],
        )
        env.seed(0)
        sim = env.env.sim
        H = W = args.image_size

        for demo_key in tqdm(demo_keys, desc=f"Task {task_id}"):
            demo_idx = int(demo_key.split("_")[1])
            demo_dir = Path(args.cache_root) / args.benchmark / f"task_{task_id}" / f"demo_{demo_idx}"

            # Skip if already prerendered
            if (demo_dir / "wrist_cam_K_norm.npy").exists():
                continue

            states = all_states[demo_key]
            T = states.shape[0]

            wrist_frames_dir = demo_dir / "wrist_frames"
            wrist_frames_dir.mkdir(parents=True, exist_ok=True)

            wrist_extrinsics = np.zeros((T, 4, 4), dtype=np.float32)
            wrist_w2c = np.zeros((T, 4, 4), dtype=np.float32)
            wrist_pix_uv = np.zeros((T, 2), dtype=np.float32)
            wrist_K_norm = None

            for t in range(T):
                obs = env.set_init_state(states[t])
                sim.forward()

                # Save wrist image (flipud to match agentview convention)
                wrist_img = np.flipud(obs[f"{WRIST_CAM}_image"]).copy()
                cv2.imwrite(
                    str(wrist_frames_dir / f"{t:06d}.png"),
                    cv2.cvtColor(wrist_img, cv2.COLOR_RGB2BGR),
                )

                # Camera params (change every frame since wrist moves)
                wrist_extrinsics[t] = get_camera_extrinsic_matrix(sim, WRIST_CAM)
                wrist_w2c[t] = get_camera_transform_matrix(sim, WRIST_CAM, H, W)

                # Intrinsics (constant but get once)
                if wrist_K_norm is None:
                    K = get_camera_intrinsic_matrix(sim, WRIST_CAM, H, W)
                    wrist_K_norm = K.copy()
                    wrist_K_norm[0] /= W
                    wrist_K_norm[1] /= H

                # Project EEF onto wrist camera
                eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                pix_rc = project_points_from_world_to_camera(
                    eef_pos.reshape(1, 3), wrist_w2c[t], H, W
                )[0]
                wrist_pix_uv[t] = [pix_rc[1], pix_rc[0]]  # [col, row] = [u, v]

            np.save(demo_dir / "wrist_extrinsics.npy", wrist_extrinsics)
            np.save(demo_dir / "wrist_w2c.npy", wrist_w2c)
            np.save(demo_dir / "wrist_cam_K_norm.npy", wrist_K_norm)
            np.save(demo_dir / "wrist_pix_uv.npy", wrist_pix_uv)

        env.close()
        print(f"  Done: {len(demo_keys)} demos")

    print("\nAll done.")


if __name__ == "__main__":
    main()
