"""debug_execution.py — Debug a single shifted trajectory with z-offset for better grasps.

Runs the left-middle trajectory (dx=-0.15, dy=0.0) with the gripper lowered
by --z_offset during the grasp approach.

Usage:
    python ood_libero/debug_execution.py [--z_offset -0.015]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv

# ---- Same constants as replay_shifted_trajectory.py ----
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":             {"qpos_start": 9,  "role": "pick"},
    "akita_black_bowl_2":             {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                      {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                        {"qpos_start": 37, "role": "place"},
}

FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]
SINK_POS = np.array([0.0, 0.0, -5.0])
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])

DISTRACTOR_DOFS = {
    "akita_black_bowl_2":            {"qpos": slice(16, 23), "dof": slice(15, 21)},
    "cookies_1":                     {"qpos": slice(23, 30), "dof": slice(21, 27)},
    "glazed_rim_porcelain_ramekin_1": {"qpos": slice(30, 37), "dof": slice(27, 33)},
}


def _si(qpos_start):
    return qpos_start + STATE_QPOS_OFFSET


def move_distractors_in_state(state):
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _si(info["qpos_start"])
            s[i:i + 3] = DISTRACTOR_POS
    return s


def shift_pick_place(state, dx, dy):
    s = state.copy()
    for info in TASK0_OBJECTS.values():
        if info["role"] in ("pick", "place"):
            i = _si(info["qpos_start"])
            s[i] += dx
            s[i + 1] += dy
    return s


def hide_furniture(sim):
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        sim.model.body_pos[bid] = SINK_POS
    sim.forward()


def hide_distractors_visual(sim):
    distractor_bodies = set()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            bid = sim.model.body_name2id(f"{name}_main")
            distractor_bodies.add(bid)
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in distractor_bodies:
            sim.model.geom_rgba[geom_id][3] = 0.0


def freeze_distractors(sim):
    for dofs in DISTRACTOR_DOFS.values():
        sim.data.qpos[dofs["qpos"].start:dofs["qpos"].start + 3] = DISTRACTOR_POS
        sim.data.qvel[dofs["dof"]] = 0.0


def render_obs(obs, camera, image_size):
    img_key = f"{camera}_image"
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    rgb = np.ascontiguousarray(np.flipud(rgb))
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


def servo_to_position(env, target_pos, gripper_cmd, max_servo=50, threshold=0.003):
    sim = env.env.sim
    obs = None
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        freeze_distractors(sim)
        if done:
            return obs, True
    if obs is None:
        obs = env.env._get_observations()
    return obs, False


def extract_demo_eef_positions(env, states):
    sim = env.env.sim
    eef_positions = []
    for t in range(len(states)):
        env.set_init_state(states[t])
        sim.forward()
        obs = env.env._get_observations()
        eef_positions.append(np.array(obs["robot0_eef_pos"], dtype=np.float64))
    return np.array(eef_positions)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--max_servo", type=int, default=50)
    parser.add_argument("--fps", type=int, default=10)
    parser.add_argument("--dx", type=float, default=-0.15,
                        help="X shift for this trajectory")
    parser.add_argument("--dy", type=float, default=0.0,
                        help="Y shift for this trajectory")
    parser.add_argument("--z_offset", type=float, default=-0.015,
                        help="Lower the EEF target by this amount (negative = lower). "
                             "Applied when gripper is closing (grasp phase).")
    parser.add_argument("--out_dir", type=str, default=None)
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Load demo ----
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]

    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {len(states)} frames")
    print(f"Shift: dx={args.dx:+.3f}, dy={args.dy:+.3f}, z_offset={args.z_offset:+.4f}")

    # ---- Extract original EEF trajectory ----
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    print("Extracting original EEF trajectory...")
    eef_positions = extract_demo_eef_positions(env, states)

    # ---- Compute centering offset ----
    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    bowl_orig_x = states[0][bowl_i]
    bowl_orig_y = states[0][bowl_i + 1]
    center_dx = -bowl_orig_x
    center_dy = -bowl_orig_y
    print(f"Bowl original: ({bowl_orig_x:.3f}, {bowl_orig_y:.3f}), centering: ({center_dx:+.3f}, {center_dy:+.3f})")

    # Total shift = centering + user shift
    total_dx = center_dx + args.dx
    total_dy = center_dy + args.dy
    print(f"Total shift: ({total_dx:+.3f}, {total_dy:+.3f})")

    # ---- Setup scene ----
    env.reset()
    sim = env.env.sim
    hide_furniture(sim)
    hide_distractors_visual(sim)

    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)

    env.set_init_state(state_0)
    sim.forward()
    freeze_distractors(sim)
    env.env.timestep = 0
    env.env.done = False
    env.env.horizon = 100000

    # ---- Replay with z_offset during grasp ----
    frame_indices = list(range(0, len(eef_positions), args.frame_stride))
    print(f"\nReplaying {len(frame_indices)} waypoints...")

    # Detect gripper transitions: find when gripper goes from open to closed
    gripper_values = np.array([actions[min(t, len(actions)-1), 6] for t in range(len(states))])

    frames = []
    gripper_cmd = -1.0
    prev_gripper = -1.0

    # Render initial frame
    obs = env.env._get_observations()
    frames.append(render_obs(obs, args.camera, args.image_size))

    for t in tqdm(frame_indices, desc="Servo replay"):
        target_pos = eef_positions[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        # Apply z_offset when gripper is closing (grasp approach)
        # Detect: gripper_cmd > 0 means closing, or about to close
        if gripper_cmd > 0:
            target_pos[2] += args.z_offset

        obs, done = servo_to_position(env, target_pos, gripper_cmd, max_servo=args.max_servo)

        # Render with annotation
        frame = render_obs(obs, args.camera, args.image_size)
        cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        err = np.linalg.norm(cur_pos - target_pos)
        z_info = f" z_off={args.z_offset:+.3f}" if gripper_cmd > 0 else ""
        label = f"t={t} err={err*1000:.1f}mm grip={gripper_cmd:+.1f}{z_info}"
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.45,
                    (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(frame, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.45,
                    (20, 20, 20), 1, cv2.LINE_AA)

        # Mark done
        if done:
            cv2.putText(frame, "SUCCESS", (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                        (0, 255, 0), 2, cv2.LINE_AA)

        frames.append(frame)
        prev_gripper = gripper_cmd

        if done:
            # Add a few extra frames showing the success state
            for _ in range(5):
                frames.append(frame.copy())
            break

    # ---- Save video ----
    video_path = out_dir / "debug_execution.mp4"
    h, w = frames[0].shape[:2]
    writer = cv2.VideoWriter(
        str(video_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        float(args.fps),
        (w, h),
    )
    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"\nSaved video ({len(frames)} frames): {video_path}")

    # Save keyframes
    import matplotlib.pyplot as plt
    key_idxs = [0, len(frames)//4, len(frames)//2, 3*len(frames)//4, len(frames)-1]
    fig, axes = plt.subplots(1, len(key_idxs), figsize=(4*len(key_idxs), 4))
    for i, fi in enumerate(key_idxs):
        axes[i].imshow(frames[fi])
        axes[i].set_title(f"Frame {fi}/{len(frames)-1}", fontsize=9)
        axes[i].axis("off")
    plt.suptitle(f"Debug: dx={args.dx:+.2f} dy={args.dy:+.2f} z_offset={args.z_offset:+.3f}", fontsize=12)
    plt.tight_layout()
    kf_path = out_dir / "debug_execution_keyframes.png"
    plt.savefig(str(kf_path), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved keyframes: {kf_path}")

    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
