"""Generate OOD viewpoint + object position dataset.

5x5 viewpoint grid × 3 random object positions = 75 episodes.
Uses servo replay (like objpos) so EEF trajectories are consistent with object shifts.

Usage:
    python generate_ood_viewpoint_objpos.py --n_views 5 --n_objpos 3
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

AGENT_CAM = "agentview"
TABLE_Z = 0.90

# State layout
STATE_QPOS_OFFSET = 1
PICK_QPOS = 9
PLACE_QPOS = 37
DISTRACTOR_QPOS_STARTS = [16, 23, 30]
FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
DISTRACTOR_BODIES = ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])
DISTRACTOR_DOFS = {16: slice(15, 21), 23: slice(21, 27), 30: slice(27, 33)}


def look_at_quat(cam_pos, target):
    """MuJoCo camera quaternion (w,x,y,z) looking at target."""
    forward = target - cam_pos
    forward = forward / (np.linalg.norm(forward) + 1e-12)
    cam_z = -forward
    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(forward, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])
    cam_x = np.cross(up_hint, cam_z)
    cam_x = cam_x / (np.linalg.norm(cam_x) + 1e-12)
    cam_y = np.cross(cam_z, cam_x)
    R = np.stack([cam_x, cam_y, cam_z], axis=-1)
    q = ScipyR.from_matrix(R).as_quat()  # (x,y,z,w)
    return np.array([q[3], q[0], q[1], q[2]])  # (w,x,y,z)


def generate_viewpoint_grid(default_pos, look_at_point, n_views, theta_max_deg):
    radius = np.linalg.norm(default_pos - look_at_point)
    default_dir = (default_pos - look_at_point) / radius
    up = np.array([0, 0, 1.0])
    if abs(np.dot(default_dir, up)) > 0.99:
        up = np.array([1, 0, 0.0])
    right = np.cross(default_dir, up)
    right /= np.linalg.norm(right)
    true_up = np.cross(right, default_dir)

    theta_max = np.radians(theta_max_deg)
    thetas = np.linspace(0, theta_max, n_views)
    phis = np.linspace(0, 2 * np.pi * (1 - 1/n_views), n_views)

    positions, quaternions = [], []
    for theta in thetas:
        for phi in phis:
            offset = (np.sin(theta) * np.cos(phi) * right +
                      np.sin(theta) * np.sin(phi) * true_up +
                      np.cos(theta) * default_dir)
            pos = look_at_point + radius * offset
            positions.append(pos)
            quaternions.append(look_at_quat(pos, look_at_point))
    return np.array(positions), np.array(quaternions)


def find_grasp_timestep(actions):
    gripper = actions[:, 6]
    for t in range(1, len(gripper)):
        if gripper[t] > 0 and gripper[t - 1] <= 0:
            return t
    return len(gripper) // 2


def setup_clean_scene(sim):
    for fname in FURNITURE_BODIES:
        try:
            sim.model.body_pos[sim.model.body_name2id(fname)] = np.array([0, 0, -5.0])
        except Exception: pass
    sim.forward()
    dist_bodies = set()
    for dn in DISTRACTOR_BODIES:
        try: dist_bodies.add(sim.model.body_name2id(dn))
        except Exception: pass
    for gid in range(sim.model.ngeom):
        if sim.model.geom_bodyid[gid] in dist_bodies:
            sim.model.geom_rgba[gid][3] = 0.0


def freeze_distractors(sim):
    for qps, dof in DISTRACTOR_DOFS.items():
        sim.data.qpos[qps:qps+3] = DISTRACTOR_POS
        sim.data.qvel[dof] = 0.0


def shift_state(state, dx, dy):
    s = state.copy()
    for qps in [PICK_QPOS, PLACE_QPOS]:
        si = qps + STATE_QPOS_OFFSET
        s[si] += dx
        s[si+1] += dy
    for qps in DISTRACTOR_QPOS_STARTS:
        si = qps + STATE_QPOS_OFFSET
        s[si:si+3] = DISTRACTOR_POS
    return s


def servo_to(env, target, gripper_cmd, max_servo=25, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        obs = env.env._get_observations()
        cur = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        delta = target - cur
        if np.linalg.norm(delta) < threshold:
            break
        action = np.zeros(7, dtype=np.float32)
        action[:3] = np.clip(delta / 0.05, -1, 1)
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done: break
    return obs or env.env._get_observations()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_views", type=int, default=5)
    parser.add_argument("--n_objpos", type=int, default=3)
    parser.add_argument("--theta_max", type=float, default=20)
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_viewpoint_objpos_task0")
    parser.add_argument("--max_demos", type=int, default=0, help="Limit total demos (0=all)")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        dk = sorted([k for k in f["data"].keys() if k.startswith("demo_")])[0]
        states = f[f"data/{dk}/states"][()]
        actions = f[f"data/{dk}/actions"][()]

    # Extract original EEF trajectory
    print("Extracting EEF trajectory...")
    env_tmp = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                                  camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = []
    for t in range(len(states)):
        env_tmp.set_init_state(states[t])
        env_tmp.env.sim.forward()
        eef_orig.append(np.array(env_tmp.env._get_observations()["robot0_eef_pos"], dtype=np.float64))
    eef_orig = np.array(eef_orig)
    env_tmp.close()

    # Centering offset
    bowl_si = PICK_QPOS + STATE_QPOS_OFFSET
    center_dx = -states[0][bowl_si]
    center_dy = -states[0][bowl_si + 1]
    print(f"Center offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    # Object position grid (same range as objpos dataset)
    dx_vals = np.linspace(-0.15, 0.0, 16)
    dy_vals = np.linspace(-0.1, 0.1, 16)

    # Create env
    env = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                              camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    setup_clean_scene(sim)
    H = W = args.image_size

    # Get default camera
    cam_id = sim.model.camera_name2id(AGENT_CAM)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    forward = -cam_xmat[:, 2]
    t_hit = (TABLE_Z - default_pos[2]) / (forward[2] + 1e-8)
    look_at = default_pos + t_hit * forward
    print(f"Default cam: {default_pos}, look-at: {look_at}")

    # Generate viewpoints
    vp_positions, vp_quats = generate_viewpoint_grid(default_pos, look_at, args.n_views, args.theta_max)
    n_total = len(vp_positions) * args.n_objpos
    print(f"\n{args.n_views}x{args.n_views} viewpoints × {args.n_objpos} obj positions = {n_total} episodes")

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)

    rng = np.random.RandomState(42)
    demo_idx = 0

    for vi in tqdm(range(len(vp_positions)), desc="Viewpoints"):
        # Set camera
        sim.model.cam_pos[cam_id] = vp_positions[vi]
        sim.model.cam_quat[cam_id] = vp_quats[vi]

        for oi in range(args.n_objpos):
            if args.max_demos > 0 and demo_idx >= args.max_demos:
                break

            demo_dir = task_dir / f"demo_{demo_idx}"
            if (demo_dir / "eef_pos.npy").exists():
                demo_idx += 1
                continue

            dx = rng.choice(dx_vals)
            dy = rng.choice(dy_vals)
            total_dx = center_dx + dx
            total_dy = center_dy + dy

            # Reset robot
            env.env.timestep = 0
            env.env.done = False

            state_0 = shift_state(states[0], total_dx, total_dy)
            env.set_init_state(state_0)
            sim.forward()
            for _ in range(5):
                env.step(np.zeros(7, dtype=np.float32))
                freeze_distractors(sim)

            # Find grasp point
            t_grasp = find_grasp_timestep(actions)
            t_pregrasp = max(0, t_grasp - 6)

            gripper_cmd = -1.0
            rec_frames, rec_eef, rec_quat, rec_grip = [], [], [], []

            def record(obs):
                rec_eef.append(np.array(obs["robot0_eef_pos"], dtype=np.float32))
                rec_quat.append(np.array(obs["robot0_eef_quat"], dtype=np.float32))
                rec_grip.append(gripper_cmd)
                img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
                rec_frames.append(img)

            obs = env.env._get_observations()
            home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            record(obs)

            # Phase 1: Interpolate home → shifted pre-grasp
            pregrasp_target = eef_orig[t_pregrasp].copy()
            pregrasp_target[0] += total_dx
            pregrasp_target[1] += total_dy
            interp_steps = 8
            alphas = np.linspace(0, 1, interp_steps + 1)[1:]
            for a in alphas:
                wp = home_pos + a * (pregrasp_target - home_pos)
                obs = servo_to(env, wp, -1.0)
                record(obs)

            # Phase 2: Execute from pre-grasp onward
            frame_indices = list(range(t_pregrasp, len(eef_orig), args.frame_stride))
            for t in frame_indices:
                target = eef_orig[t].copy()
                target[0] += center_dx + dx
                target[1] += center_dy + dy
                if t < len(actions):
                    gripper_cmd = float(np.clip(actions[t, 6], -1, 1))
                if gripper_cmd > 0 and args.z_offset != 0:
                    target[2] += args.z_offset
                obs = servo_to(env, target, gripper_cmd)
                record(obs)

            T = len(rec_eef)
            eef_arr = np.stack(rec_eef)

            # Camera params
            agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
            agent_w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
            agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
            K_norm = agent_K.copy()
            K_norm[0] /= W; K_norm[1] /= H

            pix_uvs = []
            for i in range(T):
                rc = project_points_from_world_to_camera(eef_arr[i:i+1].astype(np.float64), agent_w2c, H, W)[0]
                pix_uvs.append(np.array([rc[1], rc[0]], dtype=np.float32))

            # Save
            frames_dir = demo_dir / "frames"
            frames_dir.mkdir(parents=True, exist_ok=True)
            for i, f in enumerate(rec_frames):
                cv2.imwrite(str(frames_dir / f"{i:06d}.png"), cv2.cvtColor(f, cv2.COLOR_RGB2BGR))

            np.save(demo_dir / "eef_pos.npy", eef_arr)
            np.save(demo_dir / "eef_quat.npy", np.stack(rec_quat))
            np.save(demo_dir / "gripper.npy", np.array(rec_grip[:T], dtype=np.float32))
            np.save(demo_dir / "pix_uv.npy", np.stack(pix_uvs))
            np.save(demo_dir / "cam_extrinsic.npy", agent_ext)
            np.save(demo_dir / "cam_K_norm.npy", K_norm)
            np.save(demo_dir / "world_to_cam.npy", agent_w2c)
            np.save(demo_dir / "base_z.npy", np.float32(0.912))
            np.save(demo_dir / "actions.npy", np.zeros((T, 7), dtype=np.float32))
            np.savez(demo_dir / "meta.npz", dx=dx, dy=dy, vi=vi, oi=oi)

            demo_idx += 1

    env.close()
    print(f"\nDone. {demo_idx} episodes saved to {task_dir}")


if __name__ == "__main__":
    main()
