"""replay_natural_start.py — Replay shifted trajectory with natural robot start.

Instead of shifting the entire trajectory (which puts the robot in a weird start
position), this script:
  1. Starts the robot at its natural home position (original demo t=0)
  2. Smoothly interpolates from home to the shifted pre-grasp position
  3. Executes the shifted grasp/lift/place trajectory from there

Usage:
    python ood_libero/replay_natural_start.py [--dx -0.15] [--dy 0.0]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv

# ---- Constants ----
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":             {"qpos_start": 9,  "role": "pick"},
    "akita_black_bowl_2":             {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                      {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                        {"qpos_start": 37, "role": "place"},
}

FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
SINK_POS = np.array([0.0, 0.0, -5.0])
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])

DISTRACTOR_DOFS = {
    "akita_black_bowl_2":            {"qpos": slice(16, 23), "dof": slice(15, 21)},
    "cookies_1":                     {"qpos": slice(23, 30), "dof": slice(21, 27)},
    "glazed_rim_porcelain_ramekin_1": {"qpos": slice(30, 37), "dof": slice(27, 33)},
}


def _si(qpos_start):
    return qpos_start + STATE_QPOS_OFFSET


def move_distractors_in_state(state):
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _si(info["qpos_start"])
            s[i:i + 3] = DISTRACTOR_POS
    return s


def shift_pick_place(state, dx, dy):
    s = state.copy()
    for info in TASK0_OBJECTS.values():
        if info["role"] in ("pick", "place"):
            i = _si(info["qpos_start"])
            s[i] += dx
            s[i + 1] += dy
    return s


def hide_furniture(sim):
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        sim.model.body_pos[bid] = SINK_POS
    sim.forward()


def hide_distractors_visual(sim):
    distractor_bodies = set()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            bid = sim.model.body_name2id(f"{name}_main")
            distractor_bodies.add(bid)
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in distractor_bodies:
            sim.model.geom_rgba[geom_id][3] = 0.0


def freeze_distractors(sim):
    for dofs in DISTRACTOR_DOFS.values():
        sim.data.qpos[dofs["qpos"].start:dofs["qpos"].start + 3] = DISTRACTOR_POS
        sim.data.qvel[dofs["dof"]] = 0.0


def render_obs(obs, camera, image_size):
    img_key = f"{camera}_image"
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    rgb = np.ascontiguousarray(np.flipud(rgb))
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


def servo_to_position(env, target_pos, gripper_cmd, max_servo=50, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            return obs, True
    if obs is None:
        obs = env.env._get_observations()
    return obs, False


def extract_demo_eef_positions(env, states):
    sim = env.env.sim
    eef_positions = []
    for t in range(len(states)):
        env.set_init_state(states[t])
        sim.forward()
        obs = env.env._get_observations()
        eef_positions.append(np.array(obs["robot0_eef_pos"], dtype=np.float64))
    return np.array(eef_positions)


def find_grasp_timestep(actions):
    """Find the first timestep where gripper transitions from open to close."""
    gripper = actions[:, 6]
    for t in range(1, len(gripper)):
        if gripper[t] > 0 and gripper[t - 1] <= 0:
            return t
    return len(gripper) // 2  # fallback


def interpolate_waypoints(start_pos, end_pos, n_steps):
    """Linear interpolation from start to end, returning n_steps intermediate positions."""
    alphas = np.linspace(0, 1, n_steps + 1)[1:]  # skip start, include end
    return np.array([start_pos + a * (end_pos - start_pos) for a in alphas])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--max_servo", type=int, default=50)
    parser.add_argument("--dx", type=float, default=0.0)
    parser.add_argument("--dy", type=float, default=0.0)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--pregrasp_lead", type=int, default=6,
                        help="Number of timesteps before grasp to start the shifted approach")
    parser.add_argument("--interp_steps", type=int, default=8,
                        help="Number of interpolated waypoints from home to pre-grasp")
    parser.add_argument("--fps", type=int, default=10)
    parser.add_argument("--out_dir", type=str, default=None)
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Load demo ----
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")

    # ---- Extract original EEF trajectory ----
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    print("Extracting original EEF trajectory...")
    eef_positions = extract_demo_eef_positions(env, states)

    # ---- Compute centering + user shift ----
    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    bowl_orig_x = states[0][bowl_i]
    bowl_orig_y = states[0][bowl_i + 1]
    center_dx = -bowl_orig_x
    center_dy = -bowl_orig_y
    total_dx = center_dx + args.dx
    total_dy = center_dy + args.dy
    print(f"Bowl original: ({bowl_orig_x:.3f}, {bowl_orig_y:.3f})")
    print(f"Centering: ({center_dx:+.3f}, {center_dy:+.3f}), user shift: ({args.dx:+.3f}, {args.dy:+.3f})")
    print(f"Total shift: ({total_dx:+.3f}, {total_dy:+.3f})")

    # ---- Find grasp point ----
    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - args.pregrasp_lead)
    print(f"\nGrasp at t={t_grasp}, pre-grasp at t={t_pregrasp}")

    # ---- Setup scene ----
    env.reset()
    sim = env.env.sim
    hide_furniture(sim)
    hide_distractors_visual(sim)

    # Initial state: original robot pose, shifted objects, distractors off-screen
    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)
    # NOTE: robot joints stay at their original starting configuration

    env.set_init_state(state_0)
    sim.forward()
    freeze_distractors(sim)
    env.env.timestep = 0
    env.env.done = False
    env.env.horizon = 100000

    # ---- Phase 1: Home → Pre-grasp (interpolated approach) ----
    obs = env.env._get_observations()
    home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    # Shifted pre-grasp target
    pregrasp_target = eef_positions[t_pregrasp].copy()
    pregrasp_target[0] += total_dx
    pregrasp_target[1] += total_dy

    print(f"Home EEF:      {home_pos}")
    print(f"Pre-grasp EEF: {pregrasp_target}")
    print(f"Distance:      {np.linalg.norm(pregrasp_target - home_pos)*1000:.1f}mm")

    interp_waypoints = interpolate_waypoints(home_pos, pregrasp_target, args.interp_steps)

    frames = []
    # Render initial frame
    frame = render_obs(obs, args.camera, args.image_size)
    cv2.putText(frame, "HOME", (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(frame, "HOME", (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20, 20, 20), 1, cv2.LINE_AA)
    frames.append(frame)

    print(f"\nPhase 1: Interpolating home → pre-grasp ({args.interp_steps} waypoints)...")
    for i, wp in enumerate(tqdm(interp_waypoints, desc="Approach")):
        obs, done = servo_to_position(env, wp, -1.0, max_servo=args.max_servo)
        frame = render_obs(obs, args.camera, args.image_size)
        label = f"APPROACH {i+1}/{args.interp_steps}"
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 200, 0), 1, cv2.LINE_AA)
        frames.append(frame)
        if done:
            break

    # ---- Phase 2: Execute shifted trajectory from pre-grasp onward ----
    # Use every frame_stride-th timestep from t_pregrasp to end
    phase2_indices = list(range(t_pregrasp, len(eef_positions), args.frame_stride))
    print(f"Phase 2: Executing shifted trajectory ({len(phase2_indices)} waypoints from t={t_pregrasp})...")

    gripper_cmd = -1.0
    for t in tqdm(phase2_indices, desc="Execute"):
        target_pos = eef_positions[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        # Apply z_offset during grasp
        if gripper_cmd > 0 and args.z_offset != 0:
            target_pos[2] += args.z_offset

        obs, done = servo_to_position(env, target_pos, gripper_cmd, max_servo=args.max_servo)

        frame = render_obs(obs, args.camera, args.image_size)
        cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        err = np.linalg.norm(cur_pos - target_pos)
        phase = "GRASP" if gripper_cmd > 0 else "MOVE"
        label = f"{phase} t={t} err={err*1000:.1f}mm"
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20, 20, 20), 1, cv2.LINE_AA)

        if done:
            cv2.putText(frame, "SUCCESS", (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
            frames.append(frame)
            # Pad a few success frames
            for _ in range(5):
                frames.append(frame.copy())
            break

        frames.append(frame)

    # ---- Save video ----
    video_path = out_dir / "replay_natural_start.mp4"
    h, w = frames[0].shape[:2]
    writer = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), float(args.fps), (w, h))
    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"\nSaved video ({len(frames)} frames): {video_path}")

    # Save keyframes
    import matplotlib.pyplot as plt
    n_kf = min(7, len(frames))
    key_idxs = [int(i * (len(frames) - 1) / (n_kf - 1)) for i in range(n_kf)]
    fig, axes = plt.subplots(1, n_kf, figsize=(4 * n_kf, 4))
    for i, fi in enumerate(key_idxs):
        axes[i].imshow(frames[fi])
        axes[i].set_title(f"Frame {fi}/{len(frames)-1}", fontsize=9)
        axes[i].axis("off")
    plt.suptitle(f"Natural start: dx={args.dx:+.2f} dy={args.dy:+.2f} z_off={args.z_offset:+.3f}", fontsize=12)
    plt.tight_layout()
    kf_path = out_dir / "replay_natural_start_keyframes.png"
    plt.savefig(str(kf_path), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved keyframes: {kf_path}")

    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
