"""Generate OOD translation dataset (camera translation variation, fixed orientation).

Creates a grid of camera translations in camera-local coordinates (right, up).
5 positions: center (0,0) + 4 corners at (±dx, ±dy).
Multiple demos per position (with random object position jitter for variety).
Clean scene (no distractors/furniture).

Usage:
    python generate_ood_translation.py --demos_per_view 10
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

AGENT_CAM = "agentview"
TABLE_Z = 0.90
STATE_QPOS_OFFSET = 1
PICK_QPOS = 9
PLACE_QPOS = 37
DISTRACTOR_QPOS_STARTS = [16, 23, 30]
FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
DISTRACTOR_BODIES = ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])
DISTRACTOR_DOFS = {16: slice(15, 21), 23: slice(21, 27), 30: slice(27, 33)}


def generate_translation_grid(default_pos, cam_xmat, dx_mag, dy_mag):
    """Generate 5 translated camera positions: center + 4 corners.

    Translations are in camera-local coordinates:
      - right = cam_xmat[:, 0]  (camera X axis)
      - up    = cam_xmat[:, 1]  (camera Y axis)

    Returns positions array (5, 3), dx_vals (5,), dy_vals (5,).
    """
    right = cam_xmat[:, 0]  # camera-local right
    up = cam_xmat[:, 1]     # camera-local up

    dx_vals = np.array([0.0, -dx_mag, -dx_mag, dx_mag, dx_mag])
    dy_vals = np.array([0.0, -dy_mag, dy_mag, -dy_mag, dy_mag])

    positions = []
    for dx, dy in zip(dx_vals, dy_vals):
        pos = default_pos + dx * right + dy * up
        positions.append(pos)

    return np.array(positions), dx_vals, dy_vals


def find_grasp_timestep(actions):
    gripper = actions[:, 6]
    for t in range(1, len(gripper)):
        if gripper[t] > 0 and gripper[t - 1] <= 0:
            return t
    return len(gripper) // 2


def setup_clean_scene(sim):
    for fname in FURNITURE_BODIES:
        try:
            sim.model.body_pos[sim.model.body_name2id(fname)] = np.array([0, 0, -5.0])
        except Exception:
            pass
    sim.forward()
    dist_bodies = set()
    for dn in DISTRACTOR_BODIES:
        try:
            dist_bodies.add(sim.model.body_name2id(dn))
        except Exception:
            pass
    for gid in range(sim.model.ngeom):
        if sim.model.geom_bodyid[gid] in dist_bodies:
            sim.model.geom_rgba[gid][3] = 0.0


def freeze_distractors(sim):
    for qps, dof in DISTRACTOR_DOFS.items():
        sim.data.qpos[qps:qps+3] = DISTRACTOR_POS
        sim.data.qvel[dof] = 0.0


def shift_state(state, dx, dy):
    s = state.copy()
    for qps in [PICK_QPOS, PLACE_QPOS]:
        si = qps + STATE_QPOS_OFFSET
        s[si] += dx
        s[si+1] += dy
    for qps in DISTRACTOR_QPOS_STARTS:
        si = qps + STATE_QPOS_OFFSET
        s[si:si+3] = DISTRACTOR_POS
    return s


def servo_to(env, target, gripper_cmd, max_servo=25, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        obs = env.env._get_observations()
        cur = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        delta = target - cur
        if np.linalg.norm(delta) < threshold:
            break
        action = np.zeros(7, dtype=np.float32)
        action[:3] = np.clip(delta / 0.05, -1, 1)
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            break
    return obs or env.env._get_observations()


def interpolate_waypoints(start, end, n):
    alphas = np.linspace(0, 1, n + 1)[1:]
    return [start + a * (end - start) for a in alphas]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dx_cam", type=float, default=0.10,
                        help="Camera translation magnitude along camera-local right axis")
    parser.add_argument("--dy_cam", type=float, default=0.075,
                        help="Camera translation magnitude along camera-local up axis")
    parser.add_argument("--demos_per_view", type=int, default=10, help="Demos per camera position")
    parser.add_argument("--dx_min", type=float, default=-0.40, help="Min object dx offset")
    parser.add_argument("--dx_max", type=float, default=-0.01, help="Max object dx offset")
    parser.add_argument("--dy_min", type=float, default=-0.30, help="Min object dy offset")
    parser.add_argument("--dy_max", type=float, default=0.30, help="Max object dy offset")
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_translation_v1")
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        dk = sorted([k for k in f["data"].keys() if k.startswith("demo_")])[0]
        states = f[f"data/{dk}/states"][()]
        actions = f[f"data/{dk}/actions"][()]

    print("Extracting EEF trajectory...")
    env_tmp = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                                  camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = []
    for t in range(len(states)):
        env_tmp.set_init_state(states[t])
        env_tmp.env.sim.forward()
        eef_orig.append(np.array(env_tmp.env._get_observations()["robot0_eef_pos"], dtype=np.float64))
    eef_orig = np.array(eef_orig)
    env_tmp.close()

    bowl_si = PICK_QPOS + STATE_QPOS_OFFSET
    center_dx = -states[0][bowl_si]
    center_dy = -states[0][bowl_si + 1]
    print(f"Center offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    env = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=args.image_size,
                              camera_widths=args.image_size, camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim
    setup_clean_scene(sim)
    H = W = args.image_size

    cam_id = sim.model.camera_name2id(AGENT_CAM)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    default_quat = sim.model.cam_quat[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    print(f"Default cam pos: {default_pos}")
    print(f"Default cam quat: {default_quat}")
    print(f"Camera right axis: {cam_xmat[:, 0]}")
    print(f"Camera up axis: {cam_xmat[:, 1]}")

    # Generate 5 translated positions: center + 4 corners
    tp_positions, dx_vals, dy_vals = generate_translation_grid(
        default_pos, cam_xmat, args.dx_cam, args.dy_cam)

    n_positions = len(tp_positions)
    n_total = n_positions * args.demos_per_view
    print(f"\n{n_positions} camera positions x {args.demos_per_view} demos = {n_total} episodes")
    print(f"Camera dx magnitudes: {dx_vals}")
    print(f"Camera dy magnitudes: {dy_vals}")
    print(f"Object position range: dx=[{args.dx_min}, {args.dx_max}], dy=[{args.dy_min}, {args.dy_max}]")

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)

    np.savez(task_dir / "viewpoint_meta.npz",
             dx_vals=dx_vals, dy_vals=dy_vals,
             dx_cam=args.dx_cam, dy_cam=args.dy_cam,
             demos_per_view=args.demos_per_view,
             tp_positions=tp_positions,
             default_quat=default_quat,
             center_dx=center_dx, center_dy=center_dy,
             dx_min=args.dx_min, dx_max=args.dx_max,
             dy_min=args.dy_min, dy_max=args.dy_max)

    rng = np.random.RandomState(42)
    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - 6)
    demo_idx = 0
    successes = 0

    for vi in tqdm(range(n_positions), desc="Camera positions"):
        # Set translated camera position; keep original orientation (no quat change)
        sim.model.cam_pos[cam_id] = tp_positions[vi]
        # Orientation stays the same as default — no quaternion update needed

        for di in range(args.demos_per_view):
            demo_dir = task_dir / f"demo_{demo_idx}"
            if (demo_dir / "eef_pos.npy").exists():
                demo_idx += 1
                continue

            # Random object position from full range
            dx_offset = rng.uniform(args.dx_min, args.dx_max)
            dy_offset = rng.uniform(args.dy_min, args.dy_max)
            total_dx = center_dx + dx_offset
            total_dy = center_dy + dy_offset

            env.env.timestep = 0
            env.env.done = False
            state_0 = shift_state(states[0], total_dx, total_dy)
            env.set_init_state(state_0)
            sim.forward()
            for _ in range(5):
                env.step(np.zeros(7, dtype=np.float32))
                freeze_distractors(sim)

            obs = env.env._get_observations()
            home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            pregrasp_target = eef_orig[t_pregrasp].copy()
            pregrasp_target[0] += total_dx
            pregrasp_target[1] += total_dy

            gripper_cmd = -1.0
            rec_frames, rec_eef, rec_quat, rec_grip = [], [], [], []

            def record(o):
                rec_eef.append(np.array(o["robot0_eef_pos"], dtype=np.float32))
                rec_quat.append(np.array(o["robot0_eef_quat"], dtype=np.float32))
                rec_grip.append(gripper_cmd)
                rec_frames.append(np.flipud(o[f"{AGENT_CAM}_image"]).copy())

            record(obs)

            for wp in interpolate_waypoints(home_pos, pregrasp_target, 8):
                obs = servo_to(env, wp, -1.0, max_servo=25)
                record(obs)

            phase2 = list(range(t_pregrasp, len(eef_orig), args.frame_stride))
            success = False
            for t in phase2:
                target = eef_orig[t].copy()
                target[0] += total_dx
                target[1] += total_dy
                if t < len(actions):
                    gripper_cmd = float(np.clip(actions[t, 6], -1, 1))
                if gripper_cmd > 0 and args.z_offset != 0:
                    target[2] += args.z_offset
                obs = servo_to(env, target, gripper_cmd, max_servo=25)
                record(obs)
                if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                    success = True

            demo_dir.mkdir(parents=True, exist_ok=True)
            frames_dir = demo_dir / "frames"
            frames_dir.mkdir(exist_ok=True)
            for fi, frame in enumerate(rec_frames):
                cv2.imwrite(str(frames_dir / f"{fi:06d}.png"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            eef_arr = np.stack(rec_eef)
            w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W)
            K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W)
            K_norm = K.copy()
            K_norm[0] /= W
            K_norm[1] /= H
            ext = get_camera_extrinsic_matrix(sim, AGENT_CAM)

            pix_uvs = []
            for ei in range(len(eef_arr)):
                pix_rc = project_points_from_world_to_camera(
                    eef_arr[ei:ei+1].astype(np.float64), w2c, H, W)[0]
                pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

            np.save(demo_dir / "eef_pos.npy", eef_arr)
            np.save(demo_dir / "eef_quat.npy", np.stack(rec_quat))
            np.save(demo_dir / "gripper.npy", np.array(rec_grip, dtype=np.float32))
            np.save(demo_dir / "pix_uv.npy", np.stack(pix_uvs))
            np.save(demo_dir / "cam_extrinsic.npy", ext.astype(np.float32))
            np.save(demo_dir / "cam_K_norm.npy", K_norm.astype(np.float32))
            np.save(demo_dir / "world_to_cam.npy", w2c.astype(np.float32))
            np.save(demo_dir / "base_z.npy", np.float32(0.912))
            np.save(demo_dir / "actions.npy", np.zeros((len(eef_arr), 7), dtype=np.float32))
            np.savez(demo_dir / "meta.npz", vi=vi, di=di,
                     cam_dx=dx_vals[vi], cam_dy=dy_vals[vi],
                     obj_dx=dx_offset, obj_dy=dy_offset)

            if success:
                successes += 1
            demo_idx += 1

            if demo_idx % 10 == 0:
                print(f"  [{demo_idx}/{n_total}] vi={vi} success={success} (total: {successes}/{demo_idx})")

    env.close()
    print(f"\nDone. {successes}/{demo_idx} succeeded.")
    print(f"Saved to {task_dir}")


if __name__ == "__main__":
    main()
