
import logging
import re

import numpy as np
import torch

import isaacsim.core.utils.bounds as bounds_utils
from isaaclab.assets import Articulation, RigidObject, RigidObjectCollection
from isaaclab.envs import ManagerBasedEnv, ManagerBasedRLEnv
from isaaclab.managers import ManagerTermBase, SceneEntityCfg, TerminationTermCfg
from isaaclab.sensors import ContactSensor
from isaaclab.utils import math as math_utils

logger = logging.getLogger(__name__)


def abnormal_robot_state(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
    robot: Articulation = env.scene[asset_cfg.name]
    return (robot.data.joint_vel.abs() > (robot.data.joint_vel_limits * 2)).any(dim=1)


def singularity_termination(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
    threshold: float = 0.02,
) -> torch.Tensor:
    """Terminate if arm is near a kinematic singularity.

    Computes the minimum singular value of the EE Jacobian.
    Returns True for envs where min_sv < threshold.
    """
    from sim_improvement.environments.mdp.rewards import _min_singular_value
    asset: Articulation = env.scene[asset_cfg.name]
    min_sv = _min_singular_value(asset)
    return min_sv < threshold


def success_termination(env: ManagerBasedRLEnv, context: str = "progress_context") -> torch.Tensor:
    context_term: ManagerTermBase = env.reward_manager.get_term_cfg(context).func  # type: ignore
    insertive_point_in_receptive_obb: torch.Tensor = getattr(context_term, "insertive_point_in_receptive_obb")
    finger_open: torch.Tensor = getattr(context_term, "finger_open")
    # print(f"Insertive point in receptive OBB: {insertive_point_in_receptive_obb}")
    # print(f"Finger open: {finger_open}")
    return (insertive_point_in_receptive_obb & finger_open)


# ---------------------------------------------------------------------------
# Standalone predicate functions
# ---------------------------------------------------------------------------

def on_top_predicate(
    env: ManagerBasedRLEnv,
    top_object_cfg: SceneEntityCfg,
    bottom_object_cfg: SceneEntityCfg,
    xy_threshold: float = 0.1,
    z_threshold: float = 0.5,
    velocity_bound: float = 0.1,
    require_support: bool = False,
    support_z_threshold: float = 0.15,
    require_no_fingers_touching: bool = False,
    gripper_cfgs: list[SceneEntityCfg] | None = None,
    gripper_open_threshold: float = -0.038,
    gripper_proximity_threshold: float = 0.1,
) -> torch.Tensor:
    """True when *top* object sits on *bottom* object.

    Checks:
        1. XY distance between root positions < ``xy_threshold``.
        2. Top object is above bottom object and within ``z_threshold``.
        3. Top object linear velocity magnitude < ``velocity_bound``.
        4. (if ``require_support``) Z gap is small enough that the top
           object is resting on the bottom (< ``support_z_threshold``).
        5. (if ``require_no_fingers_touching``) All gripper finger joints
           are open past ``gripper_open_threshold``, unless the closed
           gripper is far from the top object (> ``gripper_proximity_threshold``).
    """
    top_obj: RigidObject = env.scene[top_object_cfg.name]
    bot_obj: RigidObject = env.scene[bottom_object_cfg.name]

    top_pos = top_obj.data.root_pos_w
    bot_pos = bot_obj.data.root_pos_w

    xy_dist = torch.norm(top_pos[:, :2] - bot_pos[:, :2], dim=1)
    z_diff = top_pos[:, 2] - bot_pos[:, 2]
    vel = torch.norm(top_obj.data.root_vel_w[:, :3], dim=1)

    xy_ok = xy_dist < xy_threshold
    above = z_diff > 0
    z_ok = z_diff < z_threshold
    vel_ok = vel < velocity_bound

    result = xy_ok & above & z_ok & vel_ok

    # Approximate "support" as a tight z proximity check — object is
    # resting on the surface, not hovering above it.
    if require_support:
        support_ok = z_diff < support_z_threshold
        # # print(
        # #     f"[on_top] xy_dist={xy_dist[0].item():.4f} (<{xy_threshold}={xy_ok[0].item()}), "
        # #     f"z_diff={z_diff[0].item():.4f} (<{z_threshold}={z_ok[0].item()}), "
        # #     f"vel={vel[0].item():.4f} (<{velocity_bound}={vel_ok[0].item()}), "
        # #     f"support=({z_diff[0].item():.4f}<{support_z_threshold}={support_ok[0].item()})"
        # )
        result = result & support_ok

    # Check that all grippers have released the object.
    # panda_finger_joint1: -0.04 = open, 0.0 = closed.
    # A closed gripper is OK if it's far from the object (not gripping it).
    if require_no_fingers_touching and gripper_cfgs:
        for gcfg in gripper_cfgs:
            robot: Articulation = env.scene[gcfg.name]
            jidx = robot.data.joint_names.index(gcfg.joint_names[0])
            finger_pos = robot.data.joint_pos[:, jidx]
            finger_open = finger_pos < gripper_open_threshold
            finger_body_idxs = [
                i for i, name in enumerate(robot.data.body_names) if "finger" in name
            ]
            finger_positions = robot.data.body_pos_w[:, finger_body_idxs]  # (num_envs, num_fingers, 3)
            dists = torch.norm(finger_positions - top_pos[:, None, :], dim=-1)  # (num_envs, num_fingers)
            min_dist = dists.min(dim=1).values  # (num_envs,)
            gripper_far = min_dist > gripper_proximity_threshold
            finger_ok = finger_open | gripper_far
            result = result & finger_ok

    return result


def contact_flag_predicate(
    env: ManagerBasedRLEnv,
    sensor_cfg: SceneEntityCfg,
    bodies_in_contact_expected: bool = True,
    force_threshold: float = 1.0,
) -> torch.Tensor:
    """True when contact state between two bodies matches expectation.

    Queries a ``ContactSensor`` (set up by ``ScenarioSceneCfg.dynamic_setup``)
    for the pairwise normal contact force between body_1 and body_2.

    Args:
        sensor_cfg: SceneEntityCfg pointing to the ContactSensor
            (named ``contact__{body1}__{body2}``).
        bodies_in_contact_expected: Whether contact is the desired state.
        force_threshold: Minimum force magnitude (N) to count as contact.
    """
    sensor: ContactSensor = env.scene[sensor_cfg.name]
    # force_matrix_w: (num_envs, num_sensor_bodies, num_filter_bodies, 3)
    force_matrix = sensor.data.force_matrix_w
    # Collapse across bodies/filters → scalar force per env
    contact_force = force_matrix.norm(dim=-1).sum(dim=(1, 2))
    in_contact = contact_force > force_threshold
    if bodies_in_contact_expected:
        return in_contact
    return ~in_contact


def _resolve_frame_world_pos(
    env: ManagerBasedRLEnv,
    frame_spec: str,
    frame_offsets: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Resolve ``"object_name::frame_name"`` to world position ``(num_envs, 3)``.

    Uses the object's root pose and a pre-computed local offset from the
    ``"frames"`` dict in the JSON config.
    """
    obj_name, frame_name = frame_spec.split("::")
    obj: RigidObject = env.scene[obj_name]
    root_pos = obj.data.root_pos_w  # (num_envs, 3)
    root_quat = obj.data.root_quat_w  # (num_envs, 4)

    key = f"{obj_name}::{frame_name}"
    offset = frame_offsets[key].to(device=env.device)  # (3,)
    world_offset = math_utils.quat_apply(root_quat, offset.unsqueeze(0).expand(env.num_envs, -1))
    return root_pos + world_offset


def _point_in_oriented_box(
    point: torch.Tensor,
    box_pos: torch.Tensor,
    box_quat: torch.Tensor,
    box_lo: torch.Tensor,
    box_hi: torch.Tensor,
) -> torch.Tensor:
    """Check if point is inside an oriented bounding box.

    All inputs are ``(num_envs, 3)`` or ``(num_envs, 4)`` for quaternions.
    ``box_lo`` / ``box_hi`` are ``(3,)`` bounds in the box's local frame.

    Returns ``(num_envs,)`` boolean tensor.
    """
    # Transform point into the box's local frame
    rel = point - box_pos  # (num_envs, 3)
    inv_quat = math_utils.quat_inv(box_quat)  # (num_envs, 4)
    local = math_utils.quat_apply(inv_quat, rel)  # (num_envs, 3)
    return (local >= box_lo).all(dim=1) & (local <= box_hi).all(dim=1)


def mug_on_branch_predicate(
    env: ManagerBasedRLEnv,
    branch_frame: str,
    mug_handle_frames: list[str],
    mug_center_frame: str,
    branch_length: float,
    handle_length: float,
    frame_offsets: dict[str, torch.Tensor],
    branch_rotation: torch.Tensor,
) -> torch.Tensor:
    """True when mug hangs on a single branch.

    Checks (mirroring anzu's MugOnBranchPredicate):
    1. At least one mug handle frame is in a box above/around the branch.
    2. The mug center is in a box below the branch.

    The branch frame's x-axis points perpendicular from the branch downward.
    """
    branch_pos = _resolve_frame_world_pos(env, branch_frame, frame_offsets)
    branch_quat = branch_rotation.unsqueeze(0).expand(env.num_envs, -1).to(device=env.device)

    # Box for handle: above the branch
    handle_lo = torch.tensor([-0.02, -handle_length / 2, 0.0], device=env.device)
    handle_hi = torch.tensor([0.0, handle_length / 2, branch_length * 0.9], device=env.device)

    # At least one handle frame must be in the box
    handle_ok = torch.zeros(env.num_envs, device=env.device, dtype=torch.bool)
    for hf in mug_handle_frames:
        hf_pos = _resolve_frame_world_pos(env, hf, frame_offsets)
        handle_ok = handle_ok | _point_in_oriented_box(hf_pos, branch_pos, branch_quat, handle_lo, handle_hi)

    # Mug center must be below the branch
    center_lo = torch.tensor([0.01, -handle_length, 0.0], device=env.device)
    center_hi = torch.tensor([0.1, handle_length, branch_length * 0.9], device=env.device)
    center_pos = _resolve_frame_world_pos(env, mug_center_frame, frame_offsets)
    center_ok = _point_in_oriented_box(center_pos, branch_pos, branch_quat, center_lo, center_hi)

    return handle_ok & center_ok


def mug_on_mug_holder_predicate(
    env: ManagerBasedRLEnv,
    branch_frames: list[str],
    mug_handle_frames: list[str],
    mug_center_frame: str,
    branch_length: float,
    handle_length: float,
    frame_offsets: dict[str, torch.Tensor],
    branch_rotations: list[torch.Tensor],
) -> torch.Tensor:
    """True when mug hangs on ANY branch of the mug holder."""
    result = torch.zeros(env.num_envs, device=env.device, dtype=torch.bool)
    for bf, brot in zip(branch_frames, branch_rotations):
        result = result | mug_on_branch_predicate(
            env, bf, mug_handle_frames, mug_center_frame,
            branch_length, handle_length, frame_offsets, brot,
        )
    return result


def frames_in_bounds_predicate(
    env: ManagerBasedRLEnv,
    object_cfgs: list[SceneEntityCfg],
    reference_cfg: SceneEntityCfg,
    bounds_lo: tuple[float, float, float] = (-1.0, -1.0, 0.0),
    bounds_hi: tuple[float, float, float] = (1.0, 1.0, 1.0),
) -> torch.Tensor:
    """True when every tracked object is inside an axis-aligned box
    defined relative to a reference body."""
    ref: RigidObject = env.scene[reference_cfg.name]
    ref_pos = ref.data.root_pos_w

    lo = torch.tensor(bounds_lo, device=env.device, dtype=torch.float32)
    hi = torch.tensor(bounds_hi, device=env.device, dtype=torch.float32)

    result = torch.ones(env.num_envs, device=env.device, dtype=torch.bool)
    for cfg in object_cfgs:
        obj: RigidObject = env.scene[cfg.name]
        rel = obj.data.root_pos_w - ref_pos
        result = result & (rel >= lo).all(dim=1) & (rel <= hi).all(dim=1)
    return result


# ---------------------------------------------------------------------------
# Composite success termination
# ---------------------------------------------------------------------------

class scenario_success(ManagerTermBase):
    """Success termination that evaluates ALL task predicates from spawn_config.

    Parses ``task_predicate_configs`` and builds a list of per-predicate
    check callables (using the standalone functions above).  All must be
    satisfied simultaneously for the episode to count as a success.

    Predicates referencing bodies not present as standalone scene entities
    (e.g. ``manipuland_table`` baked into the station) are skipped with a
    warning.
    """

    def __init__(self, cfg: TerminationTermCfg, env: ManagerBasedRLEnv):
        super().__init__(cfg, env)
        predicates = cfg.params.get("predicates", [])
        scene_entity_names = set(env.scene.keys())

        # Auto-detect gripper articulations for require_no_fingers_touching.
        gripper_cfgs: list[SceneEntityCfg] = []
        for name in scene_entity_names:
            entity = env.scene[name]
            if isinstance(entity, Articulation) and "panda_finger_joint1" in entity.data.joint_names:
                gripper_cfgs.append(SceneEntityCfg(name=name, joint_names=["panda_finger_joint1"]))

        # Each entry is a (func, kwargs) pair to call at every step.
        self._checks: list[tuple[callable, dict]] = []

        for pred in predicates:
            ptype = pred.get("_type", "")

            if ptype == "ontoppredicateconfig":
                top_name = pred["top_object_name"]
                bot_name = pred["bottom_object_name"]
                if top_name not in scene_entity_names or bot_name not in scene_entity_names:
                    logger.warning("on_top predicate skipped: %s or %s not in scene", top_name, bot_name)
                    continue
                kwargs = {
                    "top_object_cfg": SceneEntityCfg(top_name),
                    "bottom_object_cfg": SceneEntityCfg(bot_name),
                    "xy_threshold": float(pred.get("xy_threshold", 0.1)),
                    "z_threshold": float(pred.get("z_threshold", 0.5)),
                    "velocity_bound": float(pred.get("top_object_velocity_bound", 0.1)),
                    "require_support": bool(pred.get("require_support", False)),
                    "support_z_threshold": float(pred.get("support_z_threshold", 0.15)),
                    "require_no_fingers_touching": bool(pred.get("require_no_fingers_touching", False)),
                }
                if pred.get("require_no_fingers_touching") and gripper_cfgs:
                    kwargs["gripper_cfgs"] = gripper_cfgs
                self._checks.append((on_top_predicate, kwargs))

            elif ptype == "contactflagpredicateconfig":
                b1 = pred["body_1_name"]
                b2 = pred["body_2_name"]
                sensor_name = f"contact__{b1}__{b2}"
                if sensor_name not in scene_entity_names:
                    logger.warning("contact predicate skipped: sensor %s not in scene", sensor_name)
                    continue
                self._checks.append((contact_flag_predicate, {
                    "sensor_cfg": SceneEntityCfg(sensor_name),
                    "bodies_in_contact_expected": bool(pred.get("bodies_in_contact_expected", True)),
                }))

            elif ptype == "framesinrelativeboxpredicateconfig":
                ref_name = pred.get("frame_B_name", "").split("::")[0]
                if ref_name not in scene_entity_names:
                    logger.warning("bounds predicate skipped: reference %s not in scene", ref_name)
                    continue
                # Resolve object names from frame_A patterns
                frame_a_patterns = pred.get("frame_A_names", [])
                matched: list[str] = []
                for pattern in frame_a_patterns:
                    obj_pattern = pattern.split("::")[0]
                    regex = re.compile(f"^{obj_pattern}$")
                    for name in scene_entity_names:
                        if regex.match(name) and name not in matched:
                            matched.append(name)
                if not matched:
                    logger.warning("bounds predicate skipped: no objects matched %s", frame_a_patterns)
                    continue
                self._checks.append((frames_in_bounds_predicate, {
                    "object_cfgs": [SceneEntityCfg(n) for n in matched],
                    "reference_cfg": SceneEntityCfg(ref_name),
                    "bounds_lo": tuple(pred.get("p_BP_lo", [-1, -1, 0])),
                    "bounds_hi": tuple(pred.get("p_BP_hi", [1, 1, 1])),
                }))

            elif ptype == "mugonmugholderpredicateconfig":
                object_frames_raw = cfg.params.get("object_frames", {})
                if not object_frames_raw:
                    logger.warning("mugonmugholder predicate skipped: no object_frames in params")
                    continue
                yup_objects = cfg.params.get("yup_objects", set())

                # Build frame_offsets: "obj::frame" -> torch.Tensor(3,)
                # Frame translations in JSON are Z-up (SDF convention).
                # For needs_yup objects, root_quat_w includes Rx(90°) which
                # expects Y-up mesh coords, so convert: [x,y,z] -> [x,z,-y]
                frame_offsets: dict[str, torch.Tensor] = {}
                for obj_name, frames_dict in object_frames_raw.items():
                    needs_yup = obj_name in yup_objects
                    for frame_name, frame_data in frames_dict.items():
                        x, y, z = frame_data.get("translation", [0, 0, 0])
                        if needs_yup:
                            offset = torch.tensor([x, z, -y], dtype=torch.float32)
                        else:
                            offset = torch.tensor([x, y, z], dtype=torch.float32)
                        frame_offsets[f"{obj_name}::{frame_name}"] = offset

                # Collect branch frame names and their rotations
                branch_frame_names = pred.get("branch_frame_names", [])
                branch_rotations: list[torch.Tensor] = []
                skip = False
                for bf in branch_frame_names:
                    obj_name, frame_name = bf.split("::")
                    if obj_name not in scene_entity_names:
                        logger.warning("mugonmugholder predicate skipped: %s not in scene", obj_name)
                        skip = True
                        break
                    frame_data = object_frames_raw.get(obj_name, {}).get(frame_name, {})
                    rpy = frame_data.get("rotation_rpy_rad", [0, 0, 0])
                    quat = math_utils.quat_from_euler_xyz(
                        torch.tensor([rpy[0]]), torch.tensor([rpy[1]]), torch.tensor([rpy[2]])
                    ).squeeze(0)
                    branch_rotations.append(quat)
                if skip:
                    continue

                # Validate mug frames exist
                mug_handle_frames = pred.get("mug_handle_frame_names", [])
                mug_center_frame = pred.get("mug_center_frame_name", "")
                all_mug_frames = mug_handle_frames + [mug_center_frame]
                mug_ok = True
                for mf in all_mug_frames:
                    obj_name = mf.split("::")[0]
                    if obj_name not in scene_entity_names:
                        logger.warning("mugonmugholder predicate skipped: %s not in scene", obj_name)
                        mug_ok = False
                        break
                    if mf not in frame_offsets:
                        logger.warning("mugonmugholder predicate skipped: frame %s not in object_frames", mf)
                        mug_ok = False
                        break
                if not mug_ok:
                    continue

                self._checks.append((mug_on_mug_holder_predicate, {
                    "branch_frames": branch_frame_names,
                    "mug_handle_frames": mug_handle_frames,
                    "mug_center_frame": mug_center_frame,
                    "branch_length": float(pred.get("branch_length", 0.08)),
                    "handle_length": float(pred.get("handle_length", 0.05)),
                    "frame_offsets": frame_offsets,
                    "branch_rotations": branch_rotations,
                }))

    def __call__(self, env: ManagerBasedRLEnv, predicates: list | None = None, object_frames: dict | None = None, yup_objects: set | None = None) -> torch.Tensor:
        if not self._checks:
            return torch.zeros(env.num_envs, device=env.device, dtype=torch.bool)

        success = torch.ones(env.num_envs, device=env.device, dtype=torch.bool)
        for func, kwargs in self._checks:
            result = func(env, **kwargs)
            success = success & result

        return success

