from __future__ import annotations

import torch
from typing import TYPE_CHECKING

import isaaclab.utils.math as math_utils
from isaaclab.envs.mdp import JointAction
from isaaclab.envs.mdp.actions.task_space_actions import DifferentialInverseKinematicsAction


if TYPE_CHECKING:
    from isaaclab.envs import ManagerBasedEnv

    from . import actions_cfg


class EnvFrameDiffIKAction(DifferentialInverseKinematicsAction):
    """DiffIK with commands in env frame (world-aligned axes, origin at env origin).

    The standard DiffIK operates in the root body frame, so if the robot base
    has a rotation (e.g. 90deg Z baked into the USD), the command axes rotate
    with it. This subclass computes the EE pose and Jacobian in the env frame
    instead, so action axes always align with the world/env coordinate system.

    Includes anzu-style singularity protection: K_VX gain scaling, Cartesian
    velocity clamping, and joint velocity clamping. See EnvFrameDiffIKActionCfg.
    """

    cfg: actions_cfg.EnvFrameDiffIKActionCfg

    def __init__(self, cfg: actions_cfg.EnvFrameDiffIKActionCfg, env: ManagerBasedEnv):
        super().__init__(cfg, env)
        self._dt = env.sim.get_physics_dt()
        # Pre-compute max joint position delta per sim step from velocity limit
        if cfg.max_joint_vel is not None:
            if isinstance(cfg.max_joint_vel, (int, float)):
                self._max_joint_delta = cfg.max_joint_vel * self._dt
            else:
                self._max_joint_delta = torch.tensor(cfg.max_joint_vel, device=self.device) * self._dt
        else:
            self._max_joint_delta = None

        # Open-loop mode: track the last commanded pose instead of reading sim state
        self._open_loop_pos: torch.Tensor | None = None
        self._open_loop_quat: torch.Tensor | None = None

    def reset(self, env_ids=None):
        super().reset(env_ids)
        # Clear open-loop state so it re-seeds from sim on the next step
        self._open_loop_pos = None
        self._open_loop_quat = None

    def apply_actions(self):
        ee_pos_curr, ee_quat_curr = self._compute_frame_pose()
        joint_pos = self._asset.data.joint_pos[:, self._joint_ids]

        # Open-loop: use the last commanded pose as the "current" pose for
        # error computation, so the IK dead-reckons from its own commands.
        if self.cfg.open_loop:
            if self._open_loop_pos is None:
                # First step: seed with the actual sim pose
                self._open_loop_pos = ee_pos_curr.clone()
                self._open_loop_quat = ee_quat_curr.clone()
            ee_pos_for_err = self._open_loop_pos
            ee_quat_for_err = self._open_loop_quat
        else:
            ee_pos_for_err = ee_pos_curr
            ee_quat_for_err = ee_quat_curr

        if ee_quat_curr.norm() != 0:
            jacobian = self._compute_frame_jacobian()

            joint_pos_des = self._ik_controller.compute(ee_pos_for_err, ee_quat_for_err, jacobian, joint_pos)
        else:
            joint_pos_des = joint_pos




        # 5. Clamp joint velocity
        # delta_joint_pos = joint_pos_des - joint_pos
        # if self._max_joint_delta is not None:
        #     delta_joint_pos = torch.clamp(delta_joint_pos, -self._max_joint_delta, self._max_joint_delta)
        # joint_pos_des = joint_pos + delta_joint_pos

        # Open-loop: update the tracked commanded pose to the desired target
        if self.cfg.open_loop:
            self._open_loop_pos = self._ik_controller.ee_pos_des.clone()
            self._open_loop_quat = self._ik_controller.ee_quat_des.clone()

        self._asset.set_joint_position_target(joint_pos_des, self._joint_ids)

    def _compute_frame_pose(self):
        ee_pos_w = self._asset.data.body_pos_w[:, self._body_idx]
        ee_quat_w = self._asset.data.body_quat_w[:, self._body_idx]
        # Pose relative to env origin with world-aligned axes
        ee_pos_env = ee_pos_w - self._env.scene.env_origins
        ee_quat_env = ee_quat_w
        if self.cfg.body_offset is not None:
            ee_pos_env, ee_quat_env = math_utils.combine_frame_transforms(
                ee_pos_env, ee_quat_env, self._offset_pos, self._offset_rot
            )
        return ee_pos_env, ee_quat_env

    def _compute_frame_jacobian(self):
        jacobian = self.jacobian_w
        if self.cfg.body_offset is not None:
            jacobian[:, 0:3, :] += torch.bmm(
                -math_utils.skew_symmetric_matrix(self._offset_pos), jacobian[:, 3:, :]
            )
            jacobian[:, 3:, :] = torch.bmm(
                math_utils.matrix_from_quat(self._offset_rot), jacobian[:, 3:, :]
            )
        return jacobian


class DefaultJointPositionStaticAction(JointAction):
    """Joint action term that applies the processed actions to the articulation's joints as position commands."""

    cfg: actions_cfg.DefaultJointPositionStaticActionCfg
    """The configuration of the action term."""

    def __init__(self, cfg: actions_cfg.DefaultJointPositionStaticActionCfg, env: ManagerBasedEnv):
        # initialize the action term
        super().__init__(cfg, env)
        # use default joint positions as offset
        if cfg.use_default_offset:
            self._offset = self._asset.data.default_joint_pos[:, self._joint_ids].clone()
        self._default_actions = self._asset.data.default_joint_pos[:, self._joint_ids].clone()

    @property
    def action_dim(self) -> int:
        return 0

    def process_actions(self, actions: torch.Tensor):
        pass

    def apply_actions(self):
        # set position targets
        self._asset.set_joint_position_target(self._default_actions, joint_ids=self._joint_ids)

class TransformedOneShotDifferentialIKAction(DifferentialInverseKinematicsAction):
    """One-shot Differential IK action term with coordinate frame transformation.

    Unlike standard DifferentialInverseKinematicsAction which recomputes IK at every physics
    step (e.g., 500 Hz), this action computes IK ONCE per policy step and holds the joint
    target fixed. This makes it equivalent to JointPos in terms of control structure,
    eliminating the distillation gap when training a JointPos student from an IK expert.

    The workflow is:

    1. Receive 6-DOF Cartesian commands [x, y, z, rx, ry, rz] in transformed frame
    2. Apply coordinate frame transformation to standard robot base frame
    3. Compute IK ONCE to get joint target (not recomputed during decimation)
    4. Hold joint target fixed for all physics steps (PD actuator tracks it)

    This exposes ``delta_joint_pos`` which is the relative joint position action - exactly
    what a JointPos policy would need to output to achieve the same behavior.
    """

    cfg: actions_cfg.TransformedOneShotDifferentialIKActionCfg
    """The configuration of the action term."""

    def __init__(self, cfg: actions_cfg.TransformedOneShotDifferentialIKActionCfg, env: ManagerBasedEnv):
        # Initialize the parent IK action
        super().__init__(cfg, env)

        # Setup action root offset transformation
        if self.cfg.action_root_offset is not None:
            self._action_root_offset_pos = torch.tensor(cfg.action_root_offset.pos, device=self.device).repeat(
                self.num_envs, 1
            )
            self._action_root_offset_quat = torch.tensor(cfg.action_root_offset.rot, device=self.device).repeat(
                self.num_envs, 1
            )
        else:
            self._action_root_offset_pos = None
            self._action_root_offset_quat = None

        # Buffers for one-shot IK: joint target computed once per step
        self._joint_pos_des = torch.zeros(self.num_envs, self._num_joints, device=self.device)
        self._delta_joint_pos = torch.zeros(self.num_envs, self._num_joints, device=self.device)

    @property
    def delta_joint_pos(self) -> torch.Tensor:
        return self._delta_joint_pos

    def process_actions(self, actions: torch.Tensor):
        """Process raw actions: transform coordinates, apply scaling, compute IK, store joint target.

        Args:
            actions: The raw actions in shape (num_envs, 6) representing [x, y, z, rx, ry, rz].
        """
        # Store raw actions
        self._raw_actions[:] = actions

        # Transform actions from offset frame to standard frame (if offset is configured)
        actions_standard = self._transform_actions_to_standard_frame(actions)

        # Apply scaling and clipping
        self._processed_actions[:] = actions_standard * self._scale
        if self.cfg.clip is not None:
            self._processed_actions = torch.clamp(
                self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1]
            )

        # Obtain quantities from simulation
        ee_pos_curr, ee_quat_curr = self._compute_frame_pose()
        joint_pos = self._asset.data.joint_pos[:, self._joint_ids]

        # Set command into controller
        self._ik_controller.set_command(self._processed_actions, ee_pos_curr, ee_quat_curr)

        # Compute IK ONCE and store the joint target (not recomputed in apply_actions)
        self._compute_joint_target(ee_pos_curr, ee_quat_curr, joint_pos)

    def apply_actions(self):
        """Apply the pre-computed joint position target to the articulation.

        Unlike the parent class which recomputes IK every physics step, this simply
        applies the joint target that was computed once in process_actions().
        """
        self._asset.set_joint_position_target(self._joint_pos_des, self._joint_ids)

    def _transform_actions_to_standard_frame(self, actions: torch.Tensor) -> torch.Tensor:
        """Transform actions from offset coordinate frame to standard robot base frame.

        Args:
            actions: The raw actions in offset frame, shape (num_envs, 6).

        Returns:
            The transformed actions in standard frame, shape (num_envs, 6).
        """
        if self._action_root_offset_pos is not None and self._action_root_offset_quat is not None:
            # Extract position and rotation deltas
            delta_pos_offset = actions[:, :3]  # [x, y, z]
            delta_rot_offset = actions[:, 3:6]  # [rx, ry, rz] in axis-angle

            # Get rotation matrix from offset-robot-base to standard-robot-base
            # The action_root_offset defines standard -> offset, so we need the inverse
            R_offset_to_standard = math_utils.matrix_from_quat(math_utils.quat_inv(self._action_root_offset_quat))

            # Transform position delta: rotate from offset coordinates to standard coordinates
            delta_pos_standard = torch.bmm(R_offset_to_standard, delta_pos_offset.unsqueeze(-1)).squeeze(-1)

            # Transform rotation delta (axis-angle): rotate the axis from offset coordinates to standard
            delta_rot_standard = torch.bmm(R_offset_to_standard, delta_rot_offset.unsqueeze(-1)).squeeze(-1)

            return torch.cat([delta_pos_standard, delta_rot_standard], dim=-1)
        else:
            return actions

    def _compute_joint_target(self, ee_pos: torch.Tensor, ee_quat: torch.Tensor, joint_pos: torch.Tensor):
        """Compute and store the joint position target using IK.

        Args:
            ee_pos: Current end-effector position in shape (num_envs, 3).
            ee_quat: Current end-effector orientation in shape (num_envs, 4).
            joint_pos: Current joint positions in shape (num_envs, num_joints).
        """
        if ee_quat.norm() != 0:
            jacobian = self._compute_frame_jacobian()
            self._joint_pos_des[:] = self._ik_controller.compute(ee_pos, ee_quat, jacobian, joint_pos)
            self._delta_joint_pos[:] = self._joint_pos_des - joint_pos
        else:
            self._joint_pos_des[:] = joint_pos.clone()
            self._delta_joint_pos[:] = 0.0
