"""replay_shifted_trajectory.py — Replay a LIBERO demo with shifted object positions.

Takes the first demo of task 0, centers the pick/place objects, removes distractors
and furniture, then replays the trajectory by:
  1. For each original EEF position, shift it by the same centering offset
  2. Zero out rotation (use zero rotation delta)
  3. Servo the robot to the shifted position (closed-loop OSC teleport)
  4. Apply gripper action
  5. Render and save video

Usage:
    python ood_libero/replay_shifted_trajectory.py [--image_size 448] [--out_dir ood_libero/out]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import get_camera_transform_matrix

# ---- State array layout (same as object_removal_test.py) ----
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":             {"qpos_start": 9,  "role": "pick"},
    "akita_black_bowl_2":             {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                      {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                        {"qpos_start": 37, "role": "place"},
}

FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
SINK_POS = np.array([0.0, 0.0, -5.0])  # for furniture only (fixed bodies, no physics instability)

# Far-away position for distractor free-joint objects (on table height so first frame is ok)
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])

# Distractor DOF mappings (qpos and qvel indices into the raw sim arrays)
DISTRACTOR_DOFS = {
    "akita_black_bowl_2":            {"qpos": slice(16, 23), "dof": slice(15, 21)},
    "cookies_1":                     {"qpos": slice(23, 30), "dof": slice(21, 27)},
    "glazed_rim_porcelain_ramekin_1": {"qpos": slice(30, 37), "dof": slice(27, 33)},
}


def _si(qpos_start):
    return qpos_start + STATE_QPOS_OFFSET


def move_distractors_in_state(state):
    """Move distractors far off-screen in a state array."""
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _si(info["qpos_start"])
            s[i:i + 3] = DISTRACTOR_POS
    return s


def hide_distractors_visual(sim):
    """Make distractor geoms invisible (alpha=0). Returns original rgba for restore."""
    distractor_bodies = set()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            bid = sim.model.body_name2id(f"{name}_main")
            distractor_bodies.add(bid)

    originals = {}
    for geom_id in range(sim.model.ngeom):
        body_id = sim.model.geom_bodyid[geom_id]
        if body_id in distractor_bodies:
            originals[geom_id] = sim.model.geom_rgba[geom_id].copy()
            sim.model.geom_rgba[geom_id][3] = 0.0
    return originals


def restore_distractors_visual(sim, originals):
    for geom_id, rgba in originals.items():
        sim.model.geom_rgba[geom_id] = rgba


def freeze_distractors(sim):
    """Reset distractor positions and zero their velocities. Call after each env.step()."""
    for dofs in DISTRACTOR_DOFS.values():
        sim.data.qpos[dofs["qpos"].start:dofs["qpos"].start + 3] = DISTRACTOR_POS
        sim.data.qvel[dofs["dof"]] = 0.0


def shift_pick_place(state, dx, dy):
    s = state.copy()
    for info in TASK0_OBJECTS.values():
        if info["role"] in ("pick", "place"):
            i = _si(info["qpos_start"])
            s[i] += dx
            s[i + 1] += dy
    return s


def hide_furniture(sim):
    originals = {}
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        originals[name] = (bid, sim.model.body_pos[bid].copy())
        sim.model.body_pos[bid] = SINK_POS
    sim.forward()
    return originals


def restore_furniture(sim, originals):
    for name, (bid, pos) in originals.items():
        sim.model.body_pos[bid] = pos
    sim.forward()


def render_obs(obs, camera, image_size):
    img_key = f"{camera}_image"
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    rgb = np.ascontiguousarray(np.flipud(rgb))
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


def servo_to_position(env, target_pos, gripper_cmd, max_servo=50, threshold=0.003):
    """Closed-loop servo to target EEF position with zero rotation, then return obs."""
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        # rotation stays zero
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)  # prevent distractor drift / instability
        if done:
            break
    if obs is None:
        obs = env.env._get_observations()
    return obs


def extract_demo_eef_positions(env, sim, states, camera):
    """Replay demo states to extract EEF positions and gripper actions."""
    eef_positions = []
    for t in range(len(states)):
        env.set_init_state(states[t])
        sim.forward()
        obs = env.env._get_observations()
        eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        eef_positions.append(eef_pos)
    return np.array(eef_positions)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--frame_stride", type=int, default=1,
                        help="Use every Nth frame from the demo")
    parser.add_argument("--max_servo", type=int, default=50,
                        help="Max OSC steps per waypoint for servo convergence")
    parser.add_argument("--z_offset", type=float, default=-0.015,
                        help="Lower EEF target by this amount during grasp (gripper closing)")
    parser.add_argument("--fps", type=int, default=15)
    parser.add_argument("--out_dir", type=str, default=None)
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Load demo ----
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # ---- Extract original EEF trajectory ----
    # Use a temporary env to read EEF positions without modifying the scene
    env_extract = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env_extract.seed(0)
    env_extract.reset()

    print("Extracting original EEF trajectory...")
    eef_positions = extract_demo_eef_positions(env_extract, env_extract.env.sim, states, args.camera)
    print(f"  EEF range: x=[{eef_positions[:,0].min():.3f}, {eef_positions[:,0].max():.3f}] "
          f"y=[{eef_positions[:,1].min():.3f}, {eef_positions[:,1].max():.3f}] "
          f"z=[{eef_positions[:,2].min():.3f}, {eef_positions[:,2].max():.3f}]")
    env_extract.close()

    # ---- Compute centering offset ----
    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    bowl_orig_x = states[0][bowl_i]
    bowl_orig_y = states[0][bowl_i + 1]
    center_dx = -bowl_orig_x
    center_dy = -bowl_orig_y
    print(f"\nBowl original pos: ({bowl_orig_x:.3f}, {bowl_orig_y:.3f})")
    print(f"Centering offset:  ({center_dx:+.3f}, {center_dy:+.3f})")

    # ---- Create fresh env for replay ----
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    # Hide furniture (move body_pos underground — safe since they're fixed bodies)
    furniture_orig = hide_furniture(sim)
    # Hide distractors: move far away in state + make invisible + freeze each step
    distractor_orig = hide_distractors_visual(sim)

    # ---- Prepare initial state with centered pick/place objects ----
    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, center_dx, center_dy)

    # Set initial state
    env.set_init_state(state_0)
    sim.forward()

    # ---- Replay trajectory with servo teleport ----
    # Subsample frames
    frame_indices = list(range(0, len(eef_positions), args.frame_stride))
    print(f"\nReplaying {len(frame_indices)} waypoints (stride={args.frame_stride})...")

    frames = []
    gripper_cmd = -1.0  # start open

    # Render initial frame
    obs = env.env._get_observations()
    frames.append(render_obs(obs, args.camera, args.image_size))

    for wi, t in enumerate(tqdm(frame_indices, desc="Servo replay")):
        # Shifted target position
        target_pos = eef_positions[t].copy()
        target_pos[0] += center_dx
        target_pos[1] += center_dy

        # Gripper from demo actions
        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        # Lower the grasp point when gripper is closing
        if gripper_cmd > 0 and args.z_offset != 0:
            target_pos[2] += args.z_offset

        # Servo to target
        obs = servo_to_position(env, target_pos, gripper_cmd, max_servo=args.max_servo)

        # Render
        frame = render_obs(obs, args.camera, args.image_size)

        # Annotate
        cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        err = np.linalg.norm(cur_pos - target_pos)
        label = f"t={t} err={err*1000:.1f}mm grip={gripper_cmd:+.1f}"
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                    (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                    (20, 20, 20), 1, cv2.LINE_AA)
        frames.append(frame)

    # ---- Save video ----
    video_path = out_dir / "replay_shifted_trajectory.mp4"
    h, w = frames[0].shape[:2]
    writer = cv2.VideoWriter(
        str(video_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(args.fps),
        (w, h),
    )
    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"\nSaved video ({len(frames)} frames): {video_path}")

    # ---- Also save a few key frame snapshots ----
    key_frames = [0, len(frames) // 4, len(frames) // 2, 3 * len(frames) // 4, len(frames) - 1]
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(1, len(key_frames), figsize=(4 * len(key_frames), 4))
    for i, fi in enumerate(key_frames):
        axes[i].imshow(frames[fi])
        axes[i].set_title(f"Frame {fi}/{len(frames)-1}", fontsize=9)
        axes[i].axis("off")
    plt.suptitle("Shifted trajectory replay (key frames)", fontsize=12)
    plt.tight_layout()
    keyframes_path = out_dir / "replay_keyframes.png"
    plt.savefig(str(keyframes_path), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved keyframes: {keyframes_path}")

    restore_distractors_visual(sim, distractor_orig)
    restore_furniture(sim, furniture_orig)
    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
