from pxr import Usd, UsdGeom
from pxr import Gf
from omni.usd import get_context
import numpy as np
import torch

import isaacsim.core.utils.bounds as bounds_utils
import sim_improvement.environments.mdp as mdp
from isaaclab.envs.manager_based_rl_env import ManagerBasedRLEnv
from isaaclab.managers import SceneEntityCfg, RewardTermCfg, ManagerTermBase
from isaaclab.assets import RigidObject
from isaaclab.assets import Articulation
from isaaclab.utils import math as math_utils
from isaaclab.sensors import FrameTransformer


def _check_point_in_obb(points, centroids, axes, half_extents, check_axes=(0, 1, 2)) -> torch.Tensor:
    """Check if points are inside Oriented Bounding Boxes along specified axes.

    Args:
        points: Points to test (num_envs, 3) - torch tensor on GPU
        centroids: Centers of OBBs (num_envs, 3) - torch tensor on GPU
        axes: Orientation axes of OBBs (num_envs, 3, 3) - torch tensor on GPU
        half_extents: Half extents of OBB (3,) - torch tensor on GPU
        check_axes: Tuple of OBB axis indices to check. Default (0, 1, 2) checks all three.
            E.g. (0, 1) to only check the first two axes (xy plane of the OBB).

    Returns:
        torch.Tensor: Boolean tensor (num_envs,) - True if point is inside OBB on checked axes
    """
    d = points - centroids  # (num_envs, 3)
    projections = torch.abs(torch.bmm(d.unsqueeze(1), axes).squeeze(1))  # (num_envs, 3)
    check_axes = list(check_axes)
    return (projections[:, check_axes] <= half_extents[check_axes].unsqueeze(0)).all(dim=1)  # (num_envs,)


class point_in_obb_reward(ManagerTermBase):
    """Reward term that returns 1.0 when the insertive object's centroid is inside the receptacle's OBB."""

    def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
        super().__init__(cfg, env)

        self._env = env

        self.insertive_object_cfg = cfg.params.get("insertive_asset_cfg")
        self.receptacle_object_cfg = cfg.params.get("receptive_asset_cfg")
        self.insertive_object = env.scene[self.insertive_object_cfg.name]
        self.receptacle_object = env.scene[self.receptacle_object_cfg.name]

        self._bbox_cache = bounds_utils.create_bbox_cache()
        self._compute_receptacle_obb()

    def _compute_receptacle_obb(self):
        """Compute OBB for the receptacle object and convert to body frame."""
        receptacle_base_path = self.receptacle_object.cfg.prim_path.replace(".*", "0", 1)

        centroid_world, axes_world, half_extents = bounds_utils.compute_obb(
            self._bbox_cache, receptacle_base_path
        )

        pos_world = self.receptacle_object.data.root_pos_w[0]
        quat_world = self.receptacle_object.data.root_quat_w[0]
        device = self._env.device

        centroid_world_t = torch.tensor(centroid_world, device=device, dtype=torch.float32)
        axes_world_t = torch.tensor(axes_world, device=device, dtype=torch.float32)

        centroid_body = math_utils.quat_apply_inverse(
            quat_world, centroid_world_t - pos_world
        )

        rot_matrix_world = math_utils.matrix_from_quat(quat_world.unsqueeze(0))[0]
        axes_body = torch.matmul(rot_matrix_world.T, axes_world_t.T).T

        self._receptacle_obb_centroid = centroid_body
        self._receptacle_obb_axes = axes_body
        self._receptacle_obb_half_extents = torch.tensor(half_extents, device=device, dtype=torch.float32)

    def __call__(
        self,
        env: ManagerBasedRLEnv,
        insertive_object_cfg: SceneEntityCfg,
        receptacle_object_cfg: SceneEntityCfg,
        check_axes: tuple[int, ...] = (0, 1, 2),
    ) -> torch.Tensor:
        """Returns 1.0 if insertive object's centroid is inside receptacle's OBB, 0.0 otherwise.

        Args:
            check_axes: OBB axis indices to check. Default (0, 1, 2) checks all three.
                E.g. (0, 1) to only check the xy plane of the OBB.
        """
        insertive_pos = self.insertive_object.data.root_pos_w  # (num_envs, 3)

        receptacle_pos = self.receptacle_object.data.root_pos_w  # (num_envs, 3)
        receptacle_quat = self.receptacle_object.data.root_quat_w  # (num_envs, 4)

        receptacle_world_centroids = receptacle_pos + math_utils.quat_apply(
            receptacle_quat, self._receptacle_obb_centroid.unsqueeze(0).expand(env.num_envs, -1)
        )

        receptacle_rot_matrices = math_utils.matrix_from_quat(receptacle_quat)
        receptacle_world_axes = torch.bmm(
            receptacle_rot_matrices,
            self._receptacle_obb_axes.unsqueeze(0).expand(env.num_envs, -1, -1).transpose(1, 2),
        ).transpose(1, 2)

        inside = _check_point_in_obb(
            insertive_pos,
            receptacle_world_centroids,
            receptacle_world_axes,
            self._receptacle_obb_half_extents,
            check_axes=check_axes,
        )
        return inside.float()


class ProgressContext(ManagerTermBase):
    def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
        super().__init__(cfg, env)
        self.insertive_asset: Articulation | RigidObject = env.scene[cfg.params.get("insertive_asset_cfg").name]  # type: ignore
        self.receptive_asset: Articulation | RigidObject = env.scene[cfg.params.get("receptive_asset_cfg").name]  # type: ignore

        self.insertive_point_in_receptive_obb = torch.zeros((env.num_envs), device=env.device, dtype=torch.bool)
        self.finger_open = torch.zeros((env.num_envs), device=env.device, dtype=torch.bool)

        self._point_in_obb_reward = point_in_obb_reward(cfg, env)


    def reset(self, env_ids: torch.Tensor | None = None) -> None:
        super().reset(env_ids)

        self.insertive_point_in_receptive_obb[:] = 0.0
        self.finger_open[:] = 0.0

    def __call__(
        self,
        env: ManagerBasedRLEnv,
        insertive_asset_cfg: SceneEntityCfg,
        receptive_asset_cfg: SceneEntityCfg,
        obb_check_axes: tuple[int, ...] = (0, 1),
        open_finger_threshold: float = 0.07,
        robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
    ) -> torch.Tensor:

        point_in_xy_obb_bool = self._point_in_obb_reward(env, insertive_asset_cfg, receptive_asset_cfg, check_axes=obb_check_axes)
        self.insertive_point_in_receptive_obb[:] = point_in_xy_obb_bool

        finger_open_bool = gripper_open(env, robot_cfg=robot_cfg, open_finger_threshold=open_finger_threshold)
        self.finger_open[:] = finger_open_bool

        return torch.zeros(env.num_envs, device=env.device)

def dense_ee_to_object_distance(env: ManagerBasedRLEnv, context: str = "progress_context", std: float = 1.0) -> torch.Tensor:
    context_term: ManagerTermBase = env.reward_manager.get_term_cfg(context).func  # type: ignore
    ee_to_object_distance: torch.Tensor = getattr(context_term, "ee_to_object_distance")
    return 1 - torch.tanh(ee_to_object_distance / std)

def dense_lift_height(env: ManagerBasedRLEnv, context: str = "progress_context", threshold: float = 0.04) -> torch.Tensor:
    context_term: ManagerTermBase = env.reward_manager.get_term_cfg(context).func  # type: ignore
    lift_height: torch.Tensor = getattr(context_term, "lift_height")
    ee2object_gate: torch.Tensor = getattr(context_term, "ee_to_object_distance") < 0.05
    return (lift_height > threshold) & ee2object_gate

def dense_object_to_goal_distance(env: ManagerBasedRLEnv, context: str = "progress_context", std: float = 1.0) -> torch.Tensor:
    context_term: ManagerTermBase = env.reward_manager.get_term_cfg(context).func  # type: ignore
    object_to_goal_distance: torch.Tensor = getattr(context_term, "object_to_goal_distance")
    # lift_gate: torch.Tensor = getattr(context_term, "lift_height") > 0.04
    lift_gate = dense_lift_height(env, context, threshold=0.04)
    return torch.where(lift_gate, 1 - torch.tanh(object_to_goal_distance / std), 0.0)

def gripper_open(
    env: ManagerBasedRLEnv,
    robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
    open_finger_threshold: float = 0.02
    ) -> torch.Tensor:
    robot: Articulation = env.scene[robot_cfg.name]
    finger_joint = robot.data.joint_pos[:,robot.data.joint_names.index("finger_joint")]
    print(f"Finger joint: {finger_joint}")
    return finger_joint < open_finger_threshold

def lift_reward(
    env: ManagerBasedRLEnv,
    object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
    ) -> torch.Tensor:
    icecream: RigidObject = env.scene[object_cfg.name]
    icecream_pos = icecream.data.root_pos_w
    default_pos = icecream.data.default_root_state[:, :3]
    lift_height = icecream_pos[:, 2] - default_pos[:, 2]
    return torch.tanh(lift_height)

def object_to_object_distance(
    env: ManagerBasedRLEnv,
    std: float,
    object1_cfg: SceneEntityCfg = SceneEntityCfg("object1"),
    object2_cfg: SceneEntityCfg = SceneEntityCfg("object2"),
) -> torch.Tensor:
    object1: RigidObject = env.scene[object1_cfg.name]
    object2: RigidObject = env.scene[object2_cfg.name]
    object1_pos = object1.data.root_pos_w
    object2_pos = object2.data.root_pos_w
    distance = torch.norm(object1_pos - object2_pos, dim=1)
    return 1 - torch.tanh(distance / std)

def object_ee_distance(
    env: ManagerBasedRLEnv,
    std: float,
    object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
    ee_frame_cfg: SceneEntityCfg = SceneEntityCfg("ee_frame"),
) -> torch.Tensor:
    """Reward the agent for reaching the object using tanh-kernel."""
    # extract the used quantities (to enable type-hinting)
    object: RigidObject = env.scene[object_cfg.name]
    ee_frame: FrameTransformer = env.scene[ee_frame_cfg.name]
    # Target object position: (num_envs, 3)
    cube_pos_w = object.data.root_pos_w
    # End-effector position: (num_envs, 3)
    ee_w = ee_frame.data.target_pos_w[..., 0, :]
    # Distance of the end-effector to the object: (num_envs,)
    object_ee_distance = torch.norm(cube_pos_w - ee_w, dim=1)

    return 1 - torch.tanh(object_ee_distance / std)


def ee_to_object_distance(
    env: ManagerBasedRLEnv,
    std: float,
    object_cfg: SceneEntityCfg,
    left_arm_cfg: SceneEntityCfg = SceneEntityCfg("left_panda"),
    right_arm_cfg: SceneEntityCfg = SceneEntityCfg("right_panda"),
    link_name: str = "panda_link8",
) -> torch.Tensor:
    """Reward for the closest EE reaching an object, using tanh-kernel.

    Computes min distance across both arms so either can be guided to the object.
    """
    obj: RigidObject = env.scene[object_cfg.name]
    obj_pos = obj.data.root_pos_w  # (num_envs, 3)

    left_arm: Articulation = env.scene[left_arm_cfg.name]
    left_idx = left_arm.data.body_names.index(link_name)
    left_ee = left_arm.data.body_link_pos_w[:, left_idx]  # (num_envs, 3)

    right_arm: Articulation = env.scene[right_arm_cfg.name]
    right_idx = right_arm.data.body_names.index(link_name)
    right_ee = right_arm.data.body_link_pos_w[:, right_idx]  # (num_envs, 3)

    dist_left = torch.norm(obj_pos - left_ee, dim=1)
    dist_right = torch.norm(obj_pos - right_ee, dim=1)
    min_dist = torch.min(dist_left, dist_right)

    return 1 - torch.tanh(min_dist / std)

def action_l2_clamped(env: ManagerBasedRLEnv) -> torch.Tensor:
    """Penalize the actions using L2 squared kernel."""
    return torch.clamp(torch.sum(torch.square(env.action_manager.action), dim=1), 0, 1e4)


def action_rate_l2_clamped(env: ManagerBasedRLEnv) -> torch.Tensor:
    """Penalize the rate of change of the actions using L2 squared kernel."""
    return torch.clamp(
        torch.sum(torch.square(env.action_manager.action - env.action_manager.prev_action), dim=1), 0, 1e4
    )


def joint_vel_l2_clamped(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
    """Penalize joint velocities on the articulation using L2 squared kernel.

    NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint velocities contribute to the term.
    """
    # extract the used quantities (to enable type-hinting)
    asset: Articulation = env.scene[asset_cfg.name]
    return torch.clamp(torch.sum(torch.square(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1), 0, 1e4)


def success_reward(
    env: ManagerBasedRLEnv,
    termination_term_name: str = "success",
) -> torch.Tensor:

    success_tensor_bool = env.termination_manager.get_term_cfg(termination_term_name).func(env)
    return success_tensor_bool.float()


def _min_singular_value(
    asset: Articulation,
    body_name: str = "panda_link8",
    joint_names: list[str] | None = None,
) -> torch.Tensor:
    """Compute the minimum singular value of the EE Jacobian.

    Returns:
        (num_envs,) tensor of minimum singular values. Near 0 = near singularity.
    """
    body_ids, _ = asset.find_bodies(body_name)
    body_idx = body_ids[0]
    if joint_names:
        joint_ids, _ = asset.find_joints(joint_names)
    else:
        joint_ids = list(range(asset.num_joints))

    jacobi_body_idx = body_idx - 1 if asset.is_fixed_base else body_idx
    jacobi_joint_ids = joint_ids if asset.is_fixed_base else [i + 6 for i in joint_ids]

    J = asset.root_physx_view.get_jacobians()[:, jacobi_body_idx, :, jacobi_joint_ids]  # (N, 6, num_joints)
    S = torch.linalg.svdvals(J)  # (N, min(6, num_joints))
    return S[:, -1]  # (N,)


def singularity_penalty(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
    threshold: float = 0.05,
) -> torch.Tensor:
    """Negative reward when an arm approaches singularity.

    Returns -1.0 for envs where min singular value < threshold, 0.0 otherwise.
    Multiply by a large negative weight in the reward config.
    """
    asset: Articulation = env.scene[asset_cfg.name]
    min_sv = _min_singular_value(asset)
    return (min_sv < threshold).float()