"""Generate arm-deleted dataset with red gripper fingers.

Same as generate_ood_objpos.py but with:
- Robot arm links 0-7 invisible
- Gripper fingers (gripper0_*) rendered in red
- All other data (eef_pos, gripper, pix_uv, etc.) unchanged

Usage:
    # Split across GPUs:
    CUDA_VISIBLE_DEVICES=4 python generate_arm_deleted.py --start_idx 0 --end_idx 64
    CUDA_VISIBLE_DEVICES=5 python generate_arm_deleted.py --start_idx 64 --end_idx 128
    CUDA_VISIBLE_DEVICES=6 python generate_arm_deleted.py --start_idx 128 --end_idx 192
    CUDA_VISIBLE_DEVICES=7 python generate_arm_deleted.py --start_idx 192 --end_idx 256
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

from generate_ood_objpos import (
    TASK0_OBJECTS, FURNITURE_BODIES, DISTRACTOR_POS, DISTRACTOR_DOFS,
    AGENT_CAM, _si, move_distractors_in_state, hide_furniture,
    hide_distractors_visual, freeze_distractors, shift_pick_place,
    servo_to_position, find_grasp_timestep, interpolate_waypoints,
    extract_demo_eef_positions,
)

# Gripper body names (fingers only — the end-effector)
GRIPPER_BODY_NAMES = {
    "gripper0_right_gripper", "gripper0_eef",
    "gripper0_leftfinger", "gripper0_finger_joint1_tip",
    "gripper0_rightfinger", "gripper0_finger_joint2_tip",
}

# Arm body names (everything to hide)
ARM_BODY_NAMES = {
    "robot0_base", "robot0_link0", "robot0_link1", "robot0_link2",
    "robot0_link3", "robot0_link4", "robot0_link5", "robot0_link6",
    "robot0_link7", "robot0_right_hand",
}


def apply_arm_deletion(sim):
    """Hide arm links, make gripper fingers red."""
    gripper_bids = set()
    arm_bids = set()
    for i in range(sim.model.nbody):
        name = sim.model.body_id2name(i)
        if name in GRIPPER_BODY_NAMES:
            gripper_bids.add(i)
        elif name in ARM_BODY_NAMES:
            arm_bids.add(i)

    for geom_id in range(sim.model.ngeom):
        bid = sim.model.geom_bodyid[geom_id]
        if bid in arm_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
        elif bid in gripper_bids:
            sim.model.geom_rgba[geom_id] = [1.0, 0.0, 0.0, 1.0]
    sim.forward()
    return arm_bids, gripper_bids


def generate_trajectory_arm_deleted(env, states, actions, eef_orig, dx, dy, center_dx, center_dy,
                                     arm_bids, gripper_bids,
                                     frame_stride=3, z_offset=-0.015, max_servo=25, image_size=448,
                                     pregrasp_lead=6, interp_steps=8):
    """Same as generate_trajectory but with arm-deleted rendering."""
    sim = env.env.sim
    H = W = image_size
    total_dx = center_dx + dx
    total_dy = center_dy + dy

    env.env.timestep = 0
    env.env.done = False

    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)

    env.set_init_state(state_0)
    sim.forward()

    # Re-apply visual changes after state reset
    for geom_id in range(sim.model.ngeom):
        bid = sim.model.geom_bodyid[geom_id]
        if bid in arm_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
        elif bid in gripper_bids:
            sim.model.geom_rgba[geom_id] = [1.0, 0.0, 0.0, 1.0]
    hide_furniture(sim)
    sim.forward()

    # Settle
    for _ in range(5):
        env.step(np.zeros(7, dtype=np.float32))
        freeze_distractors(sim)

    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - pregrasp_lead)

    obs = env.env._get_observations()
    home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    pregrasp_target = eef_orig[t_pregrasp].copy()
    pregrasp_target[0] += total_dx
    pregrasp_target[1] += total_dy

    gripper_cmd = -1.0

    recorded_frames_agent = []
    recorded_eef_pos = []
    recorded_eef_quat = []
    recorded_gripper = []

    def record(obs):
        eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
        eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
        recorded_eef_pos.append(eef_pos)
        recorded_eef_quat.append(eef_quat)
        recorded_gripper.append(gripper_cmd)
        agent_img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
        recorded_frames_agent.append(agent_img)

    record(obs)

    # Phase 1: Interpolate home → pre-grasp
    interp_wps = interpolate_waypoints(home_pos, pregrasp_target, interp_steps)
    for wp in interp_wps:
        obs = servo_to_position(env, wp, -1.0, max_servo=max_servo)
        record(obs)

    # Phase 2: Execute shifted trajectory
    phase2_indices = list(range(t_pregrasp, len(eef_orig), frame_stride))
    success = False
    for t in phase2_indices:
        target_pos = eef_orig[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy

        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))

        if gripper_cmd > 0 and z_offset != 0:
            target_pos[2] += z_offset

        obs = servo_to_position(env, target_pos, gripper_cmd, max_servo=max_servo)
        record(obs)

        if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
            success = True

    # Camera params
    agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
    agent_w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K_norm = agent_K.copy()
    agent_K_norm[0] /= W
    agent_K_norm[1] /= H

    eef_arr = np.stack(recorded_eef_pos)
    pix_uvs = []
    for i in range(len(eef_arr)):
        pix_rc = project_points_from_world_to_camera(
            eef_arr[i:i+1].astype(np.float64), agent_w2c, H, W)[0]
        pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

    return {
        "frames_agent": recorded_frames_agent,
        "eef_pos": np.stack(recorded_eef_pos),
        "eef_quat": np.stack(recorded_eef_quat),
        "gripper": np.array(recorded_gripper, dtype=np.float32),
        "pix_uv": np.stack(pix_uvs),
        "cam_extrinsic": agent_ext,
        "cam_K_norm": agent_K_norm,
        "world_to_cam": agent_w2c,
        "base_z": np.float32(0.912),
        "success": success,
        "dx": dx, "dy": dy,
    }


def save_demo(data, demo_dir):
    demo_dir = Path(demo_dir)
    frames_dir = demo_dir / "frames"
    frames_dir.mkdir(parents=True, exist_ok=True)
    for i, frame in enumerate(data["frames_agent"]):
        cv2.imwrite(str(frames_dir / f"{i:06d}.png"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    np.save(demo_dir / "eef_pos.npy", data["eef_pos"])
    np.save(demo_dir / "eef_quat.npy", data["eef_quat"])
    np.save(demo_dir / "gripper.npy", data["gripper"])
    np.save(demo_dir / "pix_uv.npy", data["pix_uv"])
    np.save(demo_dir / "cam_extrinsic.npy", data["cam_extrinsic"])
    np.save(demo_dir / "cam_K_norm.npy", data["cam_K_norm"])
    np.save(demo_dir / "world_to_cam.npy", data["world_to_cam"])
    np.save(demo_dir / "base_z.npy", data["base_z"])
    np.save(demo_dir / "actions.npy", np.zeros((len(data["eef_pos"]), 7), dtype=np.float32))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--dx_min", type=float, default=-0.40)
    parser.add_argument("--dx_max", type=float, default=-0.01)
    parser.add_argument("--dy_min", type=float, default=-0.30)
    parser.add_argument("--dy_max", type=float, default=0.30)
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--max_servo", type=int, default=25)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_objpos_arm_deleted")
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--end_idx", type=int, default=256)
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        states = f[f"data/{demo_keys[0]}/states"][()]
        actions = f[f"data/{demo_keys[0]}/actions"][()]

    print(f"Task: {task.name}")
    print("Extracting original EEF trajectory...")
    env_tmp = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = extract_demo_eef_positions(env_tmp, states)
    env_tmp.close()

    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    center_dx = -states[0][bowl_i]
    center_dy = -states[0][bowl_i + 1]
    print(f"Centering offset: ({center_dx:+.3f}, {center_dy:+.3f})")

    N = args.grid_size
    dx_vals = np.linspace(args.dx_min, args.dx_max, N)
    dy_vals = np.linspace(args.dy_min, args.dy_max, N)

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim

    # Apply visual changes
    hide_furniture(sim)
    hide_distractors_visual(sim)
    arm_bids, gripper_bids = apply_arm_deletion(sim)

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)

    # Save grid metadata
    np.savez(task_dir / "grid_meta.npz", dx_vals=dx_vals, dy_vals=dy_vals,
             center_dx=center_dx, center_dy=center_dy)

    successes = 0
    total = 0

    print(f"\nGenerating demos {args.start_idx} to {args.end_idx} (arm-deleted, red gripper)")

    for demo_idx in tqdm(range(args.start_idx, args.end_idx), desc="Demos"):
        i = demo_idx // N
        j = demo_idx % N
        if i >= N or j >= N:
            continue
        dx, dy = dx_vals[i], dy_vals[j]

        demo_dir = task_dir / f"demo_{demo_idx}"
        if (demo_dir / "eef_pos.npy").exists():
            total += 1
            continue

        data = generate_trajectory_arm_deleted(
            env, states, actions, eef_orig,
            dx, dy, center_dx, center_dy,
            arm_bids, gripper_bids,
            frame_stride=args.frame_stride,
            z_offset=args.z_offset,
            max_servo=args.max_servo,
            image_size=args.image_size,
        )

        save_demo(data, demo_dir)
        total += 1
        if data["success"]:
            successes += 1

        if total % 10 == 0:
            print(f"  [{total}] dx={dx:.3f} dy={dy:.3f} success={data['success']}")

    env.close()
    print(f"\nDone. {successes}/{total} succeeded.")


if __name__ == "__main__":
    main()
