"""Generate 2view (BEV + wrist) OOD viewpoint dataset.

Same as generate_ood_viewpoint.py but with wrist camera enabled and per-frame wrist
params saved. Adds --left_hemi_only / --vi_start / --vi_end for parallel rendering.
"""
import argparse, os, sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

from generate_ood_viewpoint import (
    AGENT_CAM, TABLE_Z, STATE_QPOS_OFFSET, PICK_QPOS, PLACE_QPOS,
    DISTRACTOR_QPOS_STARTS, FURNITURE_BODIES, DISTRACTOR_BODIES,
    DISTRACTOR_POS, DISTRACTOR_DOFS,
    look_at_quat, generate_viewpoint_grid, find_grasp_timestep,
    setup_clean_scene, freeze_distractors, shift_state,
    servo_to, interpolate_waypoints,
)

WRIST_CAM = "robot0_eye_in_hand"


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--n_views", type=int, default=8)
    p.add_argument("--theta_max", type=float, default=25)
    p.add_argument("--demos_per_view", type=int, default=10)
    p.add_argument("--dx_min", type=float, default=-0.40)
    p.add_argument("--dx_max", type=float, default=-0.01)
    p.add_argument("--dy_min", type=float, default=-0.30)
    p.add_argument("--dy_max", type=float, default=0.30)
    p.add_argument("--image_size", type=int, default=448)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--z_offset", type=float, default=-0.015)
    p.add_argument("--out_root", type=str, default="/data/libero/ood_viewpoint_v3_2view")
    p.add_argument("--left_hemi_only", action="store_true",
                   help="Skip viewpoints with phi in (90°, 270°) — i.e. only render phi ∈ [0,90] ∪ [270,360]")
    p.add_argument("--vi_start", type=int, default=0, help="Start vp index (inclusive) for parallel rendering")
    p.add_argument("--vi_end",   type=int, default=None, help="End vp index (exclusive)")
    args = p.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        dk = sorted([k for k in f["data"].keys() if k.startswith("demo_")])[0]
        states = f[f"data/{dk}/states"][()]
        actions = f[f"data/{dk}/actions"][()]

    print("Extracting EEF trajectory...")
    env_tmp = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size,
        camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = []
    for t in range(len(states)):
        env_tmp.set_init_state(states[t])
        env_tmp.env.sim.forward()
        eef_orig.append(np.array(env_tmp.env._get_observations()["robot0_eef_pos"], dtype=np.float64))
    eef_orig = np.array(eef_orig)
    env_tmp.close()

    bowl_si = PICK_QPOS + STATE_QPOS_OFFSET
    center_dx = -states[0][bowl_si]
    center_dy = -states[0][bowl_si + 1]
    print(f"Center offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    # Env with BOTH cams
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size,
        camera_widths=args.image_size, camera_names=[AGENT_CAM, WRIST_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    setup_clean_scene(sim)
    H = W = args.image_size

    cam_id = sim.model.camera_name2id(AGENT_CAM)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    forward = -cam_xmat[:, 2]
    t_hit = (TABLE_Z - default_pos[2]) / (forward[2] + 1e-8)
    look_at = default_pos + t_hit * forward
    print(f"Default cam: {default_pos}, look-at: {look_at}")

    vp_positions, vp_quats, thetas, phis, theta_idx, phi_idx = generate_viewpoint_grid(
        default_pos, look_at, args.n_views, args.theta_max)
    n_viewpoints = len(vp_positions)
    n_total = n_viewpoints * args.demos_per_view
    print(f"\n{args.n_views}x{args.n_views} = {n_viewpoints} viewpoints x {args.demos_per_view} demos = {n_total} episodes")
    phis_deg = np.degrees(phis)
    print(f"Theta deg: {np.degrees(thetas).round(1)}")
    print(f"Phi deg:   {phis_deg.round(1)}")

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)

    np.savez(task_dir / "viewpoint_meta.npz",
             thetas_deg=np.degrees(thetas), phis_deg=phis_deg,
             n_views=args.n_views, theta_max=args.theta_max,
             demos_per_view=args.demos_per_view,
             vp_positions=vp_positions, vp_quats=vp_quats,
             center_dx=center_dx, center_dy=center_dy,
             dx_min=args.dx_min, dx_max=args.dx_max,
             dy_min=args.dy_min, dy_max=args.dy_max)

    rng = np.random.RandomState(42)
    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - 6)
    successes = 0

    # vi range
    vi_end = args.vi_end if args.vi_end is not None else n_viewpoints
    vi_start = max(0, args.vi_start); vi_end = min(n_viewpoints, vi_end)

    # NOTE: rng state must match the single-process version's per-demo offsets, so we
    # consume the rng deterministically up to (vi_start * demos_per_view) before starting.
    for _ in range(vi_start * args.demos_per_view):
        rng.uniform(args.dx_min, args.dx_max)
        rng.uniform(args.dy_min, args.dy_max)

    n_processed = 0
    for vi in tqdm(range(vi_start, vi_end), desc=f"VP[{vi_start}:{vi_end}]"):
        if args.left_hemi_only:
            # left hemi = phi ∈ [0, 90] ∪ [270, 360]
            p_deg = phis_deg[vi % args.n_views]
            if not ((0 <= p_deg <= 90) or (270 <= p_deg <= 360)):
                # still need to advance rng for skipped viewpoints to keep determinism
                for _ in range(args.demos_per_view):
                    rng.uniform(args.dx_min, args.dx_max); rng.uniform(args.dy_min, args.dy_max)
                continue

        sim.model.cam_pos[cam_id] = vp_positions[vi]
        sim.model.cam_quat[cam_id] = vp_quats[vi]

        for di in range(args.demos_per_view):
            demo_idx = vi * args.demos_per_view + di
            demo_dir = task_dir / f"demo_{demo_idx}"
            dx_offset = rng.uniform(args.dx_min, args.dx_max)
            dy_offset = rng.uniform(args.dy_min, args.dy_max)
            if (demo_dir / "wrist_w2c.npy").exists():
                continue

            total_dx = center_dx + dx_offset
            total_dy = center_dy + dy_offset

            env.env.timestep = 0
            env.env.done = False
            state_0 = shift_state(states[0], total_dx, total_dy)
            env.set_init_state(state_0)
            sim.forward()
            for _ in range(5):
                env.step(np.zeros(7, dtype=np.float32))
                freeze_distractors(sim)

            obs = env.env._get_observations()
            home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            pregrasp_target = eef_orig[t_pregrasp].copy()
            pregrasp_target[0] += total_dx
            pregrasp_target[1] += total_dy

            gripper_cmd = -1.0
            rec_bev, rec_wrist = [], []
            rec_eef, rec_quat, rec_grip = [], [], []
            rec_wrist_w2c, rec_wrist_ext = [], []

            def record(o):
                rec_eef.append(np.array(o["robot0_eef_pos"], dtype=np.float32))
                rec_quat.append(np.array(o["robot0_eef_quat"], dtype=np.float32))
                rec_grip.append(gripper_cmd)
                rec_bev.append(np.flipud(o[f"{AGENT_CAM}_image"]).copy())
                rec_wrist.append(np.flipud(o[f"{WRIST_CAM}_image"]).copy())
                w2c_t = get_camera_transform_matrix(sim, WRIST_CAM, H, W).astype(np.float32)
                ext_t = get_camera_extrinsic_matrix(sim, WRIST_CAM).astype(np.float32)
                rec_wrist_w2c.append(w2c_t)
                rec_wrist_ext.append(ext_t)

            record(obs)
            for wp in interpolate_waypoints(home_pos, pregrasp_target, 8):
                obs = servo_to(env, wp, -1.0, max_servo=25)
                record(obs)

            phase2 = list(range(t_pregrasp, len(eef_orig), args.frame_stride))
            success = False
            for t in phase2:
                target = eef_orig[t].copy()
                target[0] += total_dx
                target[1] += total_dy
                if t < len(actions):
                    gripper_cmd = float(np.clip(actions[t, 6], -1, 1))
                if gripper_cmd > 0 and args.z_offset != 0:
                    target[2] += args.z_offset
                obs = servo_to(env, target, gripper_cmd, max_servo=25)
                record(obs)
                if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                    success = True

            demo_dir.mkdir(parents=True, exist_ok=True)
            (demo_dir / "frames").mkdir(exist_ok=True)
            (demo_dir / "wrist_frames").mkdir(exist_ok=True)
            for fi, frame in enumerate(rec_bev):
                cv2.imwrite(str(demo_dir / "frames" / f"{fi:06d}.png"),
                            cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
            for fi, frame in enumerate(rec_wrist):
                cv2.imwrite(str(demo_dir / "wrist_frames" / f"{fi:06d}.png"),
                            cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            eef_arr = np.stack(rec_eef)
            w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W)
            K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W)
            K_norm = K.copy(); K_norm[0] /= W; K_norm[1] /= H
            ext = get_camera_extrinsic_matrix(sim, AGENT_CAM)

            wrist_K = get_camera_intrinsic_matrix(sim, WRIST_CAM, H, W).astype(np.float32)
            wrist_K_norm = wrist_K.copy()
            wrist_K_norm[0] /= W; wrist_K_norm[1] /= H

            pix_uvs = []
            wrist_pix_uvs = []
            for ei in range(len(eef_arr)):
                pix_rc = project_points_from_world_to_camera(
                    eef_arr[ei:ei+1].astype(np.float64), w2c, H, W)[0]
                pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))
                pix_rc_w = project_points_from_world_to_camera(
                    eef_arr[ei:ei+1].astype(np.float64), rec_wrist_w2c[ei], H, W)[0]
                wrist_pix_uvs.append(np.array([pix_rc_w[1], pix_rc_w[0]], dtype=np.float32))

            np.save(demo_dir / "eef_pos.npy",    eef_arr)
            np.save(demo_dir / "eef_quat.npy",   np.stack(rec_quat))
            np.save(demo_dir / "gripper.npy",    np.array(rec_grip, dtype=np.float32))
            np.save(demo_dir / "pix_uv.npy",     np.stack(pix_uvs))
            np.save(demo_dir / "cam_extrinsic.npy", ext.astype(np.float32))
            np.save(demo_dir / "cam_K_norm.npy", K_norm.astype(np.float32))
            np.save(demo_dir / "world_to_cam.npy", w2c.astype(np.float32))
            np.save(demo_dir / "base_z.npy",     np.float32(0.912))
            np.save(demo_dir / "actions.npy",    np.zeros((len(eef_arr), 7), dtype=np.float32))
            np.save(demo_dir / "wrist_pix_uv.npy",  np.stack(wrist_pix_uvs))
            np.save(demo_dir / "wrist_cam_K_norm.npy", wrist_K_norm)
            np.save(demo_dir / "wrist_extrinsics.npy", np.stack(rec_wrist_ext))
            np.save(demo_dir / "wrist_w2c.npy",  np.stack(rec_wrist_w2c))
            np.savez(demo_dir / "meta.npz", vi=vi, di=di,
                     theta_idx=theta_idx[vi], phi_idx=phi_idx[vi],
                     dx=dx_offset, dy=dy_offset)

            if success:
                successes += 1
            n_processed += 1

    env.close()
    print(f"\nDone. {successes}/{n_processed} succeeded.")


if __name__ == "__main__":
    main()
