"""Replay action chunks in simulation.

Three modes:
  open_loop:   Execute absolute_chunks from NPZ directly (blindly send absolute poses).
  closed_loop: Execute relative_chunks, converting to absolute at each chunk boundary
               using the sim's current EE pose as the reference frame.
  commanded:   Execute relative_chunks, converting to absolute by composing each
               relative action with the *last commanded* absolute pose (dead-reckoning
               from commands with no sim feedback).

New NPZ format:
  relative_chunks: (num_chunks, chunk_len, action_dim)
  absolute_chunks: (num_chunks, chunk_len, action_dim)
  chunk_at_step:   (num_chunks,)  -- observation timestep each chunk was generated at
  action_fields:   (num_fields,)  -- field names describing the action_dim layout
  field_dims:      (num_fields,)  -- per-field dimensionality [3,3,6,6,1,1] = 20
"""

import dataclasses
import tyro
import mediapy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import tqdm
import gymnasium as gym
import torch
import argparse
import numpy as np

from pathlib import Path
from scipy.spatial.transform import Rotation
from isaaclab.app import AppLauncher


# ---------------------------------------------------------------------------
# Rotation helpers
# ---------------------------------------------------------------------------

def rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
    """6D continuous rotation (Zhou et al. 2019) -> 3x3 rotation matrix."""
    a1 = rot6d[..., :3]
    a2 = rot6d[..., 3:6]
    b1 = a1 / np.linalg.norm(a1, axis=-1, keepdims=True)
    b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
    b2 = b2 / np.linalg.norm(b2, axis=-1, keepdims=True)
    b3 = np.cross(b1, b2)
    return np.stack([b1, b2, b3], axis=-2)


def rot_to_quat_wxyz(rot: np.ndarray) -> np.ndarray:
    """3x3 rotation matrix -> (4,) wxyz quaternion."""
    q = Rotation.from_matrix(rot).as_quat()  # xyzw
    return np.array([q[3], q[0], q[1], q[2]])


def quat_wxyz_to_matrix(quat_wxyz: np.ndarray) -> np.ndarray:
    """(4,) wxyz quaternion -> 3x3 rotation matrix."""
    return Rotation.from_quat(
        [quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]]
    ).as_matrix()


# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------

def load_chunks(npz_path: str) -> dict:
    """Load action chunks from NPZ file."""
    data = np.load(npz_path, allow_pickle=True)
    result = {
        "relative_chunks": data["relative_chunks"],   # (C, H, D)
        "absolute_chunks": data["absolute_chunks"],    # (C, H, D)
        "chunk_at_step": data["chunk_at_step"],        # (C,)
        "action_fields": list(data["action_fields"]),  # list of str
        "field_dims": data["field_dims"].astype(int),  # (F,)
        "achieved_ee_poses": data["achieved_ee_poses"],  # (T, D)
    }
    C, H, D = result["relative_chunks"].shape
    T_achieved = result["achieved_ee_poses"].shape[0]
    print(f"Loaded {C} chunks, horizon={H}, action_dim={D}")
    print(f"  Achieved EE poses: {T_achieved} steps")
    print(f"  Fields: {result['action_fields']}")
    print(f"  Dims:   {list(result['field_dims'])} (sum={result['field_dims'].sum()})")
    print(f"  chunk_at_step: {result['chunk_at_step']}")
    return result


def parse_fields(vec: np.ndarray, field_dims: np.ndarray) -> dict:
    """Split a (D,) action vector into per-arm fields.

    Expected field order (matching field_dims = [3,3,6,6,1,1]):
      left_xyz, right_xyz, left_rot6d, right_rot6d, left_gripper, right_gripper
    """
    parts = np.split(vec, np.cumsum(field_dims)[:-1])
    return {
        "left_xyz": parts[0],
        "right_xyz": parts[1],
        "left_rot6d": parts[2],
        "right_rot6d": parts[3],
        "left_gripper": float(parts[4][0]),
        "right_gripper": float(parts[5][0]),
    }


# ---------------------------------------------------------------------------
# Action building
# ---------------------------------------------------------------------------

def build_open_loop_action(parsed: dict) -> dict[str, np.ndarray]:
    """Build absolute IK actions from absolute chunk fields.

    Returns dict mapping side -> [pos(3), quat_wxyz(4), gripper(1)] = (8,)
    """
    actions = {}
    for side in ("left", "right"):
        pos = parsed[f"{side}_xyz"]
        rot = rot6d_to_matrix(parsed[f"{side}_rot6d"])
        quat = rot_to_quat_wxyz(rot)
        gripper = parsed[f"{side}_gripper"]
        actions[side] = np.concatenate([pos, quat, [gripper]])
    return actions


def _compose_relative(
    parsed: dict,
    reference: dict[str, dict],
) -> dict[str, np.ndarray]:
    """Compose relative actions with a reference pose to produce absolute IK actions.

    reference maps side -> {"pos": (3,), "rot": (3,3)} in env frame.

    Conversion: T_abs = T_ref @ T_rel
      abs_pos = ref_rot @ rel_pos + ref_pos
      abs_rot = ref_rot @ rel_rot
    """
    actions = {}
    for side in ("left", "right"):
        rel_pos = parsed[f"{side}_xyz"]
        rel_rot = rot6d_to_matrix(parsed[f"{side}_rot6d"])
        ref_pos = reference[side]["pos"]
        ref_rot = reference[side]["rot"]

        abs_pos = ref_rot @ rel_pos + ref_pos
        abs_rot = ref_rot @ rel_rot
        quat = rot_to_quat_wxyz(abs_rot)
        gripper = parsed[f"{side}_gripper"]
        actions[side] = np.concatenate([abs_pos, quat, [gripper]])
    return actions


def build_closed_loop_action(
    parsed: dict,
    sim_reference: dict[str, dict],
) -> dict[str, np.ndarray]:
    """Build absolute IK actions using sim EE pose at chunk boundary as reference."""
    return _compose_relative(parsed, sim_reference)


def build_commanded_action(
    parsed: dict,
    commanded_reference: dict[str, dict],
) -> dict[str, np.ndarray]:
    """Build absolute IK actions using last commanded pose as reference."""
    return _compose_relative(parsed, commanded_reference)


# ---------------------------------------------------------------------------
# Sim helpers
# ---------------------------------------------------------------------------

def get_ee_pose(robot, ee_idx: int, env_origin: np.ndarray):
    """Return (pos_env(3,), rot(3,3), quat_wxyz(4,)) for one robot's EE."""
    pos_w = robot.data.body_pos_w[0, ee_idx].cpu().numpy()
    quat_w = robot.data.body_quat_w[0, ee_idx].cpu().numpy()  # wxyz
    pos_env = pos_w - env_origin
    rot = quat_wxyz_to_matrix(quat_w)
    return pos_env, rot, quat_w


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

POS_YLIM = (-1, 1)
POS_ERR_YLIM = (-1, 1)
ROT_ERR_YLIM = (0, 180)


def plot_ee_tracking(
    recorded_pos: np.ndarray,
    recorded_rot: np.ndarray,
    sim_pos: np.ndarray,
    sim_quat_wxyz: np.ndarray,
    title: str,
    save_path: str,
):
    """Plot recorded (achieved) vs sim EE position and orientation error."""
    N = len(recorded_pos)
    t = np.arange(N)
    labels = ["X", "Y", "Z"]

    fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True)

    # Position comparison
    ax = axes[0]
    for i, lab in enumerate(labels):
        c = f"C{i}"
        ax.plot(t, recorded_pos[:, i], label=f"Recorded {lab}", linestyle="--", color=c)
        ax.plot(t, sim_pos[:, i], label=f"Sim {lab}", color=c)
    ax.set_ylabel("Position (m)")
    ax.set_title("EE Position: Recorded vs Sim")
    ax.legend(fontsize=8, ncol=3)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(POS_YLIM)

    # Position error
    pos_err = sim_pos - recorded_pos
    ax = axes[1]
    for i, lab in enumerate(labels):
        ax.plot(t, pos_err[:, i], label=lab)
    ax.axhline(0, color="k", lw=0.5, ls="--")
    ax.set_ylabel("Error (m)")
    ax.set_title("Position Error (sim - recorded)")
    ax.legend(fontsize=8, ncol=3)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(POS_ERR_YLIM)

    # Orientation error (geodesic)
    sim_rot = Rotation.from_quat(
        np.concatenate([sim_quat_wxyz[:, 1:], sim_quat_wxyz[:, :1]], axis=1)
    ).as_matrix()
    R_err = np.einsum("nij,nkj->nik", sim_rot, recorded_rot)
    traces = np.trace(R_err, axis1=1, axis2=2)
    ang_err = np.degrees(np.arccos(np.clip((traces - 1) / 2, -1, 1)))
    ax = axes[2]
    ax.plot(t, ang_err, color="tab:red")
    ax.axhline(0, color="k", lw=0.5, ls="--")
    ax.set_ylabel("Error (deg)")
    ax.set_title("Orientation Error (geodesic)")
    ax.grid(True, alpha=0.3)
    ax.set_ylim(ROT_ERR_YLIM)

    # Summary stats
    pos_l2 = np.linalg.norm(pos_err, axis=1)
    print(f"\n{'=' * 60}")
    print(f"{title}:")
    print(f"  {'Metric':<25} {'Mean':>10} {'Max':>10} {'Final':>10}")
    print(f"  {'-' * 55}")
    for i, lab in enumerate(labels):
        ae = np.abs(pos_err[:, i])
        print(f"  {'Position ' + lab + ' (m)':<25} {ae.mean():>10.4f} {ae.max():>10.4f} {pos_err[-1, i]:>10.4f}")
    print(f"  {'Position L2 (m)':<25} {pos_l2.mean():>10.4f} {pos_l2.max():>10.4f} {pos_l2[-1]:>10.4f}")
    print(f"  {'Orientation (deg)':<25} {ang_err.mean():>10.4f} {ang_err.max():>10.4f} {ang_err[-1]:>10.4f}")
    print(f"{'=' * 60}")

    fig.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    print(f"Saved {save_path}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main(args: "Args"):
    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = args.headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab_tasks.utils import parse_env_cfg  # noqa: E402
    import sim_improvement.environments  # noqa: E402

    env_cfg = parse_env_cfg(
        args.environment,
        device="cuda",
        num_envs=1,
        use_fabric=True,
    )
    env = gym.make(args.environment, cfg=env_cfg)

    # Load chunks
    chunks = load_chunks(args.actions_npz)
    relative_chunks = chunks["relative_chunks"]   # (C, H, D)
    absolute_chunks = chunks["absolute_chunks"]    # (C, H, D)
    chunk_at_step = chunks["chunk_at_step"]        # (C,)
    field_dims = chunks["field_dims"]
    achieved_ee_poses = chunks["achieved_ee_poses"]  # (T_total, D)
    num_chunks, chunk_len, _ = relative_chunks.shape

    # Parse achieved EE poses into per-arm pos (T, 3) and rot (T, 3, 3)
    achieved = {"left": {}, "right": {}}
    splits = np.cumsum(field_dims)[:-1]
    achieved_parts = np.split(achieved_ee_poses, splits, axis=1)
    # fields: left_xyz(3), right_xyz(3), left_rot6d(6), right_rot6d(6), left_grip(1), right_grip(1)
    achieved["left"]["pos"] = achieved_parts[0]                        # (T, 3)
    achieved["right"]["pos"] = achieved_parts[1]                       # (T, 3)
    achieved["left"]["rot"] = rot6d_to_matrix(achieved_parts[2])       # (T, 3, 3)
    achieved["right"]["rot"] = rot6d_to_matrix(achieved_parts[3])      # (T, 3, 3)

    # Compute action horizon (steps to execute per chunk before switching to next)
    if args.action_horizon is not None:
        action_horizon = args.action_horizon
    elif num_chunks > 1:
        action_horizon = int(chunk_at_step[1] - chunk_at_step[0])
    else:
        action_horizon = chunk_len
    print(f"Mode: {args.mode} | action_horizon: {action_horizon} "
          f"| chunks: {num_chunks} | chunk_len: {chunk_len}")

    obs, info = env.reset()

    # Robot handles
    left_robot = env.unwrapped.scene["left_panda"]
    right_robot = env.unwrapped.scene["right_panda"]
    robots = {"left": left_robot, "right": right_robot}

    print(f"Action space: {env.action_space.shape}")

    # EE body indices (panda_link8)
    ee_body_idx = {}
    for side, robot in robots.items():
        for idx, name in enumerate(robot.body_names):
            if "panda_link8" in name:
                ee_body_idx[side] = idx
                break
        print(f"[{side}] EE body: idx={ee_body_idx[side]} "
              f"({robot.body_names[ee_body_idx[side]]})")

    env_origin = env.unwrapped.scene.env_origins[0].cpu().numpy()

    CAMERA_NAMES = [
        "external_camera_left",
        "external_camera_right",
        "wrist_camera_left",
        "wrist_camera_right",
    ]

    # Logging
    sim_ee_log = {s: {"pos": [], "quat": []} for s in ("left", "right")}
    video_frames = []
    global_step = 0  # tracks position into achieved_ee_poses

    sim_reference = {s: {"pos": None, "rot": None} for s in ("left", "right")}
    commanded_reference = {s: {"pos": None, "rot": None} for s in ("left", "right")}
    # Tracks the last commanded absolute pose (link8 frame), updated every step
    last_commanded = {s: {"pos": None, "rot": None} for s in ("left", "right")}

    # For 'commanded' mode, seed the reference from the sim's initial link8 EE pose
    if args.mode == "commanded":
        for side, robot in robots.items():
            pos, rot, _ = get_ee_pose(robot, ee_body_idx[side], env_origin)
            commanded_reference[side]["pos"] = pos.copy()
            commanded_reference[side]["rot"] = rot.copy()

    total_steps = (num_chunks - 1) * min(action_horizon, chunk_len) + chunk_len
    bar = tqdm.tqdm(total=total_steps, desc=f"Replaying ({args.mode})")
    done = False

    for chunk_idx in range(num_chunks):
        steps_this_chunk = (
            min(action_horizon, chunk_len)
            if chunk_idx < num_chunks - 1
            else chunk_len
        )

        # At each chunk boundary, capture sim EE as reference (closed_loop only)
        if args.mode == "closed_loop":
            for side, robot in robots.items():
                pos, rot, _ = get_ee_pose(robot, ee_body_idx[side], env_origin)
                sim_reference[side]["pos"] = pos
                sim_reference[side]["rot"] = rot
        elif args.mode == "commanded" and chunk_idx > 0:
            # Use the last commanded pose from the previous chunk as reference
            for side in ("left", "right"):
                commanded_reference[side]["pos"] = last_commanded[side]["pos"].copy()
                commanded_reference[side]["rot"] = last_commanded[side]["rot"].copy()

        for step_j in range(steps_this_chunk):
            # Log sim EE pose (panda_link8) before stepping
            for side, robot in robots.items():
                pos, rot, quat = get_ee_pose(robot, ee_body_idx[side], env_origin)
                sim_ee_log[side]["pos"].append(pos)
                sim_ee_log[side]["quat"].append(quat)

            # Select chunk source based on mode
            if args.mode == "open_loop":
                vec = absolute_chunks[chunk_idx, step_j]
            else:
                vec = relative_chunks[chunk_idx, step_j]
            parsed = parse_fields(vec, field_dims)

            # Build absolute IK actions
            if args.mode == "open_loop":
                arm_actions = build_open_loop_action(parsed)
            elif args.mode == "closed_loop":
                arm_actions = build_closed_loop_action(parsed, sim_reference)
            else:  # commanded
                arm_actions = build_commanded_action(parsed, commanded_reference)

            # Track last commanded pose (link8 frame) for 'commanded' mode
            if args.mode == "commanded":
                for side in ("left", "right"):
                    a = arm_actions[side]
                    last_commanded[side]["pos"] = a[:3].copy()
                    last_commanded[side]["rot"] = quat_wxyz_to_matrix(a[3:7])

            global_step += 1

            # Assemble full action: [left_arm(7) + left_grip(1), right_arm(7) + right_grip(1)]
            if args.left_only:
                action = arm_actions["left"].astype(np.float32)
            else:
                action = np.concatenate([
                    arm_actions["left"], arm_actions["right"]
                ]).astype(np.float32)
            action_tensor = torch.tensor(action, device="cuda").unsqueeze(0)

            obs, rew, term, trunc, info = env.step(action_tensor)

            # Camera frames
            if "vision" in obs:
                cam_imgs = []
                for cam in CAMERA_NAMES:
                    if cam in obs["vision"]:
                        img = obs["vision"][cam][0].cpu().numpy().astype(np.uint8)
                        cam_imgs.append(img)
                if cam_imgs:
                    video_frames.append(np.concatenate(cam_imgs, axis=1))

            bar.update(1)
            if term or trunc:
                print(f"Episode ended at chunk {chunk_idx}, step {step_j} "
                      f"(term={term}, trunc={trunc})")
                done = True
                break

        if done:
            break

    bar.close()

    # Convert logs to arrays
    for side in ("left", "right"):
        sim_ee_log[side]["pos"] = np.array(sim_ee_log[side]["pos"])
        sim_ee_log[side]["quat"] = np.array(sim_ee_log[side]["quat"])

    T = len(sim_ee_log["left"]["pos"])
    T_achieved = len(achieved["left"]["pos"])
    T_plot = min(T, T_achieved)
    out_dir = Path(args.run_folder)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Save video
    if video_frames:
        video_path = str(out_dir / "replay_cameras.mp4")
        mediapy.write_video(video_path, video_frames, fps=10)
        print(f"Saved video ({len(video_frames)} frames) -> {video_path}")

    # Plot sim vs recorded achieved EE for each arm
    for side, label in [("left", "Left Panda"), ("right", "Right Panda")]:
        plot_ee_tracking(
            achieved[side]["pos"][:T_plot],
            achieved[side]["rot"][:T_plot],
            sim_ee_log[side]["pos"][:T_plot],
            sim_ee_log[side]["quat"][:T_plot],
            f"{label} EE Tracking ({args.mode})",
            str(out_dir / f"{side}_ee_tracking.png"),
        )

    print("Replay complete!")
    env.close()
    simulation_app.close()


@dataclasses.dataclass
class Args:
    environment: str
    """IsaacLab environment ID."""

    actions_npz: str
    """Path to NPZ file with action chunks."""

    mode: str = "open_loop"
    """Replay mode: 'open_loop' (absolute chunks), 'closed_loop' (relative chunks
    re-referenced to sim EE at chunk boundaries), or 'commanded' (relative chunks
    chained from the last commanded pose)."""

    headless: bool = True

    action_horizon: int | None = None
    """Steps to execute per chunk before switching to the next. Defaults to the
    inter-chunk spacing inferred from chunk_at_step."""

    left_only: bool = False
    """Only send left arm + gripper actions."""

    run_folder: str = "runs/test"


if __name__ == "__main__":
    args: Args = tyro.cli(Args)
    assert args.mode in ("open_loop", "closed_loop", "commanded"), \
        f"Invalid mode '{args.mode}', must be 'open_loop', 'closed_loop', or 'commanded'"
    main(args)
