"""Replay LIBERO demos with zero rotation and save the resulting trajectories.

This replays demo actions with rotation zeroed out, recording the actual
EEF positions/images that result. Training targets then match zero-rotation eval.

Usage:
    python prerender_zero_rot.py --benchmark libero_spatial --task_id 0
"""
import argparse
import os
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

IMAGE_SIZE = 448
AGENT_CAM = "agentview"
WRIST_CAM = "robot0_eye_in_hand"


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--out_root", type=str, default="/data/libero/parsed_libero_zero_rot_task1")
    parser.add_argument("--image_size", type=int, default=IMAGE_SIZE)
    parser.add_argument("--max_demos", type=int, default=0)
    args = parser.parse_args()

    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    print(f"Task {args.task_id}: {task.name}")

    demo_file = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_file, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        all_data = {}
        for dk in demo_keys:
            all_data[dk] = {
                "states": f[f"data/{dk}/states"][()],
                "actions": f[f"data/{dk}/actions"][()],
            }

    if args.max_demos > 0:
        demo_keys = demo_keys[:args.max_demos]

    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[AGENT_CAM, WRIST_CAM],
    )
    env.seed(0)
    sim = env.env.sim
    H = W = args.image_size

    task_dir = Path(args.out_root) / args.benchmark / f"task_{args.task_id}"

    for dk in tqdm(demo_keys, desc="Demos"):
        demo_idx = int(dk.split("_")[1])
        demo_dir = task_dir / f"demo_{demo_idx}"

        states = all_data[dk]["states"]
        actions = all_data[dk]["actions"]
        T_actions = actions.shape[0]

        # Replay with zero rotation
        env.reset()
        obs = env.set_init_state(states[0])
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        frames_dir = demo_dir / "frames"
        wrist_frames_dir = demo_dir / "wrist_frames"
        frames_dir.mkdir(parents=True, exist_ok=True)
        wrist_frames_dir.mkdir(parents=True, exist_ok=True)

        eef_positions = []
        eef_quats = []
        grippers = []
        pix_uvs = []
        wrist_pix_uvs = []
        wrist_extrinsics = []
        wrist_w2cs = []

        # Record initial frame
        def record_frame(obs, t):
            s = env.env.sim  # refresh sim reference after each step
            eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
            eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
            eef_positions.append(eef_pos)
            eef_quats.append(eef_quat)

            # Agent camera image
            agent_img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
            cv2.imwrite(str(frames_dir / f"{t:06d}.png"),
                        cv2.cvtColor(agent_img, cv2.COLOR_RGB2BGR))

            # Wrist camera image
            wrist_img = np.flipud(obs[f"{WRIST_CAM}_image"]).copy()
            cv2.imwrite(str(wrist_frames_dir / f"{t:06d}.png"),
                        cv2.cvtColor(wrist_img, cv2.COLOR_RGB2BGR))

            # Agent camera projection
            w2c = get_camera_transform_matrix(s, AGENT_CAM, H, W)
            pix_rc = project_points_from_world_to_camera(
                eef_pos.reshape(1, 3).astype(np.float64), w2c, H, W)[0]
            pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

            # Wrist camera params (per-frame since it moves)
            wrist_ext = get_camera_extrinsic_matrix(s, WRIST_CAM)
            wrist_w2c = get_camera_transform_matrix(s, WRIST_CAM, H, W)
            wrist_extrinsics.append(wrist_ext.astype(np.float32))
            wrist_w2cs.append(wrist_w2c.astype(np.float32))

            wrist_rc = project_points_from_world_to_camera(
                eef_pos.reshape(1, 3).astype(np.float64), wrist_w2c, H, W)[0]
            wrist_pix_uvs.append(np.array([wrist_rc[1], wrist_rc[0]], dtype=np.float32))

        record_frame(obs, 0)

        # Replay actions with zero rotation
        for t in range(T_actions):
            action = actions[t].copy()
            action[3:6] = 0.0  # zero rotation
            grippers.append(float(action[6]))
            obs, _, done, _ = env.step(action)
            record_frame(obs, t + 1)
            if done:
                break

        # Pad gripper for last frame
        grippers.append(grippers[-1] if grippers else -1.0)

        T = len(eef_positions)

        # Save static camera params (agentview is static)
        agent_ext = get_camera_extrinsic_matrix(env.env.sim, AGENT_CAM).astype(np.float32)
        agent_w2c = get_camera_transform_matrix(env.env.sim, AGENT_CAM, H, W).astype(np.float32)
        agent_K = get_camera_intrinsic_matrix(env.env.sim, AGENT_CAM, H, W).astype(np.float32)
        agent_K_norm = agent_K.copy()
        agent_K_norm[0] /= W
        agent_K_norm[1] /= H

        wrist_K = get_camera_intrinsic_matrix(env.env.sim, WRIST_CAM, H, W).astype(np.float32)
        wrist_K_norm = wrist_K.copy()
        wrist_K_norm[0] /= W
        wrist_K_norm[1] /= H

        # Get robot base z
        try:
            base_z = float(sim.data.body_xpos[sim.model.body_name2id("robot0_base")][2])
        except AttributeError:
            base_z = 0.912  # default Panda base height

        # Save all
        np.save(demo_dir / "eef_pos.npy", np.stack(eef_positions))
        np.save(demo_dir / "eef_quat.npy", np.stack(eef_quats))
        np.save(demo_dir / "gripper.npy", np.array(grippers[:T], dtype=np.float32))
        np.save(demo_dir / "pix_uv.npy", np.stack(pix_uvs))
        np.save(demo_dir / "cam_extrinsic.npy", agent_ext)
        np.save(demo_dir / "cam_K_norm.npy", agent_K_norm)
        np.save(demo_dir / "world_to_cam.npy", agent_w2c)
        np.save(demo_dir / "base_z.npy", np.float32(base_z))
        np.save(demo_dir / "actions.npy", actions)

        # Wrist data
        np.save(demo_dir / "wrist_extrinsics.npy", np.stack(wrist_extrinsics))
        np.save(demo_dir / "wrist_w2c.npy", np.stack(wrist_w2cs))
        np.save(demo_dir / "wrist_cam_K_norm.npy", wrist_K_norm)
        np.save(demo_dir / "wrist_pix_uv.npy", np.stack(wrist_pix_uvs))

    env.close()
    print(f"\nDone. Saved {len(demo_keys)} demos to {task_dir}")


if __name__ == "__main__":
    main()
