"""Generate OOD viewpoint dataset for PARA training.

Creates a 16x16 grid of camera viewpoints on a spherical cap for task 0.
Uses the CENTER object position (no distractor removal needed for viewpoint study).
Replays the first demo trajectory from each viewpoint and saves the results.

For viewpoint, we parameterize as (theta, phi) on a spherical cap:
  theta = polar angle from default camera direction (0 = default view)
  phi = azimuthal angle around the default direction

Usage:
    python generate_ood_viewpoint.py --grid_size 16 --out_root /data/libero/ood_viewpoint_task0
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

AGENT_CAM = "agentview"
WRIST_CAM = "robot0_eye_in_hand"
TABLE_Z = 0.90  # table surface height


def look_at_quat(cam_pos, target):
    """Compute MuJoCo camera quaternion (w,x,y,z) so camera at cam_pos looks at target.

    MuJoCo camera convention: camera looks along -z in its local frame, y is up.
    Copied from ood_libero/viewpoint_distribution.py (verified working).
    """
    forward = target - cam_pos
    forward = forward / (np.linalg.norm(forward) + 1e-12)

    cam_z = -forward  # camera -z = forward => camera z = -forward

    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(forward, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])

    cam_x = np.cross(up_hint, cam_z)
    cam_x = cam_x / (np.linalg.norm(cam_x) + 1e-12)
    cam_y = np.cross(cam_z, cam_x)
    cam_y = cam_y / (np.linalg.norm(cam_y) + 1e-12)

    R = np.stack([cam_x, cam_y, cam_z], axis=-1)
    quat_xyzw = ScipyR.from_matrix(R).as_quat()  # scipy: (x,y,z,w)
    return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])  # MuJoCo: (w,x,y,z)


def generate_viewpoint_grid(default_pos, look_at_point, grid_size, theta_max_deg):
    """Generate a grid of camera positions on a spherical cap."""
    radius = np.linalg.norm(default_pos - look_at_point)
    default_dir = (default_pos - look_at_point) / radius

    # Create orthonormal basis around default direction
    up = np.array([0, 0, 1.0])
    if abs(np.dot(default_dir, up)) > 0.99:
        up = np.array([1, 0, 0.0])
    right = np.cross(default_dir, up)
    right /= np.linalg.norm(right)
    true_up = np.cross(right, default_dir)

    theta_max = np.radians(theta_max_deg)
    N = grid_size

    # Grid over (theta, phi)
    thetas = np.linspace(0, theta_max, N)
    phis = np.linspace(0, 2 * np.pi * (1 - 1/N), N)

    positions = []
    quaternions = []
    for theta in thetas:
        for phi in phis:
            # Spherical cap offset
            offset = (np.sin(theta) * np.cos(phi) * right +
                      np.sin(theta) * np.sin(phi) * true_up +
                      np.cos(theta) * default_dir)
            pos = look_at_point + radius * offset
            quat = look_at_quat(pos, look_at_point)
            positions.append(pos)
            quaternions.append(quat)

    return np.array(positions), np.array(quaternions)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--theta_max", type=float, default=20,
                        help="Max angle in degrees from default view")
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_viewpoint_task0")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # Create env
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM, WRIST_CAM])
    env.seed(0); env.reset()
    sim = env.env.sim
    H = W = args.image_size

    # Get default camera position and forward direction from sim data (after forward())
    cam_id = sim.model.camera_name2id(AGENT_CAM)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    default_quat = sim.model.cam_quat[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3).copy()
    forward = -cam_xmat[:, 2]  # MuJoCo: camera looks along -z in local frame
    forward = forward / (np.linalg.norm(forward) + 1e-12)

    # Trace ray to table plane to find look-at point
    if abs(forward[2]) > 1e-6:
        t = (TABLE_Z - default_pos[2]) / forward[2]
        look_at_point = default_pos + t * forward
    else:
        look_at_point = np.array([0.0, 0.0, TABLE_Z])

    print(f"Default camera pos: {default_pos}")
    print(f"Look-at point: {look_at_point}")

    # Generate viewpoint grid
    N = args.grid_size
    positions, quaternions = generate_viewpoint_grid(
        default_pos, look_at_point, N, args.theta_max)
    print(f"\nGrid: {N}x{N} = {len(positions)} viewpoints")
    print(f"  theta_max: {args.theta_max}°")

    task_dir = Path(args.out_root) / args.benchmark / f"task_{args.task_id}"
    task_dir.mkdir(parents=True, exist_ok=True)

    # Save grid metadata
    np.savez(task_dir / "grid_meta.npz", positions=positions, quaternions=quaternions,
             default_pos=default_pos, default_quat=default_quat, look_at=look_at_point)

    # Object position grid for random sampling
    objpos_meta_path = Path("/data/libero/ood_objpos_task0/libero_spatial/task_0/grid_meta.npz")
    if objpos_meta_path.exists():
        objpos_meta = np.load(objpos_meta_path)
        dx_vals = objpos_meta["dx_vals"]
        dy_vals = objpos_meta["dy_vals"]
        center_dx = float(objpos_meta["center_dx"])
        center_dy = float(objpos_meta["center_dy"])
        print(f"Loaded objpos grid: {len(dx_vals)}x{len(dy_vals)}, center=({center_dx:.3f},{center_dy:.3f})")
    else:
        dx_vals = np.array([0.0])
        dy_vals = np.array([0.0])
        center_dx = center_dy = 0.0
        print("No objpos grid found, using default position")

    rng = np.random.RandomState(42)

    # State layout for shifting objects (from CLAUDE.md)
    STATE_QPOS_OFFSET = 1
    PICK_QPOS_START = 9   # akita_black_bowl_1
    PLACE_QPOS_START = 37  # plate_1
    DISTRACTOR_QPOS = {16: "bowl2", 23: "cookies", 30: "ramekin"}

    # Hide furniture + distractors once
    for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(fname)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass
    sim.forward()
    # Hide distractors visually
    distractor_bodies = set()
    for dname in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
        try:
            distractor_bodies.add(sim.model.body_name2id(dname))
        except Exception:
            pass
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in distractor_bodies:
            sim.model.geom_rgba[geom_id][3] = 0.0

    # For each viewpoint, replay demo with random object position
    for vi in tqdm(range(len(positions)), desc="Viewpoints"):
        demo_dir = task_dir / f"demo_{vi}"
        if (demo_dir / "eef_pos.npy").exists():
            continue

        frames_dir = demo_dir / "frames"
        frames_dir.mkdir(parents=True, exist_ok=True)

        # Set camera viewpoint
        sim.model.cam_pos[cam_id] = positions[vi]
        sim.model.cam_quat[cam_id] = quaternions[vi]

        # Random object position
        dx = rng.choice(dx_vals)
        dy = rng.choice(dy_vals)

        eef_positions = []
        eef_quats = []
        grippers = []

        # Replay demo states with shifted objects
        for t in range(len(states)):
            state_t = states[t].copy()
            # Move distractors off-screen
            for qps in DISTRACTOR_QPOS:
                si = qps + STATE_QPOS_OFFSET
                state_t[si:si+3] = [10.0, 10.0, 0.9]
            # Shift pick and place objects
            for qps in [PICK_QPOS_START, PLACE_QPOS_START]:
                si = qps + STATE_QPOS_OFFSET
                state_t[si] += center_dx + dx
                state_t[si+1] += center_dy + dy

            env.set_init_state(state_t)
            sim.forward()
            obs = env.env._get_observations()

            eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
            eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
            # Shift EEF to match object shift
            eef_pos[0] += center_dx + dx
            eef_pos[1] += center_dy + dy
            eef_positions.append(eef_pos)
            eef_quats.append(eef_quat)

            if t < len(actions):
                grippers.append(float(np.clip(actions[t, 6], -1, 1)))
            else:
                grippers.append(grippers[-1] if grippers else -1.0)

            # Save agent image
            agent_img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
            cv2.imwrite(str(frames_dir / f"{t:06d}.png"), cv2.cvtColor(agent_img, cv2.COLOR_RGB2BGR))

        T = len(eef_positions)
        eef_arr = np.stack(eef_positions)
        quat_arr = np.stack(eef_quats)

        # Camera params for the modified viewpoint
        agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
        agent_w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
        agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
        agent_K_norm = agent_K.copy()
        agent_K_norm[0] /= W; agent_K_norm[1] /= H

        # Pixel projections
        pix_uvs = []
        for i in range(T):
            pix_rc = project_points_from_world_to_camera(
                eef_arr[i:i+1].astype(np.float64), agent_w2c, H, W)[0]
            pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

        # Save
        np.save(demo_dir / "eef_pos.npy", eef_arr)
        np.save(demo_dir / "eef_quat.npy", quat_arr)
        np.save(demo_dir / "gripper.npy", np.array(grippers[:T], dtype=np.float32))
        np.save(demo_dir / "pix_uv.npy", np.stack(pix_uvs))
        np.save(demo_dir / "cam_extrinsic.npy", agent_ext)
        np.save(demo_dir / "cam_K_norm.npy", agent_K_norm)
        np.save(demo_dir / "world_to_cam.npy", agent_w2c)
        np.save(demo_dir / "base_z.npy", np.float32(0.912))
        np.save(demo_dir / "actions.npy", actions)
        # Save shift metadata
        np.savez(demo_dir / "shift_meta.npz", dx=dx, dy=dy, vi=vi)

    # Restore default camera
    sim.model.cam_pos[cam_id] = default_pos
    sim.model.cam_quat[cam_id] = default_quat

    env.close()
    print(f"\nDone. Saved {len(positions)} viewpoints to {task_dir}")


if __name__ == "__main__":
    main()
