"""Scenario helpers: stochastic sampling, midpoint geometry, and reset event.

Stochastic sampling helpers handle all types from spawn_config.json:
  - uniform_vector: per-element uniform over min/max vectors
  - uniform: scalar uniform over min/max
  - uniform_discrete: random pick from values list
  - uniform_discrete_string: random pick from string values list
  - rpy: roll-pitch-yaw (deg components may themselves be stochastic)
  - literal list: constant tensor
  - None/null: zeros (translation) or identity quat (rotation)

Midpoint / quaternion helpers resolve stochastic specs to their
deterministic midpoints (used for USD spawn initial poses).

Includes physics-based collision detection for reset validation:
after sampling poses for all objects, writes them to sim, steps physics
once, and checks displacement from intended positions. Objects moved
by PhysX collision resolution are resampled. Uses the actual collision
meshes — works for any geometry (bowls, containers, arbitrary shapes).
"""

from __future__ import annotations

import json
import math
import re

import h5py
import numpy as np
import torch
import omni.physics.tensors.impl.api as physx
from isaaclab.assets.articulation import Articulation
from isaaclab.assets.rigid_object import RigidObject
from isaaclab.envs import ManagerBasedEnv
from isaaclab.managers import EventTermCfg, ManagerTermBase
from isaaclab.sim.utils import get_current_stage_id
from isaaclab.utils import math as math_utils

# Module-level storage for Y-up object sets, keyed by resolved scene_path path.
# Populated by ScenarioSceneCfg.dynamic_setup(), consumed by ResetScenarioStochastic.
# Kept here (not on the @configclass) so Isaac Lab's entity iterator doesn't see it.
YUP_OBJECTS_REGISTRY: dict[str, set[str]] = {}


def _compute_child_frame_z_offset(
    model_path: str, child_frame: str | None, needs_yup: bool,
) -> float:
    """Compute Z offset so that the named child_frame sits at the placement point.

    For ``child_frame`` names containing "bottom", returns the negative of the
    mesh's minimum extent along the vertical axis (Z for Z-up USDs, Y for Y-up
    USDs that get an Rx(90°) correction). This lifts the object so its bottom
    aligns with the target position instead of its center.

    Returns 0.0 for unrecognised child_frame names or missing meshes.
    """
    if not child_frame or "bottom" not in child_frame.lower():
        return 0.0
    try:
        from pxr import Usd, UsdGeom  # type: ignore
        stage = Usd.Stage.Open(model_path)
        for prim in stage.Traverse():
            if prim.GetTypeName() == "Mesh":
                mesh = UsdGeom.Mesh(prim)
                pts = mesh.GetPointsAttr().Get()
                if pts:
                    if needs_yup:
                        # Y-up mesh: after Rx(90°), local Y becomes world Z
                        min_val = min(p[1] for p in pts)
                    else:
                        # Z-up mesh: Z is already vertical
                        min_val = min(p[2] for p in pts)
                    return -min_val
    except Exception:
        pass
    return 0.0


# ---------------------------------------------------------------------------
# Sampling helpers
# ---------------------------------------------------------------------------

def _sample_stochastic(spec, n: int, device: str) -> torch.Tensor:
    """Recursively resolve a stochastic spec to a ``(n,)`` or ``(n, dim)`` tensor.

    Args:
        spec: A stochastic spec dict (with ``_type`` key), a literal list, a
            scalar, or ``None``.
        n: Batch size (number of environments).
        device: Torch device string.

    Returns:
        Sampled tensor.
    """
    if spec is None:
        return torch.zeros(n, device=device)

    if isinstance(spec, (int, float)):
        return torch.full((n,), float(spec), device=device)

    if isinstance(spec, list):
        # Each element may itself be stochastic
        cols = [_sample_stochastic(elem, n, device) for elem in spec]
        return torch.stack(cols, dim=-1)  # (n, len)

    if not isinstance(spec, dict):
        raise ValueError(f"Unsupported stochastic spec: {spec}")

    stype = spec.get("_type")

    if stype == "uniform_vector":
        lo = torch.tensor(spec["min"], dtype=torch.float32, device=device)
        hi = torch.tensor(spec["max"], dtype=torch.float32, device=device)
        return lo + (hi - lo) * torch.rand(n, len(spec["min"]), device=device)

    if stype == "uniform":
        lo = float(spec["min"])
        hi = float(spec["max"])
        return torch.empty(n, device=device).uniform_(lo, hi)

    if stype == "uniform_discrete":
        values = spec["values"]
        idx = torch.randint(0, len(values), (n,), device=device)
        vals = torch.tensor(values, dtype=torch.float32, device=device)
        return vals[idx]

    raise ValueError(f"Unknown stochastic _type: {stype}")


def _sample_stochastic_string(spec, n: int) -> list[str]:
    """Sample a ``uniform_discrete_string`` spec, returning a list of *n* strings."""
    if isinstance(spec, str):
        return [spec] * n
    if isinstance(spec, dict) and spec.get("_type") == "uniform_discrete_string":
        values = spec["values"]
        idx = torch.randint(0, len(values), (n,))
        return [values[i] for i in idx.tolist()]
    raise ValueError(f"Cannot sample string from: {spec}")


def _sample_rotation(spec, n: int, device: str) -> torch.Tensor:
    """Resolve a rotation spec to quaternions ``(n, 4)`` in ``(w, x, y, z)``."""
    if spec is None:
        quat = torch.zeros(n, 4, device=device)
        quat[:, 0] = 1.0
        return quat

    if not isinstance(spec, dict):
        raise ValueError(f"Unsupported rotation spec: {spec}")

    stype = spec.get("_type")
    if stype == "rpy":
        deg_spec = spec["deg"]
        deg = _sample_stochastic(deg_spec, n, device)  # (n, 3) degrees
        rad = deg * (math.pi / 180.0)
        return math_utils.quat_from_euler_xyz(rad[:, 0], rad[:, 1], rad[:, 2])

    raise ValueError(f"Unknown rotation _type: {stype}")


# ---------------------------------------------------------------------------
# Midpoint / quaternion helpers
# ---------------------------------------------------------------------------

def _quat_mul_single(q1: tuple, q2: tuple) -> tuple:
    """Multiply two quaternions (w,x,y,z) on CPU (single values)."""
    w1, x1, y1, z1 = q1
    w2, x2, y2, z2 = q2
    return (
        w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
        w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
        w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
        w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
    )


def _midpoint_scalar(spec) -> float:
    """Resolve a stochastic scalar spec to its midpoint."""
    if spec is None:
        return 0.0
    if isinstance(spec, (int, float)):
        return float(spec)
    if isinstance(spec, dict):
        stype = spec.get("_type", "")
        if stype in ("uniform", "uniform_vector"):
            return (float(spec["min"]) + float(spec["max"])) / 2.0
        if stype in ("uniform_discrete", "uniform_discrete_string"):
            vals = spec.get("values", [0])
            return float(vals[0]) if vals else 0.0
        if stype == "gaussian":
            return float(spec.get("mean", 0.0))
    return 0.0


def _midpoint_translation(spec) -> tuple:
    """Resolve a stochastic translation to its midpoint (x, y, z)."""
    if spec is None:
        return (0.0, 0.0, 0.0)
    if isinstance(spec, list) and len(spec) == 3:
        return tuple(_midpoint_scalar(v) for v in spec)
    if isinstance(spec, dict):
        stype = spec.get("_type", "")
        if stype == "uniform_vector":
            mins = spec["min"]
            maxs = spec["max"]
            return tuple((a + b) / 2.0 for a, b in zip(mins, maxs))
    return (0.0, 0.0, 0.0)


def _midpoint_rotation(spec) -> tuple:
    """Resolve a stochastic rotation to its midpoint quaternion (w, x, y, z)."""
    if spec is None:
        return (1.0, 0.0, 0.0, 0.0)
    if isinstance(spec, dict):
        stype = spec.get("_type", "")
        if stype == "rpy":
            deg = spec.get("deg", [0, 0, 0])
            r = math.radians(_midpoint_scalar(deg[0]) if isinstance(deg, list) else 0.0)
            p = math.radians(_midpoint_scalar(deg[1]) if isinstance(deg, list) and len(deg) > 1 else 0.0)
            y = math.radians(_midpoint_scalar(deg[2]) if isinstance(deg, list) and len(deg) > 2 else 0.0)
            # RPY to quaternion (extrinsic XYZ)
            cr, sr = math.cos(r / 2), math.sin(r / 2)
            cp, sp = math.cos(p / 2), math.sin(p / 2)
            cy, sy = math.cos(y / 2), math.sin(y / 2)
            w = cr * cp * cy + sr * sp * sy
            x = sr * cp * cy - cr * sp * sy
            yy = cr * sp * cy + sr * cp * sy
            z = cr * cp * sy - sr * sp * cy
            return (w, x, yy, z)
    return (1.0, 0.0, 0.0, 0.0)


# Rx(90deg) quaternion for Y-up correction: cos(45), sin(45), 0, 0
_RX90_W = math.cos(math.pi / 4)
_RX90_X = math.sin(math.pi / 4)


def _yup_quat(n: int, device: str) -> torch.Tensor:
    """Return ``(n, 4)`` quaternion for Rx(90deg) Y-up correction."""
    q = torch.zeros(n, 4, device=device)
    q[:, 0] = _RX90_W
    q[:, 1] = _RX90_X
    return q


# ---------------------------------------------------------------------------
# Reset event term
# ---------------------------------------------------------------------------

class ResetScenarioStochastic(ManagerTermBase):
    """Reset manipulable object poses using stochastic ranges.

    Reads reset specs from ``reset_config.json`` (legacy) or ``spawn_config.json``
    (library mode). Follows the ``ManagerTermBase`` pattern.

    When ``collision_detection=True`` (default in library mode), uses
    physics-based collision detection: writes candidate poses to sim,
    steps physics once, and checks whether objects were displaced by
    PhysX collision resolution. Displaced objects get resampled.
    Uses the actual collision meshes — works for any geometry.
    """

    # Contact force threshold (N) for physics-based collision detection.
    # Collision resolution forces are large (>>10N); resting/gravity
    # forces are small (~1N for typical objects). 5N separates them well.
    DEFAULT_FORCE_THRESHOLD = 5.0

    def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
        super().__init__(cfg, env)

        scene_path: str = cfg.params["scene_path"]
        self._scene_path = scene_path
        self._library_dir: str | None = cfg.params.get("library_dir")

        # Collision detection params
        self._collision_detection: bool = cfg.params.get(
            "collision_detection", self._library_dir is not None
        )
        self._max_resample_attempts: int = cfg.params.get("max_resample_attempts", 100)
        self._force_threshold: float = cfg.params.get(
            "force_threshold", self.DEFAULT_FORCE_THRESHOLD
        )

        # Lazily created RigidContactViews for inter-object collision detection.
        # Each entry: (object_name, contact_view, filter_names) where the view
        # monitors contacts only against objects forbidden by non_collision_constraint.
        self._contact_views: list[tuple[str, physx.RigidContactView, list[str]]] | None = None

        # Load config directly from the specified path
        with open(scene_path) as f:
            self.config = json.load(f)

        # ---- reference frames lookup ----
        self.ref_frames: dict[str, dict] = {}
        for frame in self.config.get("reference_frames", []):
            self.ref_frames[frame["name"]] = frame

        # ---- collect manipulable object specs ----
        self.object_specs: list[dict] = []

        # Direct objects
        for obj in self.config.get("objects", []):
            name = obj["name"]
            entity = self._find_entity(env, name)
            if entity is None:
                print(f"[ResetScenarioStochastic] Warning: object '{name}' not found in scene, skipping")
                continue
            needs_yup = self._check_yup(env, name)
            z_offset = 0.0
            if self._library_dir and obj.get("model"):
                from pathlib import Path
                usd_path = str(Path(self._library_dir) / obj["model"])
                z_offset = _compute_child_frame_z_offset(
                    usd_path, obj.get("child_frame"), needs_yup,
                )
            is_welded = bool(obj.get("welded", False))
            self.object_specs.append({
                "name": name,
                "entity": entity,
                "translation": obj.get("translation"),
                "rotation": obj.get("rotation"),
                "base_frame": obj.get("base_frame") or obj.get("parent"),
                "needs_yup": needs_yup,
                "child_frame_z_offset": z_offset,
                "welded": is_welded,
            })

        # uniform_manipuland groups — use first variant's pose ranges
        for group in self.config.get("randomized_groups", []):
            if group.get("group_type") != "uniform_manipuland":
                continue
            name = group["name"]
            entity = self._find_entity(env, name)
            if entity is None:
                print(f"[ResetScenarioStochastic] Warning: group '{name}' not found in scene, skipping")
                continue
            variant = group["variants"][0]
            needs_yup = self._check_yup(env, name)
            z_offset = 0.0
            if self._library_dir and variant.get("model"):
                from pathlib import Path
                usd_path = str(Path(self._library_dir) / variant["model"])
                z_offset = _compute_child_frame_z_offset(
                    usd_path, variant.get("child_frame"), needs_yup,
                )
            self.object_specs.append({
                "name": name,
                "entity": entity,
                "translation": variant.get("translation"),
                "rotation": variant.get("rotation"),
                "base_frame": variant.get("base_frame") or variant.get("parent"),
                "needs_yup": needs_yup,
                "child_frame_z_offset": z_offset,
                "welded": bool(variant.get("welded", False)),
            })

        # ---- robot initial joint positions ----
        self.robot_joint_specs: dict[str, dict] = {}
        initial_pos = self.config.get("initial_position", {})

        robot_name_map = {
            "right::panda": "right_panda",
            "left::panda": "left_panda",
        }
        for config_name, scene_name in robot_name_map.items():
            if config_name in initial_pos:
                try:
                    entity = env.scene[scene_name]
                    if isinstance(entity, Articulation):
                        self.robot_joint_specs[scene_name] = {
                            "entity": entity,
                            "joint_pos": initial_pos[config_name],
                        }
                except KeyError:
                    pass

        # Parse non_collision_constraint entries → per-object set of forbidden neighbors.
        # Only these pairs get checked for collisions.
        self._collision_pairs: dict[str, set[str]] = {}
        self._parse_collision_constraints()

    def _find_entity(self, env: ManagerBasedEnv, name: str) -> RigidObject | None:
        try:
            entity = env.scene[name]
            if isinstance(entity, RigidObject):
                return entity
        except KeyError:
            pass
        return None

    def _check_yup(self, env: ManagerBasedEnv, name: str) -> bool:
        """Check if object was flagged as needing Y-up correction during dynamic_setup."""
        yup_set = YUP_OBJECTS_REGISTRY.get(self._scene_path, set())
        return name in yup_set

    def _parse_collision_constraints(self) -> None:
        """Parse ``non_collision_constraint`` entries from config.

        Builds ``self._collision_pairs``: a dict mapping each object name to
        the set of other object names it must NOT collide with. Only objects
        appearing in at least one constraint will have an entry.
        """
        obj_names = [s["name"] for s in self.object_specs]

        for constraint in self.config.get("constraints", []):
            if constraint.get("constraint_type") != "non_collision_constraint":
                continue

            # Each pattern is like "bell_pepper_0::.*" or "fruit_.*::.*"
            # The part before "::" is the object name pattern.
            patterns = constraint.get("forbid_collision_between", [])

            # Resolve patterns to actual spawned object names
            matched_names: list[str] = []
            for pat in patterns:
                # Extract the object-name portion (before "::")
                obj_pat = pat.split("::")[0] if "::" in pat else pat
                # Match against spawned objects using regex
                for name in obj_names:
                    if re.fullmatch(obj_pat, name) and name not in matched_names:
                        matched_names.append(name)

            # Every matched object is forbidden from colliding with every other
            for i, a in enumerate(matched_names):
                for b in matched_names[i + 1:]:
                    self._collision_pairs.setdefault(a, set()).add(b)
                    self._collision_pairs.setdefault(b, set()).add(a)

        if self._collision_pairs:
            print(
                f"[ResetScenarioStochastic] Collision constraints: "
                + ", ".join(
                    f"{k} vs {{{', '.join(sorted(v))}}}"
                    for k, v in sorted(self._collision_pairs.items())
                )
            )

    def _sample_pose(
        self,
        spec: dict,
        n: int,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        pose_buffers: dict[str, torch.Tensor] | None = None,
    ) -> torch.Tensor:
        """Sample a random pose for an object spec.

        Returns:
            ``(n, 7)`` tensor of ``[x, y, z, qw, qx, qy, qz]`` in world frame
            (with env origins added).
        """
        device = env.device

        frame_pos, frame_quat = self._resolve_base_frame(
            spec["base_frame"], n, device, env=env, env_ids=env_ids,
            pose_buffers=pose_buffers,
        )

        if spec["translation"] is not None:
            obj_trans = _sample_stochastic(spec["translation"], n, device)
        else:
            obj_trans = torch.zeros(n, 3, device=device)

        obj_quat = _sample_rotation(spec["rotation"], n, device)

        world_pos = frame_pos + math_utils.quat_apply(frame_quat, obj_trans)
        world_quat = math_utils.quat_mul(frame_quat, obj_quat)

        if spec["needs_yup"]:
            yup_q = _yup_quat(n, device)
            world_quat = math_utils.quat_mul(world_quat, yup_q)

        # Apply child_frame Z offset (e.g., lift "bottom" of bin to placement point)
        z_off = spec.get("child_frame_z_offset", 0.0)
        if z_off != 0.0:
            world_pos[:, 2] += z_off

        world_pos = world_pos + env.scene.env_origins[env_ids]
        return torch.cat([world_pos, world_quat], dim=-1)

    def _ensure_contact_views(self) -> None:
        """Lazily create RigidContactViews for inter-object collision detection.

        Only creates views for objects that appear in ``non_collision_constraint``
        entries, and only filters against their forbidden neighbors — not all
        other objects. Objects not in any constraint are skipped entirely.
        """
        if self._contact_views is not None:
            return

        if not self._collision_pairs:
            self._contact_views = []
            return

        stage_id = get_current_stage_id()
        sim_view = physx.create_simulation_view("torch", stage_id)
        sim_view.set_subspace_roots("/")

        # Build name → spec lookup
        spec_by_name = {s["name"]: s for s in self.object_specs}

        self._contact_views = []
        for obj_name, forbidden_set in self._collision_pairs.items():
            spec = spec_by_name.get(obj_name)
            if spec is None:
                continue

            entity: RigidObject = spec["entity"]
            sensor_pattern = re.sub(
                r"/env_\d+/", "/*/", entity.root_physx_view.prim_paths[0], count=1
            )

            # Only filter against forbidden neighbors
            filter_patterns = []
            filter_obj_names = []
            for other_name in sorted(forbidden_set):
                other_spec = spec_by_name.get(other_name)
                if other_spec is None:
                    continue
                other_entity: RigidObject = other_spec["entity"]
                filt = re.sub(
                    r"/env_\d+/", "/*/", other_entity.root_physx_view.prim_paths[0], count=1
                )
                filter_patterns.append(filt)
                filter_obj_names.append(other_name)

            if not filter_patterns:
                continue

            contact_view = sim_view.create_rigid_contact_view(
                sensor_pattern, filter_patterns=filter_patterns
            )
            self._contact_views.append((obj_name, contact_view, filter_obj_names))

        print(
            f"[ResetScenarioStochastic] Created {len(self._contact_views)} "
            f"contact views for constrained collision pairs"
        )

    def __call__(
        self,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        scene_path: str | None = None,
        library_dir: str | None = None,
        collision_detection: bool | None = None,
        max_resample_attempts: int | None = None,
    ) -> None:
        if env_ids is None:
            env_ids = torch.arange(env.scene.num_envs, device=env.device)

        n = len(env_ids)
        device = env.device

        use_collision_detection = (
            collision_detection if collision_detection is not None
            else self._collision_detection
        )
        max_attempts = (
            max_resample_attempts if max_resample_attempts is not None
            else self._max_resample_attempts
        )

        # ---- 1. Reset robot joints first (deterministic) ----
        self._reset_robot_joints(env, env_ids)

        # ---- 2. Sample initial poses for ALL objects ----
        # pose_buffers are (num_envs, 7) indexed by env_id so that subsets
        # (e.g. during collision resampling) can look up parent poses.
        # Sort by dependency order: parents before children
        def get_parent_name(spec):
            bf = spec.get("base_frame")
            if isinstance(bf, str):
                return bf.split("::")[0]
            if isinstance(bf, dict):
                # Stochastic base_frame — collect all possible parent names
                return None  # conservative: treat as no dependency
            return None

        obj_names = {s["name"] for s in self.object_specs}
        sorted_specs = []
        remaining = list(self.object_specs)
        while remaining:
            # Find specs whose parent is not in remaining (already processed or external)
            ready = [s for s in remaining if get_parent_name(s) not in obj_names or
                     get_parent_name(s) in {ss["name"] for ss in sorted_specs}]
            if not ready:
                # Circular dependency or unresolvable - just add remaining
                sorted_specs.extend(remaining)
                break
            sorted_specs.extend(ready)
            for s in ready:
                remaining.remove(s)

        num_envs = env.scene.num_envs
        zero_vel_all = torch.zeros(n, 6, device=device)
        pose_buffers: dict[str, torch.Tensor] = {}
        for spec in sorted_specs:
            pose = self._sample_pose(spec, n, env, env_ids, pose_buffers=pose_buffers)
            buf = torch.zeros(num_envs, 7, device=device)
            buf[env_ids] = pose
            pose_buffers[spec["name"]] = buf
            spec["entity"].write_root_pose_to_sim(pose, env_ids=env_ids)
            spec["entity"].write_root_velocity_to_sim(zero_vel_all, env_ids=env_ids)

        if use_collision_detection and self._collision_pairs:
            # ---- 3. Physics-based inter-object collision detection ----
            #
            # Write candidate poses to sim, step physics once, and check
            # contact forces between spawned objects via RigidContactView.
            # Only inter-object contacts are flagged — contacts with static
            # geometry (tables, shelves, station) are ignored.
            # Save/restore all state so non-resetting envs are unaffected.

            self._ensure_contact_views()

            all_env_ids = torch.arange(env.scene.num_envs, device=device)
            physics_dt = env.sim.cfg.dt

            # Save state of ALL objects across ALL envs (so we can undo
            # the extra physics step for non-resetting envs)
            saved_obj_poses: dict[str, torch.Tensor] = {}
            saved_obj_vels: dict[str, torch.Tensor] = {}
            for spec in self.object_specs:
                entity: RigidObject = spec["entity"]
                saved_obj_poses[spec["name"]] = entity.data.root_link_pose_w.clone()
                saved_obj_vels[spec["name"]] = entity.data.root_com_vel_w.clone()

            saved_robot_state: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
            for scene_name, rspec in self.robot_joint_specs.items():
                art: Articulation = rspec["entity"]
                saved_robot_state[scene_name] = (
                    art.data.joint_pos.clone(),
                    art.data.joint_vel.clone(),
                )

            for attempt in range(max_attempts):
                # Write candidate poses for resetting envs
                for spec in self.object_specs:
                    spec["entity"].write_root_pose_to_sim(
                        pose_buffers[spec["name"]][env_ids], env_ids=env_ids
                    )
                    spec["entity"].write_root_velocity_to_sim(
                        zero_vel_all, env_ids=env_ids
                    )

                # Step physics once to populate contact reports
                env.sim.step(render=False)

                # Check inter-object contact forces (per-pair)
                collision_found = False
                for obj_name, contact_view, filter_names in self._contact_views:
                    # force_matrix: (num_envs, num_filters, 3)
                    force_matrix = contact_view.get_contact_force_matrix(physics_dt)
                    # Per-filter force magnitudes: (num_envs, num_filters)
                    force_mag = force_matrix.norm(dim=-1)
                    # Check each filter separately
                    for fi, filt_name in enumerate(filter_names):
                        fmag = force_mag[:, fi]  # (num_envs,)
                        colliding = fmag[env_ids] > self._force_threshold
                        if colliding.any():
                            collision_found = True
                            n_col = int(colliding.sum().item())
                            max_f = fmag[env_ids][colliding].max().item()
                            print(
                                f"[CollisionCheck] attempt {attempt}: "
                                f"{obj_name} <-> {filt_name} in {n_col} envs "
                                f"(max_force={max_f:.1f}N)"
                            )
                            # Resample only non-welded objects
                            col_env_ids = env_ids[colliding]
                            for resample_name in (obj_name, filt_name):
                                spec = next(
                                    (s for s in self.object_specs if s["name"] == resample_name),
                                    None,
                                )
                                if spec is None:
                                    continue
                                if spec.get("welded", False):
                                    continue
                                new_pose = self._sample_pose(
                                    spec, n_col, env, col_env_ids,
                                    pose_buffers=pose_buffers,
                                )
                                pose_buffers[resample_name][col_env_ids] = new_pose

                # Restore ALL envs to pre-step state
                for spec in self.object_specs:
                    spec["entity"].write_root_pose_to_sim(
                        saved_obj_poses[spec["name"]], env_ids=all_env_ids
                    )
                    spec["entity"].write_root_velocity_to_sim(
                        saved_obj_vels[spec["name"]], env_ids=all_env_ids
                    )
                for scene_name, rspec in self.robot_joint_specs.items():
                    art: Articulation = rspec["entity"]
                    jpos, jvel = saved_robot_state[scene_name]
                    art.write_joint_state_to_sim(jpos, jvel, env_ids=all_env_ids)

                if not collision_found:
                    if attempt > 0:
                        print(f"[CollisionCheck] converged after {attempt} attempts")
                    break
            else:
                print(
                    f"[CollisionCheck] WARNING: did not converge after "
                    f"{max_attempts} attempts"
                )

            # ---- 4. Write final validated poses ----
            for spec in self.object_specs:
                spec["entity"].write_root_pose_to_sim(
                    pose_buffers[spec["name"]][env_ids], env_ids=env_ids
                )
                spec["entity"].write_root_velocity_to_sim(zero_vel_all, env_ids=env_ids)
        else:
            # ---- No collision detection: poses already written in step 2 ----
            pass

    def _reset_robot_joints(
        self, env: ManagerBasedEnv, env_ids: torch.Tensor
    ) -> None:
        """Reset robot joints to initial positions from config."""
        for scene_name, rspec in self.robot_joint_specs.items():
            articulation: Articulation = rspec["entity"]
            joint_pos_cfg: dict = rspec["joint_pos"]

            default_pos = articulation.data.default_joint_pos[env_ids].clone()
            joint_names = articulation.joint_names

            for jname, jvals in joint_pos_cfg.items():
                val = jvals[0] if isinstance(jvals, list) else jvals
                if jname in joint_names:
                    idx = joint_names.index(jname)
                    default_pos[:, idx] = float(val)

            zero_vel = torch.zeros_like(default_pos)
            articulation.write_joint_state_to_sim(
                default_pos, zero_vel, env_ids=env_ids
            )

    def _resolve_base_frame(
        self, base_frame_spec, n: int, device: str,
        env: ManagerBasedEnv | None = None,
        env_ids: torch.Tensor | None = None,
        pose_buffers: dict[str, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Resolve a base_frame spec to ``(pos, quat)`` each ``(n, 3)`` / ``(n, 4)``.

        If the base_frame is a string like ``"fruit_bowl::origin"``, resolves
        to that object's already-sampled pose from ``pose_buffers`` (preferred)
        or falls back to the current sim pose.
        """
        if base_frame_spec is None:
            pos = torch.zeros(n, 3, device=device)
            quat = torch.zeros(n, 4, device=device)
            quat[:, 0] = 1.0
            return pos, quat

        if isinstance(base_frame_spec, str):
            obj_name = base_frame_spec.split("::")[0]

            # Check if this is a named reference frame first
            if obj_name in self.ref_frames:
                frame = self.ref_frames[obj_name]
                # Recursively resolve the reference frame's own base_frame
                parent_pos, parent_quat = self._resolve_base_frame(
                    frame.get("base_frame"), n, device, env, env_ids, pose_buffers
                )
                # Add the reference frame's translation offset
                ft = _sample_stochastic(frame.get("translation"), n, device)
                if ft.dim() == 1:
                    ft = ft.unsqueeze(0).expand(n, -1)
                fq = _sample_rotation(frame.get("rotation"), n, device)
                pos = parent_pos + ft
                quat = math_utils.quat_mul(parent_quat, fq)
                return pos, quat

            # Use parent's POSITION only — child translations are world-aligned,
            # not rotated into the parent's frame.
            if pose_buffers is not None and obj_name in pose_buffers:
                # pose_buffers are (num_envs, 7) indexed by env_id
                parent_pose = pose_buffers[obj_name][env_ids]  # (n, 7)
                env_origins = env.scene.env_origins[env_ids] if env is not None and env_ids is not None else 0
                pos = parent_pose[:, :3] - env_origins
                quat = torch.zeros(n, 4, device=device)
                quat[:, 0] = 1.0
                return pos, quat

            # Fallback: unresolved frame → world origin
            pos = torch.zeros(n, 3, device=device)
            quat = torch.zeros(n, 4, device=device)
            quat[:, 0] = 1.0
            return pos, quat

        # uniform_discrete_string → sample a frame name per env, then resolve each
        if isinstance(base_frame_spec, dict) and base_frame_spec.get("_type") == "uniform_discrete_string":
            frame_names = _sample_stochastic_string(base_frame_spec, n)
            pos = torch.zeros(n, 3, device=device)
            quat = torch.zeros(n, 4, device=device)
            quat[:, 0] = 1.0

            for i, fname in enumerate(frame_names):
                if fname not in self.ref_frames:
                    continue
                frame = self.ref_frames[fname]
                # Frame translation (may be stochastic or literal)
                ft = _sample_stochastic(frame.get("translation"), 1, device)  # (1, 3) or (1,)
                if ft.dim() == 1:
                    ft = ft.unsqueeze(0)
                pos[i] = ft[0]

                # Frame rotation
                fq = _sample_rotation(frame.get("rotation"), 1, device)  # (1, 4)
                quat[i] = fq[0]

            return pos, quat

        raise ValueError(f"Unsupported base_frame spec: {base_frame_spec}")


# ---------------------------------------------------------------------------
# Dataset-based reset
# ---------------------------------------------------------------------------

class ResetFromDataset(ManagerTermBase):
    """Reset event that samples states from an HDF5 demonstration dataset.

    Loads all states from every timestep of every demo into GPU tensors.
    On each reset, randomly samples from that pool for the given env_ids
    and applies them via ``scene.reset_to()``.

    Params:
        dataset_path: Path to the HDF5 file.
        initial_only: If True (default), only load timestep-0 from each demo.
            If False, load every timestep from every demo.
    """

    def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
        super().__init__(cfg, env)
        dataset_path = cfg.params["dataset_path"]
        initial_only = cfg.params.get("initial_only", True)
        self._device = env.device
        self._batched, self._progress, self._num_states = self._load_and_precompute(dataset_path, initial_only)
        mode = "initial states" if initial_only else "all timesteps"
        print(f"[ResetFromDataset] Loaded {self._num_states} states ({mode}) from {dataset_path}")
        # Filled by __call__ so eval scripts can read back per-env reset progress
        self._last_sampled_progress = torch.zeros(env.num_envs, device=self._device)

    def _load_and_precompute(self, path: str, initial_only: bool) -> tuple[dict, torch.Tensor, int]:
        """Load states from HDF5 and stack into batched GPU tensors.

        Returns:
            (batched, progress, num_states) where batched is
            {entity_type: {entity_name: {component: (N, dim) tensor}}}
            and progress is a (N,) tensor of normalized demo progress [0, 1].
        """
        # First pass: collect all state arrays
        all_arrays: dict[str, dict[str, dict[str, list]]] = {}
        all_progress: list[np.ndarray] = []
        num_states = 0

        with h5py.File(path, "r") as f:
            for demo_key in sorted(f["data"].keys()):
                states_path = f"data/{demo_key}/states"
                if states_path not in f:
                    continue
                sg = f[states_path]
                # Determine how many timesteps to take
                # Peek at any leaf dataset to get the time dimension
                sample_ds = None
                for etype in sg.keys():
                    for ename in sg[etype].keys():
                        for comp in sg[etype][ename].keys():
                            sample_ds = sg[etype][ename][comp]
                            break
                        if sample_ds is not None:
                            break
                    if sample_ds is not None:
                        break
                if sample_ds is None:
                    continue

                T_full = sample_ds.shape[0]
                T = 1 if initial_only else T_full
                progress = np.linspace(0.0, 1.0, T_full)[:T] if T_full > 1 else np.array([0.0])
                all_progress.append(progress)

                for entity_type in sg.keys():
                    if entity_type not in all_arrays:
                        all_arrays[entity_type] = {}
                    for entity_name in sg[entity_type].keys():
                        if entity_name not in all_arrays[entity_type]:
                            all_arrays[entity_type][entity_name] = {}
                        for component in sg[entity_type][entity_name].keys():
                            if component not in all_arrays[entity_type][entity_name]:
                                all_arrays[entity_type][entity_name][component] = []
                            # Grab [0:T] — either just t=0 or all timesteps
                            all_arrays[entity_type][entity_name][component].append(
                                sg[entity_type][entity_name][component][:T]
                            )
                num_states += T

        # Second pass: concatenate and move to GPU
        batched: dict[str, dict[str, dict[str, torch.Tensor]]] = {}
        for etype in all_arrays:
            batched[etype] = {}
            for ename in all_arrays[etype]:
                batched[etype][ename] = {}
                for comp in all_arrays[etype][ename]:
                    stacked = np.concatenate(all_arrays[etype][ename][comp], axis=0)
                    batched[etype][ename][comp] = (
                        torch.from_numpy(stacked).float().to(self._device)
                    )

        progress_tensor = torch.from_numpy(
            np.concatenate(all_progress, axis=0)
        ).float().to(self._device)

        return batched, progress_tensor, num_states

    def __call__(
        self,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        dataset_path: str | None = None,
        initial_only: bool = True,
    ) -> None:
        if env_ids is None:
            env_ids = torch.arange(env.scene.num_envs, device=env.device)

        n = len(env_ids)
        # Sample random state indices for each env being reset
        idx = torch.randint(0, self._num_states, (n,), device=self._device)
        self._last_sampled_progress[env_ids] = self._progress[idx]

        # Build the state dict for these envs by indexing into pre-batched tensors
        sampled_state: dict[str, dict[str, dict[str, torch.Tensor]]] = {}
        for etype in self._batched:
            sampled_state[etype] = {}
            for ename in self._batched[etype]:
                sampled_state[etype][ename] = {}
                for comp in self._batched[etype][ename]:
                    sampled_state[etype][ename][comp] = self._batched[etype][ename][comp][idx]

        env.scene.reset_to(sampled_state, env_ids=env_ids, is_relative=True)


class ResetCurriculum(ManagerTermBase):
    """Reverse-curriculum reset: start near the goal, push back as success improves.

    Loads all demo timesteps and tracks normalized progress (t/T ∈ [0,1]) for
    each state. A moving ``frontier`` controls the earliest progress we sample
    from. When recent success rate exceeds ``advance_threshold``, the frontier
    retreats (toward 0); when it drops below ``retreat_threshold``, it advances
    (toward 1).

    Params:
        dataset_path: Path to the HDF5 file.
        initial_frontier: Starting frontier value (default 0.8 = last 20% of demos).
        advance_threshold: Success rate above which frontier moves back.
        retreat_threshold: Success rate below which frontier moves forward.
        step_size: How much to move the frontier per update.
        min_frontier: Minimum frontier value (0.0 = full trajectory).
    """

    def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
        super().__init__(cfg, env)
        dataset_path = cfg.params["dataset_path"]
        self._device = env.device

        self.frontier: float = cfg.params.get("initial_frontier", 0.8)
        self.advance_threshold: float = cfg.params.get("advance_threshold", 0.7)
        self.retreat_threshold: float = cfg.params.get("retreat_threshold", 0.3)
        self.step_size: float = cfg.params.get("step_size", 0.05)
        self.min_frontier: float = cfg.params.get("min_frontier", 0.0)

        self._batched, self._progress, self._num_states = self._load(dataset_path)
        print(
            f"[ResetCurriculum] Loaded {self._num_states} states, "
            f"initial frontier={self.frontier:.2f}"
        )

        # Per-env success tracking (unbiased).
        # The old sliding-window approach was biased: successful episodes are
        # shorter, so successful envs reset more often and contribute more
        # True entries, inflating the measured success rate.
        # Instead, each env stores the outcome of its most recent episode.
        # Success rate = mean over all envs, giving each env equal weight.
        self._num_envs = env.num_envs
        self._per_env_success = torch.zeros(env.num_envs, dtype=torch.bool, device=env.device)
        self._per_env_has_episode = torch.zeros(env.num_envs, dtype=torch.bool, device=env.device)
        self._episodes_since_check: int = 0
        # How many episodes between frontier checks (controls reactivity).
        default_window = max(env.num_envs * 2, 500)
        self._check_interval: int = cfg.params.get("window_size", default_window)
        self._graduated: bool = False  # True once frontier=0.0 is mastered → use normal resets

        # Build a stochastic reset instance for use after graduation.
        scene_path = cfg.params.get("scene_path", "")
        library_dir = cfg.params.get("library_dir", "")
        if scene_path:
            from isaaclab.managers import EventTermCfg as _EventTermCfg
            stochastic_cfg = _EventTermCfg(
                func=ResetScenarioStochastic,
                mode="reset",
                params={"scene_path": scene_path, "library_dir": library_dir},
            )
            self._stochastic_reset: ResetScenarioStochastic | None = ResetScenarioStochastic(stochastic_cfg, env)
            print("[ResetCurriculum] Stochastic reset ready for post-graduation")
        else:
            self._stochastic_reset = None
            print("[ResetCurriculum] WARNING: no scene_path param, graduation will use dataset initial states")

        print(f"[ResetCurriculum] check_interval={self._check_interval} (num_envs={env.num_envs})")

    def _load(self, path: str):
        """Load all demo timesteps with progress labels."""
        all_arrays: dict[str, dict[str, dict[str, list]]] = {}
        all_progress: list[np.ndarray] = []
        num_states = 0

        with h5py.File(path, "r") as f:
            for demo_key in sorted(f["data"].keys()):
                states_path = f"data/{demo_key}/states"
                if states_path not in f:
                    continue
                sg = f[states_path]

                # Find trajectory length from any leaf dataset
                sample_ds = None
                for etype in sg.keys():
                    for ename in sg[etype].keys():
                        for comp in sg[etype][ename].keys():
                            sample_ds = sg[etype][ename][comp]
                            break
                        if sample_ds is not None:
                            break
                    if sample_ds is not None:
                        break
                if sample_ds is None:
                    continue

                T = sample_ds.shape[0]
                # Normalized progress: 0.0 = start, 1.0 = end of demo
                progress = np.linspace(0.0, 1.0, T) if T > 1 else np.array([1.0])
                all_progress.append(progress)

                for entity_type in sg.keys():
                    if entity_type not in all_arrays:
                        all_arrays[entity_type] = {}
                    for entity_name in sg[entity_type].keys():
                        if entity_name not in all_arrays[entity_type]:
                            all_arrays[entity_type][entity_name] = {}
                        for component in sg[entity_type][entity_name].keys():
                            if component not in all_arrays[entity_type][entity_name]:
                                all_arrays[entity_type][entity_name][component] = []
                            all_arrays[entity_type][entity_name][component].append(
                                sg[entity_type][entity_name][component][:]
                            )
                num_states += T

        # Stack and move to GPU
        batched: dict[str, dict[str, dict[str, torch.Tensor]]] = {}
        for etype in all_arrays:
            batched[etype] = {}
            for ename in all_arrays[etype]:
                batched[etype][ename] = {}
                for comp in all_arrays[etype][ename]:
                    stacked = np.concatenate(all_arrays[etype][ename][comp], axis=0)
                    batched[etype][ename][comp] = (
                        torch.from_numpy(stacked).float().to(self._device)
                    )

        progress_tensor = torch.from_numpy(
            np.concatenate(all_progress, axis=0)
        ).float().to(self._device)

        return batched, progress_tensor, num_states

    def _reset_tracking(self):
        """Clear per-env tracking (call when frontier or graduation state changes)."""
        self._per_env_has_episode[:] = False
        self._episodes_since_check = 0

    def _try_update_frontier(self):
        """Check success rate from per-env outcomes and update frontier/graduation.

        Only acts when enough episodes have been collected and all envs have
        completed at least one episode at the current setting.
        """
        if self._episodes_since_check < self._check_interval:
            return
        if not self._per_env_has_episode.all():
            return

        rate = self._per_env_success.float().mean().item()
        self._episodes_since_check = 0  # reset counter, keep per-env data
        old = self.frontier

        if rate >= self.advance_threshold:
            if self.frontier <= self.min_frontier:
                if not self._graduated:
                    self._graduated = True
                    self._reset_tracking()
                    print(
                        f"[ResetCurriculum] success_rate={rate:.2f}, "
                        f"frontier={self.frontier:.2f} mastered — switching to normal resets"
                    )
                    try:
                        import wandb
                        if wandb.run is not None:
                            wandb.log({
                                "curriculum/frontier": -1.0,
                                "curriculum/success_rate": rate,
                                "curriculum/graduated": 1,
                            }, commit=False)
                    except Exception:
                        pass
                return
            self.frontier = max(self.frontier - self.step_size, self.min_frontier)
        elif rate < self.retreat_threshold:
            if self._graduated:
                self._graduated = False
                self._reset_tracking()
                print(
                    f"[ResetCurriculum] success_rate={rate:.2f}, "
                    f"un-graduating — resuming dataset resets at frontier={self.frontier:.2f}"
                )
                try:
                    import wandb
                    if wandb.run is not None:
                        wandb.log({
                            "curriculum/frontier": self.frontier,
                            "curriculum/success_rate": rate,
                            "curriculum/graduated": 0,
                        }, commit=False)
                except Exception:
                    pass
                return
            self.frontier = min(self.frontier + self.step_size, 1.0)

        if self.frontier != old:
            self._reset_tracking()
            print(
                f"[ResetCurriculum] success_rate={rate:.2f}, "
                f"frontier {old:.2f} -> {self.frontier:.2f}"
            )
            try:
                import wandb
                if wandb.run is not None:
                    wandb.log({
                        "curriculum/frontier": self.frontier,
                        "curriculum/success_rate": rate,
                    }, commit=False)
            except Exception:
                pass
        else:
            # Log rate even when frontier doesn't change
            try:
                import wandb
                if wandb.run is not None:
                    wandb.log({
                        "curriculum/frontier": -1.0 if self._graduated else self.frontier,
                        "curriculum/success_rate": rate,
                    }, commit=False)
            except Exception:
                pass

    def _check_success_and_update(self, env: ManagerBasedEnv, env_ids: torch.Tensor):
        """Check if resetting envs succeeded and update per-env tracking.

        This is called from ``__call__`` which runs inside ``_reset_idx``, right
        after ``termination_manager.compute()`` — so ``get_term`` is fresh.

        Each env gets exactly one vote (its most recent episode outcome),
        so short successful episodes don't inflate the measured success rate.
        """
        try:
            success_mask = env.termination_manager.get_term("success")  # (num_envs,) bool
            self._per_env_success[env_ids] = success_mask[env_ids]
            self._per_env_has_episode[env_ids] = True
            self._episodes_since_check += len(env_ids)
        except (ValueError, KeyError):
            return

        self._try_update_frontier()

    def __call__(
        self,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        dataset_path: str | None = None,
        scene_path: str | None = None,
        library_dir: str | None = None,
        initial_frontier: float = 0.8,
        advance_threshold: float = 0.7,
        retreat_threshold: float = 0.3,
        step_size: float = 0.05,
        window_size: int = 200,
    ) -> None:
        if env_ids is None:
            env_ids = torch.arange(env.scene.num_envs, device=env.device)

        # Update curriculum based on whether these envs succeeded
        self._check_success_and_update(env, env_ids)

        n = len(env_ids)

        if self._graduated:
            # Graduated: use the full stochastic reset (same as default env reset)
            if self._stochastic_reset is not None:
                self._stochastic_reset(env, env_ids)
                return
            # Fallback if no stochastic reset: sample from dataset initial states
            initial_mask = self._progress <= 0.0
            initial_indices = torch.where(initial_mask)[0]
            if len(initial_indices) == 0:
                _, initial_indices = torch.topk(self._progress, min(n, self._num_states), largest=False)
            pick = torch.randint(0, len(initial_indices), (n,), device=self._device)
            idx = initial_indices[pick]
        else:
            # Filter to states at or beyond the current frontier
            valid_mask = self._progress >= self.frontier
            valid_indices = torch.where(valid_mask)[0]

            if len(valid_indices) == 0:
                # Fallback: sample from everything
                idx = torch.randint(0, self._num_states, (n,), device=self._device)
            else:
                # Sort valid indices by progress so index 0 = frontier (hardest)
                valid_progress = self._progress[valid_indices]
                sorted_order = torch.argsort(valid_progress)
                valid_indices = valid_indices[sorted_order]

                # Geometric distribution: P(k) ∝ (1-p)^k, concentrates mass near k=0 (frontier)
                # p=0.05 → ~50% of samples from the first ~14 entries, gentle falloff
                geom_p = 0.05
                k = len(valid_indices)
                probs = torch.pow(1.0 - geom_p, torch.arange(k, device=self._device, dtype=torch.float32))
                pick = torch.multinomial(probs, n, replacement=True)
                idx = valid_indices[pick]

        sampled_state: dict[str, dict[str, dict[str, torch.Tensor]]] = {}
        for etype in self._batched:
            sampled_state[etype] = {}
            for ename in self._batched[etype]:
                sampled_state[etype][ename] = {}
                for comp in self._batched[etype][ename]:
                    sampled_state[etype][ename][comp] = self._batched[etype][ename][comp][idx]

        env.scene.reset_to(sampled_state, env_ids=env_ids, is_relative=True)
