import torch
from isaaclab.envs.manager_based_rl_env import ManagerBasedRLEnv, ManagerBasedEnv
from isaaclab.assets import Articulation
from isaaclab.managers import SceneEntityCfg
from isaaclab.assets import RigidObject
from isaaclab.utils import math as math_utils

def ee_pose_obs(env) -> torch.Tensor:
    ee_frame = env.scene["ee_frame"]
    ee_pos = ee_frame.data.target_pos_w[
        :, 0
    ]  # get the position of the end effector (from world frame)
    ee_quat = ee_frame.data.target_quat_w[
        :, 0
    ]  # get the orientation of the end effector (from world frame)
    return torch.cat([ee_pos, ee_quat], dim=-1)

def selected_joint_pos(
    env: ManagerBasedEnv,
    asset_cfg: SceneEntityCfg,
    joint_names: list[str],
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    ids = [asset.data.joint_names.index(joint_name) for joint_name in joint_names]
    return asset.data.joint_pos[:, ids].view(env.num_envs, -1)

def link_pos(
    env: ManagerBasedEnv,
    root_asset_cfg: SceneEntityCfg,
    link_name: str,
):
    root_asset: Articulation = env.scene[root_asset_cfg.name]
    link_idx = root_asset.data.body_names.index(link_name)
    link_pos_w = root_asset.data.body_link_pos_w[:, link_idx].view(-1, 3)

    link_pos_e = link_pos_w - env.scene.env_origins
    return link_pos_e


def link_quat(
    env: ManagerBasedEnv,
    root_asset_cfg: SceneEntityCfg,
    link_name: str,
):
    root_asset: Articulation = env.scene[root_asset_cfg.name]
    link_idx = root_asset.data.body_names.index(link_name)
    link_quat_w = root_asset.data.body_link_quat_w[:, link_idx].view(-1, 4)
    return link_quat_w



def target_asset_pose_in_root_asset_frame(
    env: ManagerBasedEnv,
    target_asset_cfg: SceneEntityCfg,
    root_asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
    target_asset_offset=None,
    root_asset_offset=None,
    rotation_repr: str = "quat",
):
    target_asset: RigidObject | Articulation = env.scene[target_asset_cfg.name]
    root_asset: RigidObject | Articulation = env.scene[root_asset_cfg.name]

    target_body_idx = 0 if isinstance(target_asset_cfg.body_ids, slice) else target_asset_cfg.body_ids
    root_body_idx = 0 if isinstance(root_asset_cfg.body_ids, slice) else root_asset_cfg.body_ids

    target_pos = target_asset.data.body_link_pos_w[:, target_body_idx].view(-1, 3)
    target_quat = target_asset.data.body_link_quat_w[:, target_body_idx].view(-1, 4)
    root_pos = root_asset.data.body_link_pos_w[:, root_body_idx].view(-1, 3)
    root_quat = root_asset.data.body_link_quat_w[:, root_body_idx].view(-1, 4)

    if root_asset_offset is not None:
        root_pos, root_quat = root_asset_offset.combine(root_pos, root_quat)
    if target_asset_offset is not None:
        target_pos, target_quat = target_asset_offset.combine(target_pos, target_quat)

    target_pos_b, target_quat_b = math_utils.subtract_frame_transforms(root_pos, root_quat, target_pos, target_quat)

    if rotation_repr == "axis_angle":
        axis_angle = math_utils.axis_angle_from_quat(target_quat_b)
        return torch.cat([target_pos_b, axis_angle], dim=1)
    elif rotation_repr == "quat":
        return torch.cat([target_pos_b, target_quat_b], dim=1)
    else:
        raise ValueError(f"Invalid rotation_repr: {rotation_repr}. Must be one of: 'quat', 'axis_angle'")

def asset_link_velocity_in_root_asset_frame(
    env: ManagerBasedEnv,
    target_asset_cfg: SceneEntityCfg,
    root_asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
    target_asset: RigidObject | Articulation = env.scene[target_asset_cfg.name]
    root_asset: RigidObject | Articulation = env.scene[root_asset_cfg.name]

    taget_body_idx = 0 if isinstance(target_asset_cfg.body_ids, slice) else target_asset_cfg.body_ids

    asset_lin_vel_b, _ = math_utils.subtract_frame_transforms(
        root_asset.data.root_pos_w,
        root_asset.data.root_quat_w,
        target_asset.data.body_lin_vel_w[:, taget_body_idx].view(-1, 3),
    )
    asset_ang_vel_b, _ = math_utils.subtract_frame_transforms(
        root_asset.data.root_pos_w,
        root_asset.data.root_quat_w,
        target_asset.data.body_lin_vel_w[:, taget_body_idx].view(-1, 3),
    )

    return torch.cat([asset_lin_vel_b, asset_ang_vel_b], dim=1)

def get_material_properties(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.root_physx_view.get_material_properties().view(env.num_envs, -1)


def get_mass(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.root_physx_view.get_masses().view(env.num_envs, -1)


def get_joint_friction(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.data.joint_friction_coeff.view(env.num_envs, -1)


def get_joint_armature(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.data.joint_armature.view(env.num_envs, -1)


def get_joint_stiffness(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.data.joint_stiffness.view(env.num_envs, -1)


def get_joint_damping(
    env: ManagerBasedRLEnv,
    asset_cfg: SceneEntityCfg,
):
    asset: RigidObject | Articulation = env.scene[asset_cfg.name]
    return asset.data.joint_damping.view(env.num_envs, -1)


def object_pos(
    env: ManagerBasedEnv,
    asset_cfg: SceneEntityCfg,
) -> torch.Tensor:
    """Root position of a rigid object in env frame — (num_envs, 3)."""
    obj: RigidObject = env.scene[asset_cfg.name]
    return obj.data.root_pos_w - env.scene.env_origins


def object_quat(
    env: ManagerBasedEnv,
    asset_cfg: SceneEntityCfg,
) -> torch.Tensor:
    """Root quaternion (wxyz) of a rigid object — (num_envs, 4)."""
    obj: RigidObject = env.scene[asset_cfg.name]
    return obj.data.root_quat_w

