"""Generate OOD object position dataset for PARA training.

Creates a 16x16 grid of (dx, dy) object position shifts for task 0.
For each grid point, replays the first demo trajectory with shifted objects
using servo teleport, and saves the resulting frames + EEF data in the
same format as parsed_libero for CachedTrajectoryDataset.

Usage:
    python generate_ood_objpos.py --grid_size 16 --out_root /data/libero/ood_objpos_task0
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

# ---- State array layout ----
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":             {"qpos_start": 9,  "role": "pick"},
    "akita_black_bowl_2":             {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                      {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                        {"qpos_start": 37, "role": "place"},
}

FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])

DISTRACTOR_DOFS = {
    "akita_black_bowl_2":            {"qpos": slice(16, 23), "dof": slice(15, 21)},
    "cookies_1":                     {"qpos": slice(23, 30), "dof": slice(21, 27)},
    "glazed_rim_porcelain_ramekin_1": {"qpos": slice(30, 37), "dof": slice(27, 33)},
}

AGENT_CAM = "agentview"
WRIST_CAM = "robot0_eye_in_hand"


def _si(qpos_start):
    return qpos_start + STATE_QPOS_OFFSET


def move_distractors_in_state(state):
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _si(info["qpos_start"])
            s[i:i + 3] = DISTRACTOR_POS
    return s


def hide_distractors_visual(sim):
    distractor_bodies = set()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            bid = sim.model.body_name2id(f"{name}_main")
            distractor_bodies.add(bid)
    for geom_id in range(sim.model.ngeom):
        body_id = sim.model.geom_bodyid[geom_id]
        if body_id in distractor_bodies:
            sim.model.geom_rgba[geom_id][3] = 0.0


def freeze_distractors(sim):
    for dofs in DISTRACTOR_DOFS.values():
        sim.data.qpos[dofs["qpos"].start:dofs["qpos"].start + 3] = DISTRACTOR_POS
        sim.data.qvel[dofs["dof"]] = 0.0


def shift_pick_place(state, dx, dy):
    s = state.copy()
    for info in TASK0_OBJECTS.values():
        if info["role"] in ("pick", "place"):
            i = _si(info["qpos_start"])
            s[i] += dx
            s[i + 1] += dy
    return s


def hide_furniture(sim):
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        sim.model.body_pos[bid] = np.array([0, 0, -5.0])
    sim.forward()


def extract_demo_eef_positions(env, states):
    """Replay demo states to extract EEF positions."""
    eef_positions = []
    for t in range(len(states)):
        env.set_init_state(states[t])
        env.env.sim.forward()
        obs = env.env._get_observations()
        eef_positions.append(np.array(obs["robot0_eef_pos"], dtype=np.float64))
    return np.array(eef_positions)


def servo_to_position(env, target_pos, gripper_cmd, max_servo=50, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            break
    if obs is None:
        obs = env.env._get_observations()
    return obs


def find_grasp_timestep(actions):
    """Find first timestep where gripper transitions from open to close."""
    gripper = actions[:, 6]
    for t in range(1, len(gripper)):
        if gripper[t] > 0 and gripper[t - 1] <= 0:
            return t
    return len(gripper) // 2


def interpolate_waypoints(start_pos, end_pos, n_steps):
    """Linear interpolation from start to end."""
    alphas = np.linspace(0, 1, n_steps + 1)[1:]
    return [start_pos + a * (end_pos - start_pos) for a in alphas]


def generate_trajectory(env, states, actions, eef_orig, dx, dy, center_dx, center_dy,
                        frame_stride=3, z_offset=-0.015, max_servo=25, image_size=448,
                        pregrasp_lead=6, interp_steps=8):
    """Generate a shifted trajectory with natural start (interpolate to pre-grasp)."""
    sim = env.env.sim
    H = W = image_size
    total_dx = center_dx + dx
    total_dy = center_dy + dy

    # Reset env
    env.env.timestep = 0
    env.env.done = False

    # Prepare shifted initial state (robot joints stay at home, objects shifted)
    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)

    env.set_init_state(state_0)
    sim.forward()

    # Settle
    for _ in range(5):
        env.step(np.zeros(7, dtype=np.float32))
        freeze_distractors(sim)

    # Find grasp point and pre-grasp
    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - pregrasp_lead)

    # Get home EEF position
    obs = env.env._get_observations()
    home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    # Shifted pre-grasp target
    pregrasp_target = eef_orig[t_pregrasp].copy()
    pregrasp_target[0] += total_dx
    pregrasp_target[1] += total_dy

    gripper_cmd = -1.0

    recorded_frames_agent = []
    recorded_eef_pos = []
    recorded_eef_quat = []
    recorded_gripper = []

    def record(obs):
        eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
        eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
        recorded_eef_pos.append(eef_pos)
        recorded_eef_quat.append(eef_quat)
        recorded_gripper.append(gripper_cmd)

        agent_img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
        recorded_frames_agent.append(agent_img)

    record(obs)

    # Phase 1: Interpolate home → pre-grasp (approach)
    interp_wps = interpolate_waypoints(home_pos, pregrasp_target, interp_steps)
    for wp in interp_wps:
        obs = servo_to_position(env, wp, -1.0, max_servo=max_servo)
        record(obs)

    # Phase 2: Execute shifted trajectory from pre-grasp onward
    phase2_indices = list(range(t_pregrasp, len(eef_orig), frame_stride))
    success = False
    for t in phase2_indices:
        target_pos = eef_orig[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        if gripper_cmd > 0 and z_offset != 0:
            target_pos[2] += z_offset

        obs = servo_to_position(env, target_pos, gripper_cmd, max_servo=max_servo)
        record(obs)

        if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
            success = True

    # Get static camera params
    agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
    agent_w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K_norm = agent_K.copy()
    agent_K_norm[0] /= W
    agent_K_norm[1] /= H

    # Compute pixel projections
    eef_arr = np.stack(recorded_eef_pos)
    pix_uvs = []
    for i in range(len(eef_arr)):
        pix_rc = project_points_from_world_to_camera(
            eef_arr[i:i+1].astype(np.float64), agent_w2c, H, W)[0]
        pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

    return {
        "frames_agent": recorded_frames_agent,
        "eef_pos": np.stack(recorded_eef_pos),
        "eef_quat": np.stack(recorded_eef_quat),
        "gripper": np.array(recorded_gripper, dtype=np.float32),
        "pix_uv": np.stack(pix_uvs),
        "cam_extrinsic": agent_ext,
        "cam_K_norm": agent_K_norm,
        "world_to_cam": agent_w2c,
        "base_z": np.float32(0.912),
        "success": success,
        "dx": dx, "dy": dy,
    }


def save_demo(data, demo_dir):
    """Save trajectory data in parsed_libero format."""
    demo_dir = Path(demo_dir)
    frames_dir = demo_dir / "frames"
    frames_dir.mkdir(parents=True, exist_ok=True)

    for i, frame in enumerate(data["frames_agent"]):
        cv2.imwrite(str(frames_dir / f"{i:06d}.png"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    np.save(demo_dir / "eef_pos.npy", data["eef_pos"])
    np.save(demo_dir / "eef_quat.npy", data["eef_quat"])
    np.save(demo_dir / "gripper.npy", data["gripper"])
    np.save(demo_dir / "pix_uv.npy", data["pix_uv"])
    np.save(demo_dir / "cam_extrinsic.npy", data["cam_extrinsic"])
    np.save(demo_dir / "cam_K_norm.npy", data["cam_K_norm"])
    np.save(demo_dir / "world_to_cam.npy", data["world_to_cam"])
    np.save(demo_dir / "base_z.npy", data["base_z"])
    np.save(demo_dir / "actions.npy", np.zeros((len(data["eef_pos"]), 7), dtype=np.float32))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--dx_min", type=float, default=-0.40)
    parser.add_argument("--dx_max", type=float, default=-0.01)
    parser.add_argument("--dy_min", type=float, default=-0.30)
    parser.add_argument("--dy_max", type=float, default=0.30)
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--max_servo", type=int, default=25)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_objpos_task0")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # Extract original EEF trajectory
    print("Extracting original EEF trajectory...")
    env_tmp = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = extract_demo_eef_positions(env_tmp, states)
    env_tmp.close()

    # Compute centering offset
    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    center_dx = -states[0][bowl_i]
    center_dy = -states[0][bowl_i + 1]
    print(f"Centering offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    # Generate grid
    N = args.grid_size
    dx_vals = np.linspace(args.dx_min, args.dx_max, N)
    dy_vals = np.linspace(args.dy_min, args.dy_max, N)
    print(f"\nGrid: {N}x{N} = {N*N} trajectories")
    print(f"  dx: [{args.dx_min}, {args.dx_max}]")
    print(f"  dy: [{args.dy_min}, {args.dy_max}]")

    # Create env for replay
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    hide_furniture(sim)
    hide_distractors_visual(sim)

    task_dir = Path(args.out_root) / args.benchmark / f"task_{args.task_id}"
    successes = 0
    total = 0

    # Save grid metadata
    task_dir.mkdir(parents=True, exist_ok=True)
    np.savez(task_dir / "grid_meta.npz", dx_vals=dx_vals, dy_vals=dy_vals,
             center_dx=center_dx, center_dy=center_dy)

    for i, dx in enumerate(tqdm(dx_vals, desc="dx")):
        for j, dy in enumerate(dy_vals):
            demo_idx = i * N + j
            demo_dir = task_dir / f"demo_{demo_idx}"

            if (demo_dir / "eef_pos.npy").exists():
                total += 1
                continue

            # Reset robot state without full env.reset() (avoids sim recreation)
            env.env.timestep = 0
            env.env.done = False

            data = generate_trajectory(
                env, states, actions, eef_orig,
                dx, dy, center_dx, center_dy,
                frame_stride=args.frame_stride,
                z_offset=args.z_offset,
                max_servo=args.max_servo,
                image_size=args.image_size,
            )

            save_demo(data, demo_dir)
            total += 1
            if data["success"]:
                successes += 1

            if total % 10 == 0:
                print(f"  [{total}/{N*N}] dx={dx:.3f} dy={dy:.3f} success={data['success']} "
                      f"(total: {successes}/{total})")

    env.close()
    print(f"\nDone. {successes}/{total} trajectories succeeded.")
    print(f"Saved to {task_dir}")


if __name__ == "__main__":
    main()
