"""Generate OOD viewpoint dataset (viewpoint variation only, fixed object position).

Creates a grid of camera viewpoints on a spherical cap around the default camera.
Multiple demos per viewpoint (with slight random object position jitter for variety).
Clean scene (no distractors/furniture).

Usage:
    python generate_ood_viewpoint.py --n_views 8 --theta_max 25 --demos_per_view 10
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

AGENT_CAM = "agentview"
TABLE_Z = 0.90
STATE_QPOS_OFFSET = 1
PICK_QPOS = 9
PLACE_QPOS = 37
DISTRACTOR_QPOS_STARTS = [16, 23, 30]
FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
DISTRACTOR_BODIES = ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])
DISTRACTOR_DOFS = {16: slice(15, 21), 23: slice(21, 27), 30: slice(27, 33)}


def look_at_quat(cam_pos, target):
    forward = target - cam_pos
    forward = forward / (np.linalg.norm(forward) + 1e-12)
    cam_z = -forward
    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(forward, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])
    cam_x = np.cross(up_hint, cam_z)
    cam_x = cam_x / (np.linalg.norm(cam_x) + 1e-12)
    cam_y = np.cross(cam_z, cam_x)
    R = np.stack([cam_x, cam_y, cam_z], axis=-1)
    q = ScipyR.from_matrix(R).as_quat()
    return np.array([q[3], q[0], q[1], q[2]])


def generate_viewpoint_grid(default_pos, look_at_point, n_views, theta_max_deg):
    radius = np.linalg.norm(default_pos - look_at_point)
    default_dir = (default_pos - look_at_point) / radius
    up = np.array([0, 0, 1.0])
    if abs(np.dot(default_dir, up)) > 0.99:
        up = np.array([1, 0, 0.0])
    right = np.cross(default_dir, up)
    right /= np.linalg.norm(right)
    true_up = np.cross(right, default_dir)

    theta_max = np.radians(theta_max_deg)
    thetas = np.linspace(0, theta_max, n_views)
    phis = np.linspace(0, 2 * np.pi * (1 - 1/n_views), n_views)

    positions, quaternions, theta_indices, phi_indices = [], [], [], []
    for ti, theta in enumerate(thetas):
        for pi, phi in enumerate(phis):
            offset = (np.sin(theta) * np.cos(phi) * right +
                      np.sin(theta) * np.sin(phi) * true_up +
                      np.cos(theta) * default_dir)
            pos = look_at_point + radius * offset
            positions.append(pos)
            quaternions.append(look_at_quat(pos, look_at_point))
            theta_indices.append(ti)
            phi_indices.append(pi)
    return np.array(positions), np.array(quaternions), thetas, phis, theta_indices, phi_indices


def find_grasp_timestep(actions):
    gripper = actions[:, 6]
    for t in range(1, len(gripper)):
        if gripper[t] > 0 and gripper[t - 1] <= 0:
            return t
    return len(gripper) // 2


def setup_clean_scene(sim):
    for fname in FURNITURE_BODIES:
        try:
            sim.model.body_pos[sim.model.body_name2id(fname)] = np.array([0, 0, -5.0])
        except Exception:
            pass
    sim.forward()
    dist_bodies = set()
    for dn in DISTRACTOR_BODIES:
        try:
            dist_bodies.add(sim.model.body_name2id(dn))
        except Exception:
            pass
    for gid in range(sim.model.ngeom):
        if sim.model.geom_bodyid[gid] in dist_bodies:
            sim.model.geom_rgba[gid][3] = 0.0


def freeze_distractors(sim):
    for qps, dof in DISTRACTOR_DOFS.items():
        sim.data.qpos[qps:qps+3] = DISTRACTOR_POS
        sim.data.qvel[dof] = 0.0


def shift_state(state, dx, dy):
    s = state.copy()
    for qps in [PICK_QPOS, PLACE_QPOS]:
        si = qps + STATE_QPOS_OFFSET
        s[si] += dx
        s[si+1] += dy
    for qps in DISTRACTOR_QPOS_STARTS:
        si = qps + STATE_QPOS_OFFSET
        s[si:si+3] = DISTRACTOR_POS
    return s


def servo_to(env, target, gripper_cmd, max_servo=25, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        obs = env.env._get_observations()
        cur = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        delta = target - cur
        if np.linalg.norm(delta) < threshold:
            break
        action = np.zeros(7, dtype=np.float32)
        action[:3] = np.clip(delta / 0.05, -1, 1)
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            break
    return obs or env.env._get_observations()


def interpolate_waypoints(start, end, n):
    alphas = np.linspace(0, 1, n + 1)[1:]
    return [start + a * (end - start) for a in alphas]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_views", type=int, default=8, help="Grid size (n_views x n_views viewpoints)")
    parser.add_argument("--theta_max", type=float, default=25, help="Max polar angle in degrees")
    parser.add_argument("--demos_per_view", type=int, default=10, help="Demos per viewpoint")
    parser.add_argument("--dx_min", type=float, default=-0.40, help="Min object dx offset")
    parser.add_argument("--dx_max", type=float, default=-0.01, help="Max object dx offset")
    parser.add_argument("--dy_min", type=float, default=-0.30, help="Min object dy offset")
    parser.add_argument("--dy_max", type=float, default=0.30, help="Max object dy offset")
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_viewpoint_v2")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        dk = sorted([k for k in f["data"].keys() if k.startswith("demo_")])[0]
        states = f[f"data/{dk}/states"][()]
        actions = f[f"data/{dk}/actions"][()]

    print("Extracting EEF trajectory...")
    env_tmp = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                                  camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = []
    for t in range(len(states)):
        env_tmp.set_init_state(states[t])
        env_tmp.env.sim.forward()
        eef_orig.append(np.array(env_tmp.env._get_observations()["robot0_eef_pos"], dtype=np.float64))
    eef_orig = np.array(eef_orig)
    env_tmp.close()

    bowl_si = PICK_QPOS + STATE_QPOS_OFFSET
    center_dx = -states[0][bowl_si]
    center_dy = -states[0][bowl_si + 1]
    print(f"Center offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    env = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                              camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    setup_clean_scene(sim)
    H = W = args.image_size

    cam_id = sim.model.camera_name2id(AGENT_CAM)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    forward = -cam_xmat[:, 2]
    t_hit = (TABLE_Z - default_pos[2]) / (forward[2] + 1e-8)
    look_at = default_pos + t_hit * forward
    print(f"Default cam: {default_pos}, look-at: {look_at}")

    vp_positions, vp_quats, thetas, phis, theta_idx, phi_idx = generate_viewpoint_grid(
        default_pos, look_at, args.n_views, args.theta_max)

    n_viewpoints = len(vp_positions)
    n_total = n_viewpoints * args.demos_per_view
    print(f"\n{args.n_views}x{args.n_views} = {n_viewpoints} viewpoints x {args.demos_per_view} demos = {n_total} episodes")
    print(f"Theta: {np.degrees(thetas).round(1)}")
    print(f"Phi: {np.degrees(phis).round(1)}")
    print(f"Object position range: dx=[{args.dx_min}, {args.dx_max}], dy=[{args.dy_min}, {args.dy_max}]")

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)

    np.savez(task_dir / "viewpoint_meta.npz",
             thetas_deg=np.degrees(thetas), phis_deg=np.degrees(phis),
             n_views=args.n_views, theta_max=args.theta_max,
             demos_per_view=args.demos_per_view,
             vp_positions=vp_positions, vp_quats=vp_quats,
             center_dx=center_dx, center_dy=center_dy,
             dx_min=args.dx_min, dx_max=args.dx_max,
             dy_min=args.dy_min, dy_max=args.dy_max)

    rng = np.random.RandomState(42)
    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - 6)
    demo_idx = 0
    successes = 0

    for vi in tqdm(range(n_viewpoints), desc="Viewpoints"):
        sim.model.cam_pos[cam_id] = vp_positions[vi]
        sim.model.cam_quat[cam_id] = vp_quats[vi]

        for di in range(args.demos_per_view):
            demo_dir = task_dir / f"demo_{demo_idx}"
            if (demo_dir / "eef_pos.npy").exists():
                demo_idx += 1
                continue

            # Random object position from full range
            dx_offset = rng.uniform(args.dx_min, args.dx_max)
            dy_offset = rng.uniform(args.dy_min, args.dy_max)
            total_dx = center_dx + dx_offset
            total_dy = center_dy + dy_offset

            env.env.timestep = 0
            env.env.done = False
            state_0 = shift_state(states[0], total_dx, total_dy)
            env.set_init_state(state_0)
            sim.forward()
            for _ in range(5):
                env.step(np.zeros(7, dtype=np.float32))
                freeze_distractors(sim)

            obs = env.env._get_observations()
            home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            pregrasp_target = eef_orig[t_pregrasp].copy()
            pregrasp_target[0] += total_dx
            pregrasp_target[1] += total_dy

            gripper_cmd = -1.0
            rec_frames, rec_eef, rec_quat, rec_grip = [], [], [], []

            def record(o):
                rec_eef.append(np.array(o["robot0_eef_pos"], dtype=np.float32))
                rec_quat.append(np.array(o["robot0_eef_quat"], dtype=np.float32))
                rec_grip.append(gripper_cmd)
                rec_frames.append(np.flipud(o[f"{AGENT_CAM}_image"]).copy())

            record(obs)

            for wp in interpolate_waypoints(home_pos, pregrasp_target, 8):
                obs = servo_to(env, wp, -1.0, max_servo=25)
                record(obs)

            phase2 = list(range(t_pregrasp, len(eef_orig), args.frame_stride))
            success = False
            for t in phase2:
                target = eef_orig[t].copy()
                target[0] += total_dx
                target[1] += total_dy
                if t < len(actions):
                    gripper_cmd = float(np.clip(actions[t, 6], -1, 1))
                if gripper_cmd > 0 and args.z_offset != 0:
                    target[2] += args.z_offset
                obs = servo_to(env, target, gripper_cmd, max_servo=25)
                record(obs)
                if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                    success = True

            demo_dir.mkdir(parents=True, exist_ok=True)
            frames_dir = demo_dir / "frames"
            frames_dir.mkdir(exist_ok=True)
            for fi, frame in enumerate(rec_frames):
                cv2.imwrite(str(frames_dir / f"{fi:06d}.png"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            eef_arr = np.stack(rec_eef)
            w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W)
            K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W)
            K_norm = K.copy()
            K_norm[0] /= W
            K_norm[1] /= H
            ext = get_camera_extrinsic_matrix(sim, AGENT_CAM)

            pix_uvs = []
            for ei in range(len(eef_arr)):
                pix_rc = project_points_from_world_to_camera(
                    eef_arr[ei:ei+1].astype(np.float64), w2c, H, W)[0]
                pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

            np.save(demo_dir / "eef_pos.npy", eef_arr)
            np.save(demo_dir / "eef_quat.npy", np.stack(rec_quat))
            np.save(demo_dir / "gripper.npy", np.array(rec_grip, dtype=np.float32))
            np.save(demo_dir / "pix_uv.npy", np.stack(pix_uvs))
            np.save(demo_dir / "cam_extrinsic.npy", ext.astype(np.float32))
            np.save(demo_dir / "cam_K_norm.npy", K_norm.astype(np.float32))
            np.save(demo_dir / "world_to_cam.npy", w2c.astype(np.float32))
            np.save(demo_dir / "base_z.npy", np.float32(0.912))
            np.save(demo_dir / "actions.npy", np.zeros((len(eef_arr), 7), dtype=np.float32))
            np.savez(demo_dir / "meta.npz", vi=vi, di=di,
                     theta_idx=theta_idx[vi], phi_idx=phi_idx[vi],
                     dx=dx_offset, dy=dy_offset)

            if success:
                successes += 1
            demo_idx += 1

            if demo_idx % 50 == 0:
                print(f"  [{demo_idx}/{n_total}] vi={vi} success={success} (total: {successes}/{demo_idx})")

    env.close()
    print(f"\nDone. {successes}/{demo_idx} succeeded.")
    print(f"Saved to {task_dir}")


if __name__ == "__main__":
    main()
