"""object_removal_test.py — Render LIBERO task 0 with distractor removal and object position shifts.

For task 0 ("pick up the black bowl between the plate and the ramekin and place it on the plate"),
the essential objects are:
  - akita_black_bowl_1  (pick object)   — free joint at qpos[9:16]
  - plate_1             (place target)  — free joint at qpos[37:44]

Distractor objects (sunk underground to remove):
  - akita_black_bowl_2                  — free joint at qpos[16:23]
  - cookies_1                           — free joint at qpos[23:30]
  - glazed_rim_porcelain_ramekin_1      — free joint at qpos[30:37]

Outputs:
    ood_libero/out/object_removal.png       — with/without distractors comparison
    ood_libero/out/object_positions.png     — grid of shifted bowl+plate positions

Usage:
    python ood_libero/object_removal_test.py [--image_size 256] [--n_shifts 9] [--shift_range 0.1]
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import matplotlib.pyplot as plt
import numpy as np

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv


# Task 0 object layout — free joints are [x, y, z, qw, qx, qy, qz] in qpos.
# LIBERO state arrays have a 1-element prefix before qpos, so state indices = qpos indices + 1.
#   state[0]    = extra scalar (e.g. time)
#   state[1:49] = qpos[0:48]
#   state[49:]  = qvel[0:43]
STATE_QPOS_OFFSET = 1

TASK0_OBJECTS = {
    "akita_black_bowl_1":            {"qpos_start": 9,  "role": "pick"},    # qpos[9:16]
    "akita_black_bowl_2":            {"qpos_start": 16, "role": "distractor"},
    "cookies_1":                     {"qpos_start": 23, "role": "distractor"},
    "glazed_rim_porcelain_ramekin_1": {"qpos_start": 30, "role": "distractor"},
    "plate_1":                       {"qpos_start": 37, "role": "place"},   # qpos[37:44]
}

SINK_POS = np.array([0.0, 0.0, -5.0])  # far underground

# Fixed-body furniture to remove (not free joints — moved via sim.model.body_pos)
FURNITURE_BODIES = ["wooden_cabinet_1_main", "flat_stove_1_main"]


def hide_furniture(sim):
    """Move cabinet and stove bodies underground. Returns dict of original positions to restore."""
    originals = {}
    for name in FURNITURE_BODIES:
        bid = sim.model.body_name2id(name)
        originals[name] = (bid, sim.model.body_pos[bid].copy())
        sim.model.body_pos[bid] = SINK_POS
    sim.forward()
    return originals


def restore_furniture(sim, originals):
    """Restore furniture bodies to their original positions."""
    for name, (bid, pos) in originals.items():
        sim.model.body_pos[bid] = pos
    sim.forward()


def render_state(env, sim, state, camera, image_size):
    """Set sim state, forward, render, return flipud uint8 RGB."""
    env.set_init_state(state)
    sim.forward()
    img_key = f"{camera}_image"
    obs = env.env._get_observations()
    rgb = np.asarray(obs[img_key]).copy()
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    return np.ascontiguousarray(np.flipud(rgb))


def _state_idx(qpos_start):
    """Convert a qpos index to the corresponding state array index."""
    return qpos_start + STATE_QPOS_OFFSET


def remove_distractors(state):
    """Return a copy of state with distractor objects sunk underground."""
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] == "distractor":
            i = _state_idx(info["qpos_start"])
            s[i:i + 3] = SINK_POS  # overwrite x, y, z
    return s


def shift_pick_place(state, dx, dy):
    """Return a copy of state with pick and place objects shifted by (dx, dy) on the table."""
    s = state.copy()
    for name, info in TASK0_OBJECTS.items():
        if info["role"] in ("pick", "place"):
            i = _state_idx(info["qpos_start"])
            s[i] += dx      # x
            s[i + 1] += dy  # y
    return s


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=256)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--n_shifts", type=int, default=5,
                        help="Grid size per axis for position shifts (n x n grid)")
    parser.add_argument("--shift_range", type=float, default=0.1,
                        help="Max XY shift in meters from original position")
    parser.add_argument("--out_dir", type=str, default=None)
    args = parser.parse_args()

    script_dir = Path(__file__).resolve().parent
    out_dir = Path(args.out_dir) if args.out_dir else script_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load env + demo
    bench = bm_lib.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demo_keys[min(args.demo_id, len(demo_keys) - 1)]
        states = f[f"data/{demo_key}/states"][()]

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    state_0 = states[0]

    # Render original
    print("Rendering original scene...")
    rgb_original = render_state(env, sim, state_0, args.camera, args.image_size)

    # Render with distractors + furniture removed
    state_clean = remove_distractors(state_0)
    furniture_orig = hide_furniture(sim)
    print("Rendering scene with distractors + furniture removed...")
    rgb_clean = render_state(env, sim, state_clean, args.camera, args.image_size)

    # Side-by-side plot
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(rgb_original)
    axes[0].set_title("Original (all objects)", fontsize=11)
    axes[0].axis("off")

    axes[1].imshow(rgb_clean)
    axes[1].set_title("Distractors + furniture removed\n(bowl + plate only)", fontsize=11)
    axes[1].axis("off")

    plt.suptitle(f"Task 0: {task.name}", fontsize=10, y=0.02)
    plt.tight_layout()
    out_path = out_dir / "object_removal.png"
    plt.savefig(str(out_path), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

    # Also save individual images
    cv2.imwrite(str(out_dir / "frame0_original.png"), cv2.cvtColor(rgb_original, cv2.COLOR_RGB2BGR))
    cv2.imwrite(str(out_dir / "frame0_clean.png"), cv2.cvtColor(rgb_clean, cv2.COLOR_RGB2BGR))

    # ----- Render grid of shifted object positions (distractors removed) -----
    # First, center the pick object (bowl) in the scene.
    # Read the bowl's original XY from the clean state, compute offset to table center (x=0, y=0).
    # y=0 lines up with the robot base, centering the bowl horizontally in agentview.
    bowl_i = _state_idx(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    bowl_orig_x = state_clean[bowl_i]
    bowl_orig_y = state_clean[bowl_i + 1]
    center_dx = -bowl_orig_x  # shift bowl to x=0
    center_dy = -bowl_orig_y  # shift bowl to y=0
    state_centered = shift_pick_place(state_clean, center_dx, center_dy)
    print(f"\nBowl original pos: ({bowl_orig_x:.3f}, {bowl_orig_y:.3f})")
    print(f"Centering offset:  ({center_dx:+.3f}, {center_dy:+.3f})")

    n = args.n_shifts
    r = args.shift_range
    offsets = np.linspace(-r, r, n)
    dx_grid, dy_grid = np.meshgrid(offsets, offsets)

    print(f"Rendering {n}x{n} position-shift grid (range +/-{r}m from center)...")
    fig, axes = plt.subplots(n, n, figsize=(3 * n, 3 * n))
    for row in range(n):
        for col in range(n):
            dx, dy = dx_grid[row, col], dy_grid[row, col]
            state_shifted = shift_pick_place(state_centered, dx, dy)
            rgb = render_state(env, sim, state_shifted, args.camera, args.image_size)
            ax = axes[row, col]
            ax.imshow(rgb)
            is_center = abs(dx) < 1e-6 and abs(dy) < 1e-6
            color = "red" if is_center else "black"
            label = "dx=0, dy=0\n(centered)" if is_center else f"dx={dx:+.2f}, dy={dy:+.2f}"
            ax.set_title(label, fontsize=8, color=color)
            ax.axis("off")
            print(f"  [{row},{col}] dx={dx:+.3f} dy={dy:+.3f}", end="\r")

    plt.suptitle(f"Object position shifts ({n}x{n}, +/-{r}m)\n"
                 f"Pick object centered at scene origin, distractors removed",
                 fontsize=12)
    plt.tight_layout()
    out_path2 = out_dir / "object_positions.png"
    plt.savefig(str(out_path2), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path2}")

    restore_furniture(sim, furniture_orig)
    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
