"""Generate 2view (BEV + wrist) OOD object position dataset for PARA training.

Same servo replay as generate_ood_objpos.py, but with the wrist camera enabled
and per-frame wrist params saved. Output layout per demo:

  frames/          — agentview PNGs (BEV)
  wrist_frames/    — robot0_eye_in_hand PNGs (per-frame)
  eef_pos.npy, eef_quat.npy, gripper.npy, pix_uv.npy
  cam_extrinsic.npy, cam_K_norm.npy, world_to_cam.npy   — BEV (static)
  wrist_pix_uv.npy (T, 2)
  wrist_extrinsics.npy (T, 4, 4) — wrist cam→world per frame
  wrist_cam_K_norm.npy (3, 3)
  wrist_w2c.npy (T, 4, 4)         — wrist world→cam per frame
  base_z.npy, actions.npy

Usage:
    python generate_ood_objpos_2view.py --out_root /data/libero/ood_objpos_task0_2view
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

# Import shared helpers from the 1-view generator
from generate_ood_objpos import (
    STATE_QPOS_OFFSET, TASK0_OBJECTS, FURNITURE_BODIES, DISTRACTOR_POS,
    DISTRACTOR_DOFS, AGENT_CAM, WRIST_CAM,
    _si, move_distractors_in_state, hide_distractors_visual,
    freeze_distractors, shift_pick_place, hide_furniture,
    extract_demo_eef_positions, servo_to_position, find_grasp_timestep,
    interpolate_waypoints,
)


def generate_trajectory_2view(env, states, actions, eef_orig, dx, dy, center_dx, center_dy,
                              frame_stride=3, z_offset=-0.015, max_servo=25, image_size=448,
                              pregrasp_lead=6, interp_steps=8):
    """Generate a shifted trajectory with natural start; record BOTH BEV and wrist views."""
    sim = env.env.sim
    H = W = image_size
    total_dx = center_dx + dx
    total_dy = center_dy + dy

    env.env.timestep = 0
    env.env.done = False

    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)

    env.set_init_state(state_0)
    sim.forward()

    for _ in range(5):
        env.step(np.zeros(7, dtype=np.float32))
        freeze_distractors(sim)

    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - pregrasp_lead)

    obs = env.env._get_observations()
    home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    pregrasp_target = eef_orig[t_pregrasp].copy()
    pregrasp_target[0] += total_dx
    pregrasp_target[1] += total_dy

    gripper_cmd = -1.0

    bev_frames = []
    wrist_frames = []
    eef_pos_list = []
    eef_quat_list = []
    gripper_list = []
    wrist_w2c_list = []
    wrist_ext_list = []

    def record(obs):
        eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
        eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
        eef_pos_list.append(eef_pos)
        eef_quat_list.append(eef_quat)
        gripper_list.append(gripper_cmd)

        bev_img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
        bev_frames.append(bev_img)
        wrist_img = np.flipud(obs[f"{WRIST_CAM}_image"]).copy()
        wrist_frames.append(wrist_img)

        # Per-frame wrist camera transforms (wrist moves with EEF)
        wrist_w2c_t = get_camera_transform_matrix(sim, WRIST_CAM, H, W).astype(np.float32)
        wrist_ext_t = get_camera_extrinsic_matrix(sim, WRIST_CAM).astype(np.float32)
        wrist_w2c_list.append(wrist_w2c_t)
        wrist_ext_list.append(wrist_ext_t)

    record(obs)

    # Phase 1: interpolate home → pre-grasp
    interp_wps = interpolate_waypoints(home_pos, pregrasp_target, interp_steps)
    for wp in interp_wps:
        obs = servo_to_position(env, wp, -1.0, max_servo=max_servo)
        record(obs)

    # Phase 2: execute shifted trajectory from pre-grasp onward
    phase2_indices = list(range(t_pregrasp, len(eef_orig), frame_stride))
    success = False
    for t in phase2_indices:
        target_pos = eef_orig[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))
        if gripper_cmd > 0 and z_offset != 0:
            target_pos[2] += z_offset

        obs = servo_to_position(env, target_pos, gripper_cmd, max_servo=max_servo)
        record(obs)

        if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
            success = True

    # Static BEV camera params
    agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
    agent_w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K_norm = agent_K.copy()
    agent_K_norm[0] /= W
    agent_K_norm[1] /= H

    # Wrist K (constant) — intrinsics are camera-only, don't change with pose
    wrist_K = get_camera_intrinsic_matrix(sim, WRIST_CAM, H, W).astype(np.float32)
    wrist_K_norm = wrist_K.copy()
    wrist_K_norm[0] /= W
    wrist_K_norm[1] /= H

    # Compute BEV pixel projections
    eef_arr = np.stack(eef_pos_list)
    bev_pix = []
    for i in range(len(eef_arr)):
        pix_rc = project_points_from_world_to_camera(
            eef_arr[i:i+1].astype(np.float64), agent_w2c, H, W)[0]
        bev_pix.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

    # Compute wrist pixel projections (using per-frame wrist w2c)
    wrist_pix = []
    for i in range(len(eef_arr)):
        pix_rc = project_points_from_world_to_camera(
            eef_arr[i:i+1].astype(np.float64), wrist_w2c_list[i], H, W)[0]
        wrist_pix.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

    return {
        "bev_frames": bev_frames,
        "wrist_frames": wrist_frames,
        "eef_pos": np.stack(eef_pos_list),
        "eef_quat": np.stack(eef_quat_list),
        "gripper": np.array(gripper_list, dtype=np.float32),
        "pix_uv": np.stack(bev_pix),
        "cam_extrinsic": agent_ext,
        "cam_K_norm": agent_K_norm,
        "world_to_cam": agent_w2c,
        "wrist_pix_uv": np.stack(wrist_pix),
        "wrist_extrinsics": np.stack(wrist_ext_list),
        "wrist_cam_K_norm": wrist_K_norm,
        "wrist_w2c": np.stack(wrist_w2c_list),
        "base_z": np.float32(0.912),
        "success": success,
        "dx": dx, "dy": dy,
    }


def save_demo_2view(data, demo_dir):
    demo_dir = Path(demo_dir)
    (demo_dir / "frames").mkdir(parents=True, exist_ok=True)
    (demo_dir / "wrist_frames").mkdir(parents=True, exist_ok=True)

    for i, frame in enumerate(data["bev_frames"]):
        cv2.imwrite(str(demo_dir / "frames" / f"{i:06d}.png"),
                    cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    for i, frame in enumerate(data["wrist_frames"]):
        cv2.imwrite(str(demo_dir / "wrist_frames" / f"{i:06d}.png"),
                    cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    np.save(demo_dir / "eef_pos.npy",        data["eef_pos"])
    np.save(demo_dir / "eef_quat.npy",       data["eef_quat"])
    np.save(demo_dir / "gripper.npy",        data["gripper"])
    np.save(demo_dir / "pix_uv.npy",         data["pix_uv"])
    np.save(demo_dir / "cam_extrinsic.npy",  data["cam_extrinsic"])
    np.save(demo_dir / "cam_K_norm.npy",     data["cam_K_norm"])
    np.save(demo_dir / "world_to_cam.npy",   data["world_to_cam"])
    np.save(demo_dir / "wrist_pix_uv.npy",   data["wrist_pix_uv"])
    np.save(demo_dir / "wrist_extrinsics.npy", data["wrist_extrinsics"])
    np.save(demo_dir / "wrist_cam_K_norm.npy", data["wrist_cam_K_norm"])
    np.save(demo_dir / "wrist_w2c.npy",      data["wrist_w2c"])
    np.save(demo_dir / "base_z.npy",         data["base_z"])
    np.save(demo_dir / "actions.npy",
            np.zeros((len(data["eef_pos"]), 7), dtype=np.float32))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--dx_min", type=float, default=-0.15)
    parser.add_argument("--dx_max", type=float, default=0.0)
    parser.add_argument("--dy_min", type=float, default=-0.10)
    parser.add_argument("--dy_max", type=float, default=0.10)
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--max_servo", type=int, default=25)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_objpos_task0_2view")
    parser.add_argument("--left_half_only", action="store_true",
                        help="Render only j < N/2 (left half), for faster turnaround")
    parser.add_argument("--i_start", type=int, default=0,
                        help="Start dx index (inclusive) for parallel rendering")
    parser.add_argument("--i_end", type=int, default=None,
                        help="End dx index (exclusive) for parallel rendering")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # Extract original EEF trajectory (use bev cam only for this side-channel pass)
    print("Extracting original EEF trajectory...")
    env_tmp = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = extract_demo_eef_positions(env_tmp, states)
    env_tmp.close()

    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    center_dx = -states[0][bowl_i]
    center_dy = -states[0][bowl_i + 1]
    print(f"Centering offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    N = args.grid_size
    dx_vals = np.linspace(args.dx_min, args.dx_max, N)
    dy_vals = np.linspace(args.dy_min, args.dy_max, N)
    print(f"\nGrid: {N}x{N} = {N*N} trajectories ({'left half only' if args.left_half_only else 'full grid'})")
    print(f"  dx: [{args.dx_min}, {args.dx_max}]")
    print(f"  dy: [{args.dy_min}, {args.dy_max}]")

    # Create env with BOTH cameras
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM, WRIST_CAM],
    )
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    hide_furniture(sim)
    hide_distractors_visual(sim)

    task_dir = Path(args.out_root) / args.benchmark / f"task_{args.task_id}"
    task_dir.mkdir(parents=True, exist_ok=True)
    np.savez(task_dir / "grid_meta.npz", dx_vals=dx_vals, dy_vals=dy_vals,
             center_dx=center_dx, center_dy=center_dy)

    successes = 0
    total = 0
    j_range = range(N // 2) if args.left_half_only else range(N)
    i_end = args.i_end if args.i_end is not None else N
    i_start = max(0, args.i_start)
    i_end = min(N, i_end)
    dx_iter = list(enumerate(dx_vals))[i_start:i_end]

    for i, dx in tqdm(dx_iter, desc=f"dx[{i_start}:{i_end}]"):
        for j in j_range:
            dy = dy_vals[j]
            demo_idx = i * N + j
            demo_dir = task_dir / f"demo_{demo_idx}"

            if (demo_dir / "wrist_w2c.npy").exists():
                total += 1
                continue

            env.env.timestep = 0
            env.env.done = False

            data = generate_trajectory_2view(
                env, states, actions, eef_orig,
                dx, dy, center_dx, center_dy,
                frame_stride=args.frame_stride,
                z_offset=args.z_offset,
                max_servo=args.max_servo,
                image_size=args.image_size,
            )

            save_demo_2view(data, demo_dir)
            total += 1
            if data["success"]:
                successes += 1
            if total % 10 == 0:
                print(f"  [{total}] dx={dx:.3f} dy={dy:.3f} success={data['success']} "
                      f"(total: {successes}/{total})")

    env.close()
    print(f"\nDone. {successes}/{total} successes.")


if __name__ == "__main__":
    main()
