"""object_position_demos.py — 3x3 grid video of trajectory replays at different object positions.

For each cell in a 3x3 grid of (dx, dy) offsets, replays the demo trajectory with
pick/place objects shifted to that position. Outputs a single tiled video.

Usage:
    python ood_libero/object_position_demos.py [--shift_range 0.1] [--frame_stride 3]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv

# ---- Reuse constants/helpers from replay_shifted_trajectory.py ----
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":             {"qpos_start": 9,  "role": "pick"},
    "akita_black_bowl_2":             {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                      {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                        {"qpos_start": 37, "role": "place"},
}

FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
SINK_POS = np.array([0.0, 0.0, -5.0])
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])

DISTRACTOR_DOFS = {
    "akita_black_bowl_2":            {"qpos": slice(16, 23), "dof": slice(15, 21)},
    "cookies_1":                     {"qpos": slice(23, 30), "dof": slice(21, 27)},
    "glazed_rim_porcelain_ramekin_1": {"qpos": slice(30, 37), "dof": slice(27, 33)},
}


def _si(qpos_start):
    return qpos_start + STATE_QPOS_OFFSET


def move_distractors_in_state(state):
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _si(info["qpos_start"])
            s[i:i + 3] = DISTRACTOR_POS
    return s


def shift_pick_place(state, dx, dy):
    s = state.copy()
    for info in TASK0_OBJECTS.values():
        if info["role"] in ("pick", "place"):
            i = _si(info["qpos_start"])
            s[i] += dx
            s[i + 1] += dy
    return s


def hide_furniture(sim):
    originals = {}
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        originals[name] = (bid, sim.model.body_pos[bid].copy())
        sim.model.body_pos[bid] = SINK_POS
    sim.forward()
    return originals


def hide_distractors_visual(sim):
    distractor_bodies = set()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            bid = sim.model.body_name2id(f"{name}_main")
            distractor_bodies.add(bid)
    originals = {}
    for geom_id in range(sim.model.ngeom):
        body_id = sim.model.geom_bodyid[geom_id]
        if body_id in distractor_bodies:
            originals[geom_id] = sim.model.geom_rgba[geom_id].copy()
            sim.model.geom_rgba[geom_id][3] = 0.0
    return originals


def freeze_distractors(sim):
    for dofs in DISTRACTOR_DOFS.values():
        sim.data.qpos[dofs["qpos"].start:dofs["qpos"].start + 3] = DISTRACTOR_POS
        sim.data.qvel[dofs["dof"]] = 0.0


def render_obs(obs, camera, image_size):
    img_key = f"{camera}_image"
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    rgb = np.ascontiguousarray(np.flipud(rgb))
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


def servo_to_position(env, target_pos, gripper_cmd, max_servo=50, threshold=0.003):
    """Returns (obs, done)."""
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            return obs, True
    if obs is None:
        obs = env.env._get_observations()
    return obs, False


def extract_demo_eef_positions(env, states):
    sim = env.env.sim
    eef_positions = []
    for t in range(len(states)):
        env.set_init_state(states[t])
        sim.forward()
        obs = env.env._get_observations()
        eef_positions.append(np.array(obs["robot0_eef_pos"], dtype=np.float64))
    return np.array(eef_positions)


def replay_trajectory(env, states_0_base, eef_positions, actions,
                      shift_dx, shift_dy, center_dx, center_dy,
                      frame_indices, camera, image_size, max_servo,
                      z_offset=-0.015):
    """Replay one trajectory with a given (shift_dx, shift_dy) offset. Returns list of RGB frames."""
    # Prepare initial state: center + additional shift
    state_0 = shift_pick_place(states_0_base, shift_dx, shift_dy)

    # Set initial state
    env.set_init_state(state_0)
    sim = env.env.sim
    sim.forward()
    freeze_distractors(sim)

    # Reset robosuite internal counters so horizon doesn't terminate us early
    env.env.timestep = 0
    env.env.done = False
    env.env.horizon = 100000  # effectively infinite

    frames = []
    gripper_cmd = -1.0

    # Render initial frame
    obs = env.env._get_observations()
    frames.append(render_obs(obs, camera, image_size))

    for t in frame_indices:
        target_pos = eef_positions[t].copy()
        target_pos[0] += center_dx + shift_dx
        target_pos[1] += center_dy + shift_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        # Lower the grasp point when gripper is closing
        if gripper_cmd > 0 and z_offset != 0:
            target_pos[2] += z_offset

        obs, done = servo_to_position(env, target_pos, gripper_cmd, max_servo=max_servo)
        frames.append(render_obs(obs, camera, image_size))
        if done:
            break  # episode terminated (e.g. success), stop early

    return frames


def tile_frames_3x3(all_cell_frames, cell_size, labels):
    """Given 9 lists of frames, produce tiled 3x3 grid frames with labels.

    all_cell_frames: list of 9 frame-lists (one per grid cell, row-major)
    Returns list of tiled RGB frames.
    """
    n_frames = max(len(f) for f in all_cell_frames)
    # Pad shorter sequences with their last frame
    for i in range(9):
        while len(all_cell_frames[i]) < n_frames:
            all_cell_frames[i].append(all_cell_frames[i][-1].copy())

    tiled = []
    for fi in range(n_frames):
        rows = []
        for row in range(3):
            row_cells = []
            for col in range(3):
                idx = row * 3 + col
                cell = all_cell_frames[idx][fi].copy()
                # Add label
                label = labels[idx]
                cv2.putText(cell, label, (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55,
                            (255, 255, 255), 2, cv2.LINE_AA)
                cv2.putText(cell, label, (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55,
                            (20, 20, 20), 1, cv2.LINE_AA)
                row_cells.append(cell)
            rows.append(np.concatenate(row_cells, axis=1))
        tiled.append(np.concatenate(rows, axis=0))
    return tiled


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=256,
                        help="Per-cell render resolution")
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--max_servo", type=int, default=50)
    parser.add_argument("--shift_range", type=float, default=0.1,
                        help="Symmetric range for dy (horizontal). Also used for dx if dx_min/dx_max not set.")
    parser.add_argument("--dx_min", type=float, default=None,
                        help="Min dx (toward robot). Default: -shift_range")
    parser.add_argument("--dx_max", type=float, default=None,
                        help="Max dx (away from robot). Default: +shift_range")
    parser.add_argument("--z_offset", type=float, default=-0.015,
                        help="Lower EEF target by this amount during grasp (gripper closing)")
    parser.add_argument("--fps", type=int, default=10)
    parser.add_argument("--out_dir", type=str, default=None)
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Load demo ----
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # ---- Extract original EEF trajectory ----
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    print("Extracting original EEF trajectory...")
    eef_positions = extract_demo_eef_positions(env, states)
    print(f"  EEF range: x=[{eef_positions[:,0].min():.3f}, {eef_positions[:,0].max():.3f}] "
          f"y=[{eef_positions[:,1].min():.3f}, {eef_positions[:,1].max():.3f}] "
          f"z=[{eef_positions[:,2].min():.3f}, {eef_positions[:,2].max():.3f}]")

    # ---- Compute centering offset ----
    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    bowl_orig_x = states[0][bowl_i]
    bowl_orig_y = states[0][bowl_i + 1]
    center_dx = -bowl_orig_x
    center_dy = -bowl_orig_y
    print(f"Bowl original: ({bowl_orig_x:.3f}, {bowl_orig_y:.3f}), centering: ({center_dx:+.3f}, {center_dy:+.3f})")

    # ---- Setup scene ----
    furniture_orig = hide_furniture(sim)
    distractor_orig = hide_distractors_visual(sim)

    # Base state: distractors moved off-screen, objects centered (no extra shift yet)
    state_0_base = states[0].copy()
    state_0_base = move_distractors_in_state(state_0_base)
    state_0_base = shift_pick_place(state_0_base, center_dx, center_dy)

    frame_indices = list(range(0, len(eef_positions), args.frame_stride))

    # ---- 3x3 grid of shifts ----
    r = args.shift_range
    dx_min = args.dx_min if args.dx_min is not None else -r
    dx_max = args.dx_max if args.dx_max is not None else r
    dx_offsets = np.linspace(dx_min, dx_max, 3)
    dy_offsets = np.linspace(-r, r, 3)

    grid_shifts = []
    labels = []
    for row in range(3):
        for col in range(3):
            dx, dy = dx_offsets[col], dy_offsets[row]
            grid_shifts.append((dx, dy))
            if abs(dx) < 1e-6 and abs(dy) < 1e-6:
                labels.append("dx=0,dy=0 (center)")
            else:
                labels.append(f"dx={dx:+.2f},dy={dy:+.2f}")

    print(f"\n3x3 grid: shift_range=+/-{r}m, {len(frame_indices)} waypoints each")
    print(f"Total: 9 trajectories x {len(frame_indices)} waypoints\n")

    # ---- Replay each trajectory ----
    all_cell_frames = []
    for cell_idx, (dx, dy) in enumerate(grid_shifts):
        print(f"[{cell_idx+1}/9] {labels[cell_idx]}")

        # Reset env to clear done flag, then re-apply model modifications
        env.reset()
        sim = env.env.sim
        hide_furniture(sim)
        hide_distractors_visual(sim)

        cell_frames = replay_trajectory(
            env, state_0_base, eef_positions, actions,
            dx, dy, center_dx, center_dy,
            frame_indices, args.camera, args.image_size, args.max_servo,
            z_offset=args.z_offset,
        )
        all_cell_frames.append(cell_frames)
        print(f"  -> {len(cell_frames)} frames")

    # ---- Tile and write video ----
    print("\nTiling 3x3 grid video...")
    tiled_frames = tile_frames_3x3(all_cell_frames, args.image_size, labels)

    video_path = out_dir / "object_position_demos.mp4"
    h, w = tiled_frames[0].shape[:2]
    writer = cv2.VideoWriter(
        str(video_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(args.fps),
        (w, h),
    )
    for frame in tiled_frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"Saved: {video_path} ({len(tiled_frames)} frames, {w}x{h})")

    # Save first frame as preview image
    preview_path = out_dir / "object_position_demos_preview.png"
    cv2.imwrite(str(preview_path), cv2.cvtColor(tiled_frames[0], cv2.COLOR_RGB2BGR))
    print(f"Saved preview: {preview_path}")

    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
