"""Generate point-track dataset with circle overlay at EEF position.

ALL robot geoms hidden. A colored circle is drawn at the projected EEF pixel.
This makes the data look like generic point tracking — no robot visible at all.
Tests cross-embodiment pretraining: can PARA learn from non-robot visual data?

Usage:
    CUDA_VISIBLE_DEVICES=8 python generate_circle_overlay.py --start_idx 0 --end_idx 128
    CUDA_VISIBLE_DEVICES=9 python generate_circle_overlay.py --start_idx 128 --end_idx 256
"""
import argparse
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
from tqdm import tqdm

sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")

from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

from generate_ood_objpos import (
    TASK0_OBJECTS, FURNITURE_BODIES, DISTRACTOR_POS, DISTRACTOR_DOFS,
    AGENT_CAM, _si, move_distractors_in_state, hide_furniture,
    hide_distractors_visual, freeze_distractors, shift_pick_place,
    servo_to_position, find_grasp_timestep, interpolate_waypoints,
    extract_demo_eef_positions,
)

CIRCLE_RADIUS = 15
CIRCLE_COLOR_RGB = (255, 100, 50)  # bright orange
CIRCLE_OUTLINE_RGB = (200, 60, 20)


def hide_all_robot(sim):
    """Hide ALL robot geoms — arm, gripper, everything."""
    robot_bids = set()
    for i in range(sim.model.nbody):
        name = sim.model.body_id2name(i)
        if "robot0" in name or "gripper0" in name:
            robot_bids.add(i)

    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in robot_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
    sim.forward()
    return robot_bids


def draw_circle_at_eef(img_rgb, eef_pos, world_to_cam, H, W):
    """Draw a filled circle at the projected EEF position."""
    pix_rc = project_points_from_world_to_camera(
        eef_pos.reshape(1, 3).astype(np.float64), world_to_cam, H, W)[0]
    col, row = int(pix_rc[1]), int(pix_rc[0])

    if 0 <= col < W and 0 <= row < H:
        cv2.circle(img_rgb, (col, row), CIRCLE_RADIUS, CIRCLE_COLOR_RGB, -1)
        cv2.circle(img_rgb, (col, row), CIRCLE_RADIUS, CIRCLE_OUTLINE_RGB, 2)
    return img_rgb


def generate_trajectory_circle(env, states, actions, eef_orig, dx, dy, center_dx, center_dy,
                                robot_bids, frame_stride=3, z_offset=-0.015, max_servo=25,
                                image_size=448, pregrasp_lead=6, interp_steps=8):
    sim = env.env.sim
    H = W = image_size
    total_dx = center_dx + dx
    total_dy = center_dy + dy

    env.env.timestep = 0
    env.env.done = False

    state_0 = states[0].copy()
    state_0 = move_distractors_in_state(state_0)
    state_0 = shift_pick_place(state_0, total_dx, total_dy)

    env.set_init_state(state_0)
    sim.forward()

    # Re-hide robot after state reset
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in robot_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
    hide_furniture(sim)
    sim.forward()

    for _ in range(5):
        env.step(np.zeros(7, dtype=np.float32))
        freeze_distractors(sim)

    t_grasp = find_grasp_timestep(actions)
    t_pregrasp = max(0, t_grasp - pregrasp_lead)

    obs = env.env._get_observations()
    home_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    pregrasp_target = eef_orig[t_pregrasp].copy()
    pregrasp_target[0] += total_dx
    pregrasp_target[1] += total_dy

    gripper_cmd = -1.0
    recorded_frames = []
    recorded_eef_pos = []
    recorded_eef_quat = []
    recorded_gripper = []

    # Get camera matrix (static)
    w2c = get_camera_transform_matrix(sim, AGENT_CAM, H, W)

    def record(obs):
        eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float32)
        eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float32)
        recorded_eef_pos.append(eef_pos)
        recorded_eef_quat.append(eef_quat)
        recorded_gripper.append(gripper_cmd)
        # Render arm-deleted frame, then draw circle
        img = np.flipud(obs[f"{AGENT_CAM}_image"]).copy()
        img = draw_circle_at_eef(img, eef_pos.astype(np.float64), w2c, H, W)
        recorded_frames.append(img)

    record(obs)

    # Phase 1: Interpolate home → pre-grasp
    for wp in interpolate_waypoints(home_pos, pregrasp_target, interp_steps):
        obs = servo_to_position(env, wp, -1.0, max_servo=max_servo)
        record(obs)

    # Phase 2: Execute shifted trajectory
    phase2_indices = list(range(t_pregrasp, len(eef_orig), frame_stride))
    success = False
    for t in phase2_indices:
        target_pos = eef_orig[t].copy()
        target_pos[0] += total_dx
        target_pos[1] += total_dy
        if t < len(actions):
            gripper_cmd = float(np.clip(actions[t, 6], -1.0, 1.0))
        if gripper_cmd > 0 and z_offset != 0:
            target_pos[2] += z_offset
        obs = servo_to_position(env, target_pos, gripper_cmd, max_servo=max_servo)
        record(obs)
        if env.env.done or (hasattr(env.env, '_check_success') and env.env._check_success()):
            success = True

    # Camera params
    agent_ext = get_camera_extrinsic_matrix(sim, AGENT_CAM).astype(np.float32)
    agent_K = get_camera_intrinsic_matrix(sim, AGENT_CAM, H, W).astype(np.float32)
    agent_K_norm = agent_K.copy()
    agent_K_norm[0] /= W
    agent_K_norm[1] /= H

    eef_arr = np.stack(recorded_eef_pos)
    pix_uvs = []
    for i in range(len(eef_arr)):
        pix_rc = project_points_from_world_to_camera(
            eef_arr[i:i+1].astype(np.float64), w2c, H, W)[0]
        pix_uvs.append(np.array([pix_rc[1], pix_rc[0]], dtype=np.float32))

    return {
        "frames_agent": recorded_frames,
        "eef_pos": eef_arr,
        "eef_quat": np.stack(recorded_eef_quat),
        "gripper": np.array(recorded_gripper, dtype=np.float32),
        "pix_uv": np.stack(pix_uvs),
        "cam_extrinsic": agent_ext,
        "cam_K_norm": agent_K_norm,
        "world_to_cam": w2c.astype(np.float32),
        "base_z": np.float32(0.912),
        "success": success,
        "dx": dx, "dy": dy,
    }


def save_demo(data, demo_dir):
    demo_dir = Path(demo_dir)
    frames_dir = demo_dir / "frames"
    frames_dir.mkdir(parents=True, exist_ok=True)
    for i, frame in enumerate(data["frames_agent"]):
        cv2.imwrite(str(frames_dir / f"{i:06d}.png"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    np.save(demo_dir / "eef_pos.npy", data["eef_pos"])
    np.save(demo_dir / "eef_quat.npy", data["eef_quat"])
    np.save(demo_dir / "gripper.npy", data["gripper"])
    np.save(demo_dir / "pix_uv.npy", data["pix_uv"])
    np.save(demo_dir / "cam_extrinsic.npy", data["cam_extrinsic"])
    np.save(demo_dir / "cam_K_norm.npy", data["cam_K_norm"])
    np.save(demo_dir / "world_to_cam.npy", data["world_to_cam"])
    np.save(demo_dir / "base_z.npy", data["base_z"])
    np.save(demo_dir / "actions.npy", np.zeros((len(data["eef_pos"]), 7), dtype=np.float32))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--dx_min", type=float, default=-0.40)
    parser.add_argument("--dx_max", type=float, default=-0.01)
    parser.add_argument("--dy_min", type=float, default=-0.30)
    parser.add_argument("--dy_max", type=float, default=0.30)
    parser.add_argument("--image_size", type=int, default=448)
    parser.add_argument("--frame_stride", type=int, default=3)
    parser.add_argument("--z_offset", type=float, default=-0.015)
    parser.add_argument("--max_servo", type=int, default=25)
    parser.add_argument("--out_root", type=str, default="/data/libero/ood_objpos_circle_overlay")
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--end_idx", type=int, default=256)
    args = parser.parse_args()

    bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = bench.get_task(0)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        states = f[f"data/{demo_keys[0]}/states"][()]
        actions = f[f"data/{demo_keys[0]}/actions"][()]

    print("Extracting EEF trajectory...")
    env_tmp = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env_tmp.seed(0); env_tmp.reset()
    eef_orig = extract_demo_eef_positions(env_tmp, states)
    env_tmp.close()

    bowl_i = _si(TASK0_OBJECTS["akita_black_bowl_1"]["qpos_start"])
    center_dx = -states[0][bowl_i]
    center_dy = -states[0][bowl_i + 1]

    N = args.grid_size
    dx_vals = np.linspace(args.dx_min, args.dx_max, N)
    dy_vals = np.linspace(args.dy_min, args.dy_max, N)

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file, camera_heights=args.image_size, camera_widths=args.image_size,
        camera_names=[AGENT_CAM])
    env.seed(0); env.reset()
    env.env.horizon = 100000
    sim = env.env.sim

    hide_furniture(sim)
    hide_distractors_visual(sim)
    robot_bids = hide_all_robot(sim)

    task_dir = Path(args.out_root) / "libero_spatial" / "task_0"
    task_dir.mkdir(parents=True, exist_ok=True)
    np.savez(task_dir / "grid_meta.npz", dx_vals=dx_vals, dy_vals=dy_vals,
             center_dx=center_dx, center_dy=center_dy)

    successes = 0
    total = 0

    print(f"Generating demos {args.start_idx}-{args.end_idx} (circle overlay, no robot)")

    for demo_idx in tqdm(range(args.start_idx, args.end_idx), desc="Demos"):
        i = demo_idx // N
        j = demo_idx % N
        if i >= N or j >= N:
            continue
        dx, dy = dx_vals[i], dy_vals[j]

        demo_dir = task_dir / f"demo_{demo_idx}"
        if (demo_dir / "eef_pos.npy").exists():
            total += 1
            continue

        data = generate_trajectory_circle(
            env, states, actions, eef_orig,
            dx, dy, center_dx, center_dy,
            robot_bids,
            frame_stride=args.frame_stride,
            z_offset=args.z_offset,
            max_servo=args.max_servo,
            image_size=args.image_size,
        )

        save_demo(data, demo_dir)
        total += 1
        if data["success"]:
            successes += 1

        if total % 10 == 0:
            print(f"  [{total}] dx={dx:.3f} dy={dy:.3f} success={data['success']}")

    env.close()
    print(f"\nDone. {successes}/{total} succeeded.")


if __name__ == "__main__":
    main()
