"""viewpoint_distribution.py — Sample camera viewpoints on a spherical cap and render LIBERO frame 0.

Defines a distribution of camera viewpoints on a spherical cap centered on the
default agentview camera direction (view_0). The cap is parameterized by a max
angle theta_max (default 30 deg) from the default view direction.

Outputs:
    ood_libero/out/viewpoints_3d.png   — 3D scatter of camera positions (spheres)
    ood_libero/out/renders_grid.png    — grid of rendered images from each viewpoint

Usage:
    python ood_libero/viewpoint_distribution.py [--n_views 16] [--theta_max 30] [--image_size 256]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
)


# ---------------------------------------------------------------------------
# Half-sphere sampling
# ---------------------------------------------------------------------------

def fibonacci_spherical_cap(n_points: int, theta_max_deg: float = 30.0) -> np.ndarray:
    """Sample n_points approximately uniformly on a spherical cap around the +z pole.

    The cap subtends a half-angle of theta_max_deg from the pole.
    Uses the Fibonacci / golden-angle method for near-uniform spacing.

    Returns (n_points, 3) unit vectors within theta_max_deg of +z.
    """
    golden_ratio = (1 + np.sqrt(5)) / 2
    indices = np.arange(n_points)

    # For a spherical cap with half-angle theta_max, z ranges from
    # cos(0)=1 (pole) down to cos(theta_max).
    cos_theta_max = np.cos(np.radians(theta_max_deg))

    # Uniformly distribute z in [cos_theta_max, 1]
    # index 0 → z=1 (pole), index n-1 → z=cos_theta_max (rim)
    z = 1.0 - indices / max(n_points - 1, 1) * (1.0 - cos_theta_max)
    radius_xy = np.sqrt(np.clip(1.0 - z * z, 0, None))

    # Golden-angle azimuthal distribution
    phi = 2 * np.pi * indices / golden_ratio

    x = radius_xy * np.cos(phi)
    y = radius_xy * np.sin(phi)

    points = np.stack([x, y, z], axis=-1)  # (n, 3)
    points /= np.linalg.norm(points, axis=-1, keepdims=True) + 1e-12
    return points


def rotate_cap_to_direction(cap_points: np.ndarray, center_dir: np.ndarray) -> np.ndarray:
    """Rotate points sampled around +z pole so they are centered on center_dir.

    Uses Rodrigues' rotation formula to find the rotation from +z to center_dir.
    Returns rotated (n, 3) unit vectors.
    """
    center_dir = center_dir / (np.linalg.norm(center_dir) + 1e-12)
    z_axis = np.array([0.0, 0.0, 1.0])

    dot = np.dot(z_axis, center_dir)
    if dot > 0.9999:
        # Already aligned with +z
        return cap_points.copy()
    if dot < -0.9999:
        # Opposite to +z — rotate 180 deg around x
        return cap_points * np.array([1.0, -1.0, -1.0])

    # Rotation axis = z_axis x center_dir
    axis = np.cross(z_axis, center_dir)
    axis = axis / (np.linalg.norm(axis) + 1e-12)
    angle = np.arccos(np.clip(dot, -1, 1))

    # Rodrigues' rotation: R*v = v*cos(a) + (axis x v)*sin(a) + axis*(axis . v)*(1-cos(a))
    cos_a = np.cos(angle)
    sin_a = np.sin(angle)
    rotated = (
        cap_points * cos_a
        + np.cross(axis[None, :], cap_points) * sin_a
        + axis[None, :] * (cap_points @ axis[:, None]) * (1 - cos_a)
    )
    rotated /= np.linalg.norm(rotated, axis=-1, keepdims=True) + 1e-12
    return rotated


def look_at_quaternion_mujoco(cam_pos: np.ndarray, target: np.ndarray) -> np.ndarray:
    """Compute MuJoCo camera quaternion (w, x, y, z) so camera at cam_pos looks at target.

    MuJoCo camera convention: camera looks along -z in its local frame,
    y is up in the image.
    """
    forward = target - cam_pos
    forward = forward / (np.linalg.norm(forward) + 1e-12)

    # Camera -z = forward  =>  camera z = -forward
    cam_z = -forward

    # World up hint — use world +z unless nearly collinear with forward
    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(forward, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])

    # camera x = right = up_hint x cam_z  (MuJoCo: x is right)
    cam_x = np.cross(up_hint, cam_z)
    cam_x = cam_x / (np.linalg.norm(cam_x) + 1e-12)

    # camera y = cam_z x cam_x
    cam_y = np.cross(cam_z, cam_x)
    cam_y = cam_y / (np.linalg.norm(cam_y) + 1e-12)

    # Rotation matrix: columns are cam_x, cam_y, cam_z in world frame
    R = np.stack([cam_x, cam_y, cam_z], axis=-1)  # (3, 3)

    # Convert rotation matrix to quaternion (w, x, y, z) — MuJoCo convention
    return rotmat_to_quat_wxyz(R)


def rotmat_to_quat_wxyz(R: np.ndarray) -> np.ndarray:
    """Convert 3x3 rotation matrix to quaternion (w, x, y, z)."""
    trace = R[0, 0] + R[1, 1] + R[2, 2]
    if trace > 0:
        s = 0.5 / np.sqrt(trace + 1.0)
        w = 0.25 / s
        x = (R[2, 1] - R[1, 2]) * s
        y = (R[0, 2] - R[2, 0]) * s
        z = (R[1, 0] - R[0, 1]) * s
    elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]:
        s = 2.0 * np.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2])
        w = (R[2, 1] - R[1, 2]) / s
        x = 0.25 * s
        y = (R[0, 1] + R[1, 0]) / s
        z = (R[0, 2] + R[2, 0]) / s
    elif R[1, 1] > R[2, 2]:
        s = 2.0 * np.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2])
        w = (R[0, 2] - R[2, 0]) / s
        x = (R[0, 1] + R[1, 0]) / s
        y = 0.25 * s
        z = (R[1, 2] + R[2, 1]) / s
    else:
        s = 2.0 * np.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1])
        w = (R[1, 0] - R[0, 1]) / s
        x = (R[0, 2] + R[2, 0]) / s
        y = (R[1, 2] + R[2, 1]) / s
        z = 0.25 * s
    q = np.array([w, x, y, z])
    return q / (np.linalg.norm(q) + 1e-12)


# ---------------------------------------------------------------------------
# Camera info extraction
# ---------------------------------------------------------------------------

def get_default_camera_info(sim, camera_name="agentview"):
    """Extract position, forward direction, and look-at point for the default camera."""
    cam_id = sim.model.camera_name2id(camera_name)
    cam_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3).copy()

    # MuJoCo camera: looks along -z in local frame
    # cam_xmat columns are the camera's local axes in world frame
    forward = -cam_xmat[:, 2]  # -z column = look direction
    forward = forward / (np.linalg.norm(forward) + 1e-12)

    return cam_pos, forward, cam_xmat


def compute_look_at_point(cam_pos, forward, table_z=0.85):
    """Trace ray from camera along forward direction to a z-plane (approximate table height).

    If the ray doesn't intersect (pointing up), fall back to a fixed distance along forward.
    """
    if abs(forward[2]) < 1e-6:
        # Forward is nearly horizontal — just go 1m forward
        return cam_pos + forward * 1.0
    t = (table_z - cam_pos[2]) / forward[2]
    if t < 0:
        # Camera is below table or looking up — use fixed distance
        t = 1.0
    return cam_pos + forward * t


def generate_viewpoints(
    look_at: np.ndarray,
    radius: float,
    n_views: int,
    center_dir: np.ndarray,
    theta_max_deg: float = 30.0,
):
    """Generate n_views camera positions on a spherical cap around center_dir.

    The cap is centered on the direction from look_at toward the default camera
    (center_dir) and subtends theta_max_deg from that center.

    Returns:
        positions: (n_views, 3) camera positions in world frame
        quaternions: (n_views, 4) MuJoCo quaternions (w, x, y, z) looking at center
    """
    # Sample on a cap around +z, then rotate to center_dir
    cap_dirs = fibonacci_spherical_cap(n_views, theta_max_deg)
    directions = rotate_cap_to_direction(cap_dirs, center_dir)

    positions = look_at[None, :] + radius * directions  # (n, 3)
    quaternions = np.zeros((n_views, 4))
    for i in range(n_views):
        quaternions[i] = look_at_quaternion_mujoco(positions[i], look_at)

    return positions, quaternions


# ---------------------------------------------------------------------------
# Rendering
# ---------------------------------------------------------------------------

def render_from_viewpoint(env, sim, cam_id, cam_pos, cam_quat, state, camera_name, image_size):
    """Set camera to given pos/quat, set sim state, render and return RGB image."""
    # Save original camera params
    orig_pos = sim.model.cam_pos[cam_id].copy()
    orig_quat = sim.model.cam_quat[cam_id].copy()

    # Set new camera
    sim.model.cam_pos[cam_id] = cam_pos
    sim.model.cam_quat[cam_id] = cam_quat

    # Set state and forward
    env.set_init_state(state)
    sim.forward()

    # Render
    img_key = f"{camera_name}_image"
    obs = env.env._get_observations()
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    rgb = np.ascontiguousarray(np.flipud(rgb))

    # Restore original camera
    sim.model.cam_pos[cam_id] = orig_pos
    sim.model.cam_quat[cam_id] = orig_quat
    sim.forward()

    return rgb


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

def plot_viewpoints_3d(
    look_at, default_pos, positions, out_path, table_z=0.85
):
    """3D scatter plot of camera viewpoints with look-at point and table plane."""
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection="3d")

    # Table plane (translucent)
    table_size = 0.6
    xx, yy = np.meshgrid(
        np.linspace(look_at[0] - table_size, look_at[0] + table_size, 2),
        np.linspace(look_at[1] - table_size, look_at[1] + table_size, 2),
    )
    zz = np.full_like(xx, table_z)
    ax.plot_surface(xx, yy, zz, alpha=0.15, color="brown", label="table")

    # Look-at point
    ax.scatter(*look_at, s=120, c="red", marker="x", linewidths=3, label="look-at")

    # Default camera (view_0)
    ax.scatter(*default_pos, s=100, c="blue", marker="^", label="view_0 (default)")

    # Sampled viewpoints
    ax.scatter(
        positions[:, 0], positions[:, 1], positions[:, 2],
        s=50, c="green", alpha=0.8, label="sampled views",
    )

    # Draw lines from each viewpoint to look-at
    for i in range(len(positions)):
        ax.plot(
            [positions[i, 0], look_at[0]],
            [positions[i, 1], look_at[1]],
            [positions[i, 2], look_at[2]],
            "g-", alpha=0.2, linewidth=0.5,
        )

    # Draw line from default camera to look-at
    ax.plot(
        [default_pos[0], look_at[0]],
        [default_pos[1], look_at[1]],
        [default_pos[2], look_at[2]],
        "b--", alpha=0.5, linewidth=1.5,
    )

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title(f"Camera Viewpoints ({len(positions)} views on spherical cap)")
    ax.legend(loc="upper left")

    # Equal aspect ratio
    all_pts = np.vstack([positions, look_at.reshape(1, 3), default_pos.reshape(1, 3)])
    center = all_pts.mean(axis=0)
    max_range = (all_pts.max(axis=0) - all_pts.min(axis=0)).max() / 2 * 1.2
    ax.set_xlim(center[0] - max_range, center[0] + max_range)
    ax.set_ylim(center[1] - max_range, center[1] + max_range)
    ax.set_zlim(center[2] - max_range, center[2] + max_range)

    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved 3D viewpoint plot: {out_path}")


def plot_renders_grid(images, labels, out_path):
    """Plot rendered images in a grid."""
    n = len(images)
    cols = int(np.ceil(np.sqrt(n)))
    rows = int(np.ceil(n / cols))

    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
    if rows == 1 and cols == 1:
        axes = np.array([[axes]])
    elif rows == 1 or cols == 1:
        axes = axes.reshape(rows, cols)

    for idx in range(rows * cols):
        r, c = divmod(idx, cols)
        ax = axes[r, c]
        if idx < n:
            ax.imshow(images[idx])
            ax.set_title(labels[idx], fontsize=8)
        ax.axis("off")

    plt.suptitle("Renders from sampled viewpoints", fontsize=14)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved render grid: {out_path}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Sample viewpoints and render LIBERO frame 0")
    parser.add_argument("--n_views", type=int, default=16,
                        help="Number of uniformly spaced viewpoints on the half-sphere")
    parser.add_argument("--image_size", type=int, default=256,
                        help="Render resolution")
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--theta_max", type=float, default=30.0,
                        help="Max angle (degrees) from default view for spherical cap")
    parser.add_argument("--table_z", type=float, default=0.85,
                        help="Approximate table surface height in world Z for look-at computation")
    parser.add_argument("--out_dir", type=str, default=None,
                        help="Output directory (default: ood_libero/out/)")
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # ----- Load LIBERO env + demo states -----
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    # ----- Extract default camera (view_0) info -----
    cam_id = sim.model.camera_name2id(args.camera)
    default_pos, forward, cam_xmat = get_default_camera_info(sim, args.camera)
    default_quat = sim.model.cam_quat[cam_id].copy()

    print(f"Default camera '{args.camera}':")
    print(f"  pos:     {default_pos}")
    print(f"  quat:    {default_quat}  (w,x,y,z)")
    print(f"  forward: {forward}")

    # Compute look-at point
    look_at = compute_look_at_point(default_pos, forward, table_z=args.table_z)
    radius = np.linalg.norm(default_pos - look_at)
    print(f"  look-at: {look_at}")
    print(f"  radius:  {radius:.4f}")

    # ----- Generate viewpoints -----
    # center_dir = direction from look_at toward default camera (unit vector)
    center_dir = (default_pos - look_at)
    center_dir = center_dir / (np.linalg.norm(center_dir) + 1e-12)

    positions, quaternions = generate_viewpoints(
        look_at, radius, args.n_views,
        center_dir=center_dir,
        theta_max_deg=args.theta_max,
    )
    print(f"\nGenerated {args.n_views} viewpoints on spherical cap "
          f"(radius={radius:.3f}, theta_max={args.theta_max}°)")

    # ----- Render frame 0 from each viewpoint -----
    state_0 = states[0]

    # First render the default view (view_0)
    env.set_init_state(state_0)
    sim.forward()
    img_key = f"{args.camera}_image"
    obs = env.env._get_observations()
    default_rgb = np.asarray(obs[img_key]).copy()
    if default_rgb.max() <= 1.0:
        default_rgb = (default_rgb * 255).astype(np.uint8)
    default_rgb = np.ascontiguousarray(np.flipud(default_rgb))

    images = [default_rgb]
    labels = ["view_0 (default)"]

    for i in range(args.n_views):
        print(f"  Rendering view {i+1}/{args.n_views}...", end="\r")
        rgb = render_from_viewpoint(
            env, sim, cam_id,
            positions[i], quaternions[i],
            state_0, args.camera, args.image_size,
        )
        images.append(rgb)
        labels.append(f"view_{i+1}")

    print(f"  Rendered {args.n_views + 1} views (1 default + {args.n_views} sampled)")

    # ----- Plot -----
    plot_viewpoints_3d(
        look_at, default_pos, positions,
        str(out_dir / "viewpoints_3d.png"),
        table_z=args.table_z,
    )

    plot_renders_grid(images, labels, str(out_dir / "renders_grid.png"))

    # Also save individual renders
    renders_dir = out_dir / "renders"
    renders_dir.mkdir(exist_ok=True)
    for i, (img, label) in enumerate(zip(images, labels)):
        fname = f"{label.split(' ')[0]}.png"
        cv2.imwrite(str(renders_dir / fname), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    print(f"Saved individual renders to {renders_dir}/")

    # Save viewpoint metadata
    np.savez(
        str(out_dir / "viewpoint_meta.npz"),
        default_pos=default_pos,
        default_quat=default_quat,
        look_at=look_at,
        radius=radius,
        positions=positions,
        quaternions=quaternions,
        n_views=args.n_views,
    )
    print(f"Saved viewpoint metadata to {out_dir / 'viewpoint_meta.npz'}")

    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
