import numpy as np
from scipy.spatial.transform import Rotation
import isaaclab.utils.math as math_utils
from openpi_client import websocket_client_policy, image_tools


# Action field layout: [left_xyz(3), right_xyz(3), left_rot6d(6), right_rot6d(6), left_grip(1), right_grip(1)] = 20
FIELD_DIMS = np.array([3, 3, 6, 6, 1, 1])
FIELD_SPLITS = np.cumsum(FIELD_DIMS)[:-1]

# Identity rot6d: first two columns of I_3 flattened = [1,0,0,0,1,0]
_IDENTITY_ROT6D = np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])


def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
    """6D continuous rotation (Zhou et al. 2019) -> 3x3 rotation matrix."""
    a1 = rot6d[..., :3]
    a2 = rot6d[..., 3:6]
    b1 = a1 / np.linalg.norm(a1, axis=-1, keepdims=True)
    b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
    b2 = b2 / np.linalg.norm(b2, axis=-1, keepdims=True)
    b3 = np.cross(b1, b2)
    return np.stack([b1, b2, b3], axis=-2)


def _rot_to_quat_wxyz(rot: np.ndarray) -> np.ndarray:
    """(..., 3, 3) rotation matrices -> (..., 4) wxyz quaternions."""
    orig_shape = rot.shape[:-2]
    q = Rotation.from_matrix(rot.reshape(-1, 3, 3)).as_quat()  # (N, 4) xyzw
    return q[:, [3, 0, 1, 2]].reshape(orig_shape + (4,))  # (..., 4) wxyz


def _parse_action_fields(vec: np.ndarray) -> dict:
    """Split a (..., 20) relative action vector into per-arm fields."""
    parts = np.split(vec, FIELD_SPLITS, axis=-1)
    return {
        "left_xyz": parts[0],
        "right_xyz": parts[1],
        "left_rot6d": parts[2],
        "right_rot6d": parts[3],
        "left_gripper": parts[4],
        "right_gripper": parts[5],
    }


def _compose_relative(parsed: dict, reference: dict[str, dict]) -> np.ndarray:
    """Compose relative actions with reference pose to produce absolute IK actions.

    reference maps side -> {"pos": (N, 3), "rot": (N, 3, 3)} in env frame.

    Conversion: T_abs = T_ref @ T_rel
        abs_pos = ref_rot @ rel_pos + ref_pos
        abs_rot = ref_rot @ rel_rot

    Returns (N, 16): [left_pos(3), left_quat_wxyz(4), right_pos(3), right_quat_wxyz(4),
                    left_grip(1), right_grip(1)]
    """
    arm_parts = []
    grip_parts = []
    for side in ("left", "right"):
        rel_pos = parsed[f"{side}_xyz"]           # (N, 3)
        rel_rot = _rot6d_to_matrix(parsed[f"{side}_rot6d"])  # (N, 3, 3)
        ref_pos = reference[side]["pos"]           # (N, 3)
        ref_rot = reference[side]["rot"]           # (N, 3, 3)

        abs_pos = np.einsum('nij,nj->ni', ref_rot, rel_pos) + ref_pos  # (N, 3)
        abs_rot = np.matmul(ref_rot, rel_rot)      # (N, 3, 3)
        quat = _rot_to_quat_wxyz(abs_rot)          # (N, 4)
        arm_parts.append(np.concatenate([abs_pos, quat], axis=-1))  # (N, 7)
        grip_parts.append(parsed[f"{side}_gripper"])  # (N, 1)

    return np.concatenate(arm_parts + grip_parts, axis=-1)  # (N, 16)


class LbmOpenpiClient:
    """
    OpenPI websocket client for the LBM policy.

    Sends observations in the format expected by LBMInputs transform
    (see openpi/src/openpi/policies/lbm_policy.py).

    The model outputs relative actions (xyz_relative, rot_6d_relative).
    This client converts them to absolute poses using the EE pose captured
    at the chunk boundary as the reference frame.

    Expected observation format (passed to infer()):
        obs = {
            "vision": {
                "external_camera_right": torch.Tensor (num_envs, H, W, 3),
                "external_camera_left": torch.Tensor (num_envs, H, W, 3),
                "wrist_camera_right": torch.Tensor (num_envs, H, W, 3),
                "wrist_camera_left": torch.Tensor (num_envs, H, W, 3),
                "left_ee_pos": torch.Tensor (num_envs, 3),
                "left_ee_quat": torch.Tensor (num_envs, 4),
                "right_ee_pos": torch.Tensor (num_envs, 3),
                "right_ee_quat": torch.Tensor (num_envs, 4),
            },
        }

    Returned action (from infer()):
        np.ndarray of shape (num_envs, 16):
            [left_pos(3), left_quat_wxyz(4), right_pos(3), right_quat_wxyz(4),
             left_grip(1), right_grip(1)]
    """

    def __init__(self, host: str, port: int, open_loop_horizon: int, action_shape: tuple[int, ...], **kwargs) -> None:
        self.client = websocket_client_policy.WebsocketClientPolicy(
            host=host, port=port
        )
        self.num_envs = action_shape[0]
        self._chunk_step = 0
        self._action_chunks = None  # (num_envs, horizon, 20) relative actions
        self.open_loop_horizon = open_loop_horizon
        # Reference EE pose captured at chunk boundary for relative→absolute
        self._reference: dict[str, dict] | None = None

    @property
    def rerender(self) -> bool:
        return (
            self._chunk_step == 0
            or self._chunk_step >= self.open_loop_horizon
        )

    def reset(self, env_ids=None, obs=None):
        if env_ids is None:
            self._chunk_step = 0
            self._action_chunks = None
            self._reference = None
        else:
            if self._action_chunks is not None:
                # Zero out remaining chunk actions for reset envs with identity
                # rot6d so they hold their reference pose until next inference.
                # Action layout: [l_xyz(3), r_xyz(3), l_rot6d(6), r_rot6d(6), l_grip(1), r_grip(1)]
                identity_action = np.zeros(20)
                identity_action[6:12] = _IDENTITY_ROT6D   # left rot6d
                identity_action[12:18] = _IDENTITY_ROT6D  # right rot6d
                identity_action[18] = 1.0                  # left gripper open
                identity_action[19] = 1.0                  # right gripper open
                self._action_chunks[env_ids, self._chunk_step:] = identity_action

            # Update reference pose for reset envs so identity actions
            # compose against the new post-reset EE pose, not the stale one.
            if obs is not None and self._reference is not None:
                vision = obs["vision"]
                for side in ("left", "right"):
                    pos = vision[f"{side}_ee_pos"].detach().cpu().numpy()
                    rot_mat = math_utils.matrix_from_quat(
                        vision[f"{side}_ee_quat"]
                    ).detach().cpu().numpy()
                    self._reference[side]["pos"][env_ids] = pos[env_ids]
                    self._reference[side]["rot"][env_ids] = rot_mat[env_ids]

    def infer(
        self, obs: dict, instruction: str, return_viz: bool = False
    ) -> tuple[np.ndarray, np.ndarray | None]:
        viz = None
        if (
            self._chunk_step == 0
            or self._chunk_step >= self.open_loop_horizon
        ):
            request_data = self._preprocess_obs(obs)
            request_data["prompt"] = [instruction] * self.num_envs

            # Capture reference EE pose at chunk boundary
            self._reference = self._extract_ee_reference(obs)

            self._chunk_step = 0
            server_response = self.client.infer(request_data)
            self._action_chunks = np.array(server_response["actions"])  # (num_envs, horizon, 20)

            if return_viz:
                viz = self._make_viz(request_data)

        if return_viz and viz is None:
            request_data = self._preprocess_obs(obs)
            viz = self._make_viz(request_data)

        if self._action_chunks is None:
            raise ValueError("No action chunk predicted")

        # Convert relative action to absolute using reference pose
        rel_action = self._action_chunks[:, self._chunk_step]  # (num_envs, 20)
        parsed = _parse_action_fields(rel_action)
        action = _compose_relative(parsed, self._reference)  # (num_envs, 16)

        self._chunk_step += 1

        return action, viz

    def _extract_ee_reference(self, obs: dict) -> dict[str, dict]:
        """Extract current EE pose from obs as the reference for relative→absolute."""
        vision = obs["vision"]
        reference = {}
        for side in ("left", "right"):
            pos = vision[f"{side}_ee_pos"].detach().cpu().numpy()  # (num_envs, 3)
            rot_mat = math_utils.matrix_from_quat(
                vision[f"{side}_ee_quat"]
            ).detach().cpu().numpy()  # (num_envs, 3, 3)
            reference[side] = {"pos": pos, "rot": rot_mat}
        return reference

    def _preprocess_obs(self, obs: dict) -> dict:
        """Convert sim observation to the format expected by LBMInputs."""
        vision = obs["vision"]

        # Extract and resize images to 224x224 — resize_with_pad handles (*b, h, w, c)
        wrist_right = image_tools.resize_with_pad(
            vision["wrist_camera_right"].detach().cpu().numpy(), 224, 224
        )
        wrist_left = image_tools.resize_with_pad(
            vision["wrist_camera_left"].detach().cpu().numpy(), 224, 224
        )
        scene_left = image_tools.resize_with_pad(
            vision["external_camera_left"].detach().cpu().numpy(), 224, 224
        )
        scene_right = image_tools.resize_with_pad(
            vision["external_camera_right"].detach().cpu().numpy(), 224, 224
        )

        # Extract EE positions — (num_envs, 3)
        left_pos = vision["left_ee_pos"].detach().cpu().numpy()
        right_pos = vision["right_ee_pos"].detach().cpu().numpy()

        # Convert quaternions to rot_6d (first two columns of rotation matrix)
        left_rot_mat = math_utils.matrix_from_quat(vision["left_ee_quat"]).detach().cpu().numpy()  # (num_envs, 3, 3)
        right_rot_mat = math_utils.matrix_from_quat(vision["right_ee_quat"]).detach().cpu().numpy()  # (num_envs, 3, 3)
        left_rot_6d = np.concatenate([left_rot_mat[:, :, 0], left_rot_mat[:, :, 1]], axis=-1)   # (num_envs, 6)
        right_rot_6d = np.concatenate([right_rot_mat[:, :, 0], right_rot_mat[:, :, 1]], axis=-1)  # (num_envs, 6)

        return {
            "observation": {
                "proprioception": {
                    "robot__actual__poses__left::panda__xyz": left_pos,
                    "robot__actual__poses__right::panda__xyz": right_pos,
                    "robot__actual__poses__left::panda__rot_6d": left_rot_6d,
                    "robot__actual__poses__right::panda__rot_6d": right_rot_6d,
                },
                "images": {
                    "wrist_right_minus_t0": wrist_right,
                    "wrist_left_plus_t0": wrist_left,
                    "scene_left_0_t0": scene_left,
                    "scene_right_0_t0": scene_right,
                },
            },
        }

    def _make_viz(self, request_data: dict) -> list[np.ndarray]:
        """Create a side-by-side visualization of the camera views (env 0 only)."""
        images = request_data["observation"]["images"]
        all_env_frames = []
        for idx in range(self.num_envs):
            frame = [
                images["wrist_right_minus_t0"][idx],
                images["wrist_left_plus_t0"][idx],
                images["scene_left_0_t0"][idx],
                images["scene_right_0_t0"][idx],
            ]
            all_env_frames.append(np.concatenate(frame, axis=1))
        return all_env_frames