from __future__ import annotations

import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING

import omni.log

import isaaclab.utils.math as math_utils
from isaaclab.assets.articulation import Articulation
from isaaclab.controllers.differential_ik import DifferentialIKController
from isaaclab.controllers.differential_ik_cfg import DifferentialIKControllerCfg
from isaaclab.envs.mdp.actions.actions_cfg import DifferentialInverseKinematicsActionCfg
from isaaclab.envs.utils.io_descriptors import GenericActionIODescriptor
from isaaclab.managers.action_manager import ActionTerm
from isaaclab.managers.manager_base import ManagerTermBase
from isaaclab.managers.manager_term_cfg import ActionTermCfg
from isaaclab.utils import configclass

if TYPE_CHECKING:
    from isaaclab.envs import ManagerBasedEnv


class DualArmEnvFrameDiffIKAction(ActionTerm):
    """Dual-arm env-frame DiffIK action term.

    Like EnvFrameDiffIKAction but handles both left and right arms in a single
    action term. Actions are concatenated as [left_6dof, right_6dof].

    Each arm gets its own DifferentialIKController and asset reference. The IK
    is computed in env frame (world-aligned axes, origin at env origin), same as
    EnvFrameDiffIKAction.
    """

    cfg: DualArmEnvFrameDiffIKActionCfg

    def __init__(self, cfg: DualArmEnvFrameDiffIKActionCfg, env: ManagerBasedEnv):
        # Skip ActionTerm.__init__ which expects a single asset_name.
        # Call ManagerTermBase.__init__ directly for cfg/env setup.
        ManagerTermBase.__init__(self, cfg, env)

        # Set up ActionTerm bookkeeping that we skipped
        self._IO_descriptor = GenericActionIODescriptor()
        self._export_IO_descriptor = True
        self._debug_vis_handle = None

        # Resolve both assets from the scene
        self._left_asset: Articulation = env.scene[cfg.left_asset_name]
        self._right_asset: Articulation = env.scene[cfg.right_asset_name]

        # Set up per-arm joint/body/Jacobian indices
        self._setup_arm("left", self._left_asset)
        self._setup_arm("right", self._right_asset)

        # Create per-arm IK controllers
        self._left_ik = DifferentialIKController(
            cfg=self.cfg.controller, num_envs=self.num_envs, device=self.device
        )
        self._right_ik = DifferentialIKController(
            cfg=self.cfg.controller, num_envs=self.num_envs, device=self.device
        )

        # Single-arm action dim (e.g. 6 for pose commands)
        self._single_arm_action_dim = self._left_ik.action_dim

        # Create action buffers
        self._raw_actions = torch.zeros(self.num_envs, self.action_dim, device=self.device)
        self._processed_actions = torch.zeros_like(self._raw_actions)

        # Scale (applied per-arm)
        self._scale = torch.zeros((self.num_envs, self._single_arm_action_dim), device=self.device)
        self._scale[:] = torch.tensor(self.cfg.scale, device=self.device)

        # Body offset tensors (shared for both arms since they use the same body_offset)
        if self.cfg.body_offset is not None:
            self._offset_pos = torch.tensor(self.cfg.body_offset.pos, device=self.device).repeat(self.num_envs, 1)
            self._offset_rot = torch.tensor(self.cfg.body_offset.rot, device=self.device).repeat(self.num_envs, 1)
        else:
            self._offset_pos, self._offset_rot = None, None

        # Physics dt for velocity clamping
        self._dt = env.sim.get_physics_dt()
        if cfg.max_joint_vel is not None:
            if isinstance(cfg.max_joint_vel, (int, float)):
                self._max_joint_delta = cfg.max_joint_vel * self._dt
            else:
                self._max_joint_delta = torch.tensor(cfg.max_joint_vel, device=self.device) * self._dt
        else:
            self._max_joint_delta = None

        # Open-loop tracking state (per arm)
        self._left_open_loop_pos: torch.Tensor | None = None
        self._left_open_loop_quat: torch.Tensor | None = None
        self._right_open_loop_pos: torch.Tensor | None = None
        self._right_open_loop_quat: torch.Tensor | None = None

        # QPNet mode (replaces IK when configured)
        self._qp_net: torch.jit.ScriptModule | None = None
        if cfg.qp_net_path is not None:
            self._qp_net = torch.jit.load(cfg.qp_net_path, map_location=self.device)
            self._qp_net.eval()
            omni.log.info(f"Loaded QPNet from {cfg.qp_net_path}")
            # Tracked joint positions for open-loop QPNet inference
            self._qp_joint_pos_left: torch.Tensor | None = None
            self._qp_joint_pos_right: torch.Tensor | None = None
            # Pre-computed joint targets (set in process_actions, held in apply_actions)
            self._qp_joint_target_left = torch.zeros(self.num_envs, self._left_num_joints, device=self.device)
            self._qp_joint_target_right = torch.zeros(self.num_envs, self._right_num_joints, device=self.device)

    def _setup_arm(self, prefix: str, asset: Articulation):
        """Resolve joint IDs, body index, and Jacobian indices for one arm."""
        joint_ids, joint_names = asset.find_joints(self.cfg.joint_names)
        body_ids, body_names = asset.find_bodies(self.cfg.body_name)
        if len(body_ids) != 1:
            raise ValueError(
                f"Expected one match for body name: {self.cfg.body_name}. "
                f"Found {len(body_ids)}: {body_names}."
            )

        body_idx = body_ids[0]
        num_joints = len(joint_ids)

        if asset.is_fixed_base:
            jacobi_body_idx = body_idx - 1
            jacobi_joint_ids = joint_ids
        else:
            jacobi_body_idx = body_idx
            jacobi_joint_ids = [i + 6 for i in joint_ids]

        # Use slice for efficiency if controlling all joints
        if num_joints == asset.num_joints:
            joint_ids = slice(None)

        setattr(self, f"_{prefix}_joint_ids", joint_ids)
        setattr(self, f"_{prefix}_joint_names", joint_names)
        setattr(self, f"_{prefix}_num_joints", num_joints)
        setattr(self, f"_{prefix}_body_idx", body_idx)
        setattr(self, f"_{prefix}_body_name", body_names[0])
        setattr(self, f"_{prefix}_jacobi_body_idx", jacobi_body_idx)
        setattr(self, f"_{prefix}_jacobi_joint_ids", jacobi_joint_ids)

        omni.log.info(
            f"DualArmEnvFrameDiffIK [{prefix}] joints: {joint_names} [{joint_ids}], "
            f"body: {body_names[0]} [{body_idx}]"
        )

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def action_dim(self) -> int:
        return self._single_arm_action_dim * 2

    @property
    def raw_actions(self) -> torch.Tensor:
        return self._raw_actions

    @property
    def processed_actions(self) -> torch.Tensor:
        return self._processed_actions

    # ------------------------------------------------------------------
    # Operations
    # ------------------------------------------------------------------

    def process_actions(self, actions: torch.Tensor):
        self._raw_actions[:] = actions

        dim = self._single_arm_action_dim
        left_actions = actions[:, :dim] * self._scale
        right_actions = actions[:, dim:] * self._scale
        self._processed_actions[:, :dim] = left_actions
        self._processed_actions[:, dim:] = right_actions

        # Compute current EE poses in env frame and set IK commands
        left_pos, left_quat = self._compute_frame_pose(self._left_asset, self._left_body_idx)
        right_pos, right_quat = self._compute_frame_pose(self._right_asset, self._right_body_idx)

        self._left_ik.set_command(left_actions, left_pos, left_quat)
        self._right_ik.set_command(right_actions, right_pos, right_quat)

        if self._qp_net is not None:
            self._compute_qp_net_targets()

    def apply_actions(self):
        if self._qp_net is not None:
            # QPNet mode: just hold the pre-computed joint targets
            self._left_asset.set_joint_position_target(self._qp_joint_target_left, self._left_joint_ids)
            self._right_asset.set_joint_position_target(self._qp_joint_target_right, self._right_joint_ids)
            return

        self._apply_arm(
            self._left_asset, self._left_ik,
            self._left_body_idx, self._left_joint_ids,
            self._left_jacobi_body_idx, self._left_jacobi_joint_ids,
            "left",
        )
        self._apply_arm(
            self._right_asset, self._right_ik,
            self._right_body_idx, self._right_joint_ids,
            self._right_jacobi_body_idx, self._right_jacobi_joint_ids,
            "right",
        )

    def reset(self, env_ids: Sequence[int] | None = None):
        if env_ids is not None:
            self._raw_actions[env_ids] = 0.0
        else:
            self._raw_actions[:] = 0.0
        # Clear open-loop state so it re-seeds from sim on the next step
        self._left_open_loop_pos = None
        self._left_open_loop_quat = None
        self._right_open_loop_pos = None
        self._right_open_loop_quat = None
        # Clear QPNet tracked state
        if self._qp_net is not None:
            self._qp_joint_pos_left = None
            self._qp_joint_pos_right = None

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _apply_arm(
        self,
        asset: Articulation,
        ik_controller: DifferentialIKController,
        body_idx: int,
        joint_ids,
        jacobi_body_idx: int,
        jacobi_joint_ids,
        prefix: str,
    ):
        ee_pos_curr, ee_quat_curr = self._compute_frame_pose(asset, body_idx)
        joint_pos = asset.data.joint_pos[:, joint_ids]

        # Open-loop: use last commanded pose as "current" for error computation
        if self.cfg.open_loop:
            ol_pos = getattr(self, f"_{prefix}_open_loop_pos")
            ol_quat = getattr(self, f"_{prefix}_open_loop_quat")
            if ol_pos is None:
                ol_pos = ee_pos_curr.clone()
                ol_quat = ee_quat_curr.clone()
                setattr(self, f"_{prefix}_open_loop_pos", ol_pos)
                setattr(self, f"_{prefix}_open_loop_quat", ol_quat)
            ee_pos_for_err = ol_pos
            ee_quat_for_err = ol_quat
        else:
            ee_pos_for_err = ee_pos_curr
            ee_quat_for_err = ee_quat_curr

        if ee_quat_curr.norm() != 0:
            jacobian = self._compute_frame_jacobian(asset, jacobi_body_idx, jacobi_joint_ids)
            joint_pos_des = ik_controller.compute(ee_pos_for_err, ee_quat_for_err, jacobian, joint_pos)
        else:
            joint_pos_des = joint_pos

        if self.cfg.open_loop:
            setattr(self, f"_{prefix}_open_loop_pos", ik_controller.ee_pos_des.clone())
            setattr(self, f"_{prefix}_open_loop_quat", ik_controller.ee_quat_des.clone())

        asset.set_joint_position_target(joint_pos_des, joint_ids)

    def _compute_qp_net_targets(self):
        """Compute joint targets using QPNet instead of IK (one-shot)."""
        # Seed tracked q from sim on first call (or after reset)
        if self._qp_joint_pos_left is None:
            self._qp_joint_pos_left = self._left_asset.data.joint_pos[:, self._left_joint_ids].clone()
            self._qp_joint_pos_right = self._right_asset.data.joint_pos[:, self._right_joint_ids].clone()

        # Get absolute EE targets computed by set_command
        ee_pos_left = self._left_ik.ee_pos_des
        ee_quat_left = self._left_ik.ee_quat_des
        ee_pos_right = self._right_ik.ee_pos_des
        ee_quat_right = self._right_ik.ee_quat_des

        # Convert quat → rot_6d (first two columns of rotation matrix, row-major)
        left_rot_6d = math_utils.matrix_from_quat(ee_quat_left)[:, :, :2].reshape(-1, 6)
        right_rot_6d = math_utils.matrix_from_quat(ee_quat_right)[:, :, :2].reshape(-1, 6)

        qp_input = {
            "robot__desired__joint_position__left::panda": self._qp_joint_pos_left,
            "robot__desired__joint_position__right::panda": self._qp_joint_pos_right,
            "robot__action__poses__left::panda__xyz": ee_pos_left,
            "robot__action__poses__right::panda__xyz": ee_pos_right,
            "robot__action__poses__left::panda__rot_6d": left_rot_6d,
            "robot__action__poses__right::panda__rot_6d": right_rot_6d,
        }

        with torch.no_grad():
            qp_output = self._qp_net(qp_input)

        # Update tracked q (open-loop: feed output back as next input)
        self._qp_joint_pos_left = qp_output["robot__desired__joint_position__left::panda"]
        self._qp_joint_pos_right = qp_output["robot__desired__joint_position__right::panda"]
        self._qp_joint_target_left[:] = self._qp_joint_pos_left
        self._qp_joint_target_right[:] = self._qp_joint_pos_right

    def _compute_frame_pose(self, asset: Articulation, body_idx: int):
        """Compute EE pose in env frame (world-aligned, relative to env origin)."""
        ee_pos_w = asset.data.body_pos_w[:, body_idx]
        ee_quat_w = asset.data.body_quat_w[:, body_idx]
        ee_pos_env = ee_pos_w - self._env.scene.env_origins
        ee_quat_env = ee_quat_w
        if self.cfg.body_offset is not None:
            ee_pos_env, ee_quat_env = math_utils.combine_frame_transforms(
                ee_pos_env, ee_quat_env, self._offset_pos, self._offset_rot
            )
        return ee_pos_env, ee_quat_env

    def _compute_frame_jacobian(self, asset: Articulation, jacobi_body_idx: int, jacobi_joint_ids):
        """Compute world-frame Jacobian for the given arm."""
        jacobian = asset.root_physx_view.get_jacobians()[:, jacobi_body_idx, :, jacobi_joint_ids]
        if self.cfg.body_offset is not None:
            jacobian[:, 0:3, :] += torch.bmm(
                -math_utils.skew_symmetric_matrix(self._offset_pos), jacobian[:, 3:, :]
            )
            jacobian[:, 3:, :] = torch.bmm(
                math_utils.matrix_from_quat(self._offset_rot), jacobian[:, 3:, :]
            )
        return jacobian


@configclass
class DualArmEnvFrameDiffIKActionCfg(ActionTermCfg):
    """Config for dual-arm env-frame DiffIK action.

    Like EnvFrameDiffIKActionCfg but handles both arms in a single action term.
    Actions are concatenated as [left_6dof, right_6dof].
    """

    class_type: type[ActionTerm] = DualArmEnvFrameDiffIKAction

    # Override parent's MISSING default — not used, both arms resolved via left/right names
    asset_name: str = "unused"

    left_asset_name: str = "left_panda"
    """Scene entity name for the left arm articulation."""

    right_asset_name: str = "right_panda"
    """Scene entity name for the right arm articulation."""

    joint_names: list[str] = ["panda_joint.*"]
    """Regex for arm joint names (applied to both assets)."""

    body_name: str = "panda_link8"
    """End-effector body name (applied to both assets)."""

    body_offset: DifferentialInverseKinematicsActionCfg.OffsetCfg | None = None
    """Optional offset from body_name to the desired control frame."""

    scale: float | tuple[float, ...] = 1.0
    """Scaling applied to raw actions before IK."""

    controller: DifferentialIKControllerCfg = DifferentialIKControllerCfg(
        command_type="pose",
        use_relative_mode=False,
        ik_method="dls",
        ik_params={"lambda_val": 0.1},
    )
    """Differential IK controller config (shared by both arms)."""

    open_loop: bool = False
    """When True, compute pose error relative to the previous commanded pose
    instead of the current sim EE pose."""

    max_joint_vel: list[float] | float | None = None
    """Max joint velocity (rad/s) for clamping IK output delta-q. None disables."""

    qp_net_path: str | None = None
    """Path to a TorchScript QPNet checkpoint. When set, replaces IK with the
    learned QPNet for joint target computation. The model takes
    (q_commanded, ee_target_xyz, ee_target_rot6d) and outputs next q_commanded."""
