"""
gRPC inference client for the vla_foundry LBM policy server.

Only requires: grpcio, grpcio-tools, numpy (no pydrake/robot_gym).
Proto stubs are compiled at import time from the vla_foundry repo.

Expected observation format (passed to infer()):
    obs = {
        "cameras": {
            "<serial>": {
                "rgb": np.ndarray (H, W, 3) uint8,         # RGB image
                "intrinsics": np.ndarray (9,),              # row-major 3x3 K matrix
                "pose_translation": np.ndarray (3,),        # camera position in world frame
                "pose_rotation": np.ndarray (9,),           # camera orientation, row-major 3x3
            },
            ...
        },
        "poses": {
            "<arm_name>": {
                "translation": np.ndarray (3,),        # xyz in meters (z-up)
                "rotation": np.ndarray (9,),           # row-major 3x3 rotation matrix
            },
            ...
        },
        "grippers": {
            "<gripper_name>": float,                   # gripper position (e.g. 0.0=closed, 0.08=open)
            ...
        },
        "joint_positions": {                           # optional
            "<arm_name>": np.ndarray (N,),             # joint angles in radians
            ...
        },
    }

Default arm/gripper/camera names (bimanual panda setup):
    Arms:     "left::panda", "right::panda"
    Grippers: "left::panda_hand", "right::panda_hand"
    Cameras:  "6CD146030E99" (scene_right), "6CD146031C25" (scene_left)

Returned action (from infer()):
    np.ndarray of shape (26,):
        [ left_xyz(3), left_rot(9), right_xyz(3), right_rot(9), left_grip(1), right_grip(1) ]
    Where rot is a row-major 3x3 rotation matrix flattened to 9 values.

Note: The policy model only uses the RGB pixels — it does not read camera intrinsics
or pose. However the gRPC proto requires them so they must be provided.
"""

import os
import subprocess
import sys
import tempfile
import uuid
from pathlib import Path

import numpy as np
import torch
from polaris.config import PolicyArgs
from polaris.policy.abstract_client import InferenceClient
import isaaclab.utils.math as math_utils

# ---------------------------------------------------------------------------
# Compile proto stubs at import time.
# Set VLA_FOUNDRY_ROOT env var if the vla_foundry repo is not at ../vla_foundry
# relative to the sim-improvement repo root.
# ---------------------------------------------------------------------------
_SIM_IMPROVEMENT_ROOT = Path(__file__).resolve().parents[3]  # sim-improvement/
_VLA_FOUNDRY_ROOT = Path(
    os.environ.get("VLA_FOUNDRY_ROOT", _SIM_IMPROVEMENT_ROOT.parent / "vla_foundry")
)
_PROTO_DIR = _VLA_FOUNDRY_ROOT / "packages" / "grpc-workspace" / "src" / "grpc_workspace" / "proto"
if not _PROTO_DIR.exists():
    raise FileNotFoundError(
        f"Proto directory not found at {_PROTO_DIR}. "
        "Set VLA_FOUNDRY_ROOT to point to the vla_foundry repo."
    )

_proto_out = tempfile.mkdtemp(prefix="lbm_grpc_protos_")
for _proto_file in ["PolicyStep.proto", "PolicyReset.proto", "health.proto"]:
    subprocess.check_call(
        [
            sys.executable,
            "-m",
            "grpc_tools.protoc",
            f"--proto_path={_PROTO_DIR}",
            f"--python_out={_proto_out}",
            f"--grpc_python_out={_proto_out}",
            str(_PROTO_DIR / _proto_file),
        ],
    )
sys.path.insert(0, _proto_out)

import grpc  # noqa: E402
import health_pb2  # noqa: E402
import health_pb2_grpc  # noqa: E402
import PolicyReset_pb2  # noqa: E402
import PolicyReset_pb2_grpc  # noqa: E402
import PolicyStep_pb2  # noqa: E402
import PolicyStep_pb2_grpc  # noqa: E402

# Default camera serials from vla_foundry field_mapping.yaml
DEFAULT_CAMERA_SERIALS = [
    "6CD146030E99",  # scene_right_0
    "6CD146031C25",  # scene_left_0
]

DEFAULT_ARM_NAMES = ["left::panda", "right::panda"]
DEFAULT_GRIPPER_NAMES = ["left::panda_hand", "right::panda_hand"]


def _make_rigid_transform(translation, rotation):
    """Build a RigidTransform proto from arrays."""
    return PolicyStep_pb2.RigidTransform(
        translation=list(translation),
        rotation=list(rotation),
    )


def _obs_to_proto(obs: dict, instruction: str) -> "PolicyStep_pb2.MultiarmObservation":
    """Convert the obs dict to a MultiarmObservation protobuf message."""
    # --- cameras ---
    camera_msgs = []
    for serial, cam_data in obs["cameras"].items():
        img = cam_data["rgb"]
        h, w = img.shape[:2]
        camera_image = PolicyStep_pb2.CameraImage(
            image=PolicyStep_pb2.Image(
                height=h,
                width=w,
                channels=3,
                data=img.astype(np.uint8).tobytes(),
                dtype=PolicyStep_pb2.DTYPE_UINT8,
                compression=PolicyStep_pb2.NONE,
            ),
            info=PolicyStep_pb2.CameraInfo(
                height=h,
                width=w,
                k=list(cam_data["intrinsics"]),
            ),
            pose=_make_rigid_transform(cam_data["pose_translation"], cam_data["pose_rotation"]),
        )
        camera_msgs.append(
            PolicyStep_pb2.CameraImageSet(
                camera_serial=serial,
                has_rgb=True,
                camera_rgb=camera_image,
            )
        )

    # --- robot state ---
    joint_positions = obs.get("joint_positions", {})
    pose_statuses = []
    for arm_name, pose_data in obs["poses"].items():
        jp = list(joint_positions.get(arm_name, [0.0] * 7))
        pose_statuses.append(
            PolicyStep_pb2.RobotPoseStatus(
                robot_name=arm_name,
                pose=_make_rigid_transform(pose_data["translation"], pose_data["rotation"]),
                joint_position=jp,
            )
        )

    gripper_statuses = []
    for gripper_name, gripper_pos in obs["grippers"].items():
        gripper_statuses.append(
            PolicyStep_pb2.RobotGripperStatus(
                gripper_name=gripper_name,
                gripper_position=float(gripper_pos),
            )
        )

    robot_state = PolicyStep_pb2.PosesAndGrippers(
        pose_status=pose_statuses,
        gripper_status=gripper_statuses,
    )
    robot = PolicyStep_pb2.PosesAndGrippersActualAndDesired(
        actual=robot_state,
        desired=robot_state,
        version=20241212,
    )

    return PolicyStep_pb2.MultiarmObservation(
        robot=robot,
        visuo=camera_msgs,
        version=20241212,
        use_language_instruction=True,
        language_instruction=instruction,
    )


def _parse_action(action_proto) -> np.ndarray:
    """Parse PosesAndGrippers proto into a flat action array.

    Returns np.ndarray of shape (16,):
        [left_xyz(3), left_quat(4), left_grip(1),
         right_xyz(3), right_quat(4), right_grip(1)]

    This matches the env action space ordering:
        left_panda_arm(7), left_panda_gripper(1),
        right_panda_arm(7), right_panda_gripper(1).
    """
    poses = {}
    for ps in action_proto.pose_status:
        t = np.array(ps.pose.translation, dtype=np.float64)
        r = np.array(ps.pose.rotation, dtype=np.float64)
        r = torch.tensor(r).reshape(3, 3)
        r = math_utils.quat_from_matrix(r).numpy()
        poses[ps.robot_name] = np.concatenate([t, r])  # (7,)

    grippers = {}
    for gs in action_proto.gripper_status:
        grippers[gs.gripper_name] = gs.gripper_position

    # Interleaved: left arm + left grip, right arm + right grip
    parts = []
    for arm, grip in zip(DEFAULT_ARM_NAMES, DEFAULT_GRIPPER_NAMES):
        parts.append(poses.get(arm, np.zeros(7)))
        parts.append(np.array([grippers.get(grip, 0.0)]))

    return np.concatenate(parts)  # (16,)


@InferenceClient.register(client_name="LbmGrpc")
class LbmGrpcClient(InferenceClient):
    """
    gRPC client for the vla_foundry LBM inference policy server.

    The server handles action chunking internally (via open_loop_steps),
    returning one action per step() call. This client calls the server
    every step.

    You must implement _extract_observation() to convert your env's obs
    dict into the expected format (see module docstring).
    """

    def __init__(self, host: str, port: int, *args, **kwargs) -> None:
        # self.args = args
        self._server_uri = f"{host}:{port}"
        self._client_id = str(uuid.uuid4())
        self._grpc_options = [
            ("grpc.max_send_message_length", 30 * 1024 * 1024),
            ("grpc.max_receive_message_length", 30 * 1024 * 1024),
        ]

        # Connect and wait for server
        print(f"[LbmGrpcClient] Connecting to {self._server_uri}...")
        self._channel = grpc.insecure_channel(self._server_uri, options=self._grpc_options)
        health_stub = health_pb2_grpc.HealthStub(self._channel)
        resp = health_stub.Check(
            health_pb2.HealthCheckRequest(service=""),
            wait_for_ready=True,
        )
        assert resp.status == health_pb2.HealthCheckResponse.SERVING
        print("[LbmGrpcClient] Server is ready.")

        self._reset_stub = PolicyReset_pb2_grpc.PolicyResetServiceStub(self._channel)
        self._step_stub = PolicyStep_pb2_grpc.PolicyStepServiceStub(self._channel)

    def reset(self):
        """Reset the server-side policy state for this client."""
        resp = self._reset_stub.PolicyReset(
            PolicyReset_pb2.PolicyResetRequest(
                client_identifier=self._client_id,
                seed=42,
            ),
            wait_for_ready=True,
        )
        if not resp.success:
            raise RuntimeError(f"[LbmGrpcClient] Reset rejected for client {self._client_id}")

    def _preprocess_obs(self, obs: dict) -> dict:
        """Preprocess the observation to the expected format."""
        returned = {}
        cameras = {
            "scene_right_0": {"rgb": obs["vision"]["external_camera_right"].detach().cpu().numpy().squeeze(0)},
            "scene_left_0": {"rgb": obs["vision"]["external_camera_left"].detach().cpu().numpy().squeeze(0)},
            "wrist_camera_right": {"rgb": obs["vision"]["wrist_camera_right"].detach().cpu().numpy().squeeze(0)},
            "wrist_camera_left": {"rgb": obs["vision"]["wrist_camera_left"].detach().cpu().numpy().squeeze(0)},
        }
        returned["cameras"] = cameras
        # fill with dummy intrinsics and pose
        for cam in cameras.values():
            cam["intrinsics"] = np.array([1000, 0, 0, 0, 1000, 0, 0, 0, 1])
            cam["pose_translation"] = np.array([0, 0, 0])
            cam["pose_rotation"] = np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])

        poses = {
            "left::panda": {
                "translation": obs["vision"]["left_ee_pos"].detach().cpu().numpy().squeeze(0),
                # convert quat to rotation matrix
                "rotation": math_utils.matrix_from_quat(obs["vision"]["left_ee_quat"]).detach().cpu().numpy().squeeze(0).flatten(),

            },
            "right::panda": {
                "translation": obs["vision"]["right_ee_pos"].detach().cpu().numpy().squeeze(0),
                "rotation": math_utils.matrix_from_quat(obs["vision"]["right_ee_quat"]).detach().cpu().numpy().squeeze(0).flatten(),
            },
        }
        returned["poses"] = poses

        grippers = {
            "left::panda_hand": obs["vision"]["left_gripper_pos"].detach().cpu().numpy().squeeze(0),
            "right::panda_hand": obs["vision"]["right_gripper_pos"].detach().cpu().numpy().squeeze(0),
        }
        returned["grippers"] = grippers

        return returned


    def infer(
        self, obs: dict, instruction: str, return_viz: bool = False
    ) -> tuple[np.ndarray, np.ndarray | None]:
        """
        Send observation to the LBM policy server and return the action.

        Args:
            obs: Observation dict with the format described in the module docstring.
                 You are responsible for converting your env's raw obs into this format.
            instruction: Language instruction for the task.
            return_viz: If True, return a side-by-side image of the camera views.

        Returns:
            action: np.ndarray (26,) —
                [left_xyz(3), left_rot(9), right_xyz(3), right_rot(9), left_grip(1), right_grip(1)]
            viz: np.ndarray (H, W*N, 3) if return_viz else None — camera views side-by-side.
        """
        obs = self._preprocess_obs(obs)
        obs_proto = _obs_to_proto(obs, instruction)

        resp = self._step_stub.PolicyStep(
            PolicyStep_pb2.PolicyStepRequest(
                client_identifier=self._client_id,
                observation=obs_proto,
            ),
            wait_for_ready=True,
        )
        if not resp.success:
            raise RuntimeError(f"[LbmGrpcClient] Step rejected for client {self._client_id}")

        action = _parse_action(resp.action)

        # viz = None
        # if return_viz:
        frames = [cam_data["rgb"] for cam_data in obs["cameras"].values()]
        viz = None
        if frames:
            # Resize all to same height before concatenating
            min_h = min(f.shape[0] for f in frames)
            resized = []
            for f in frames:
                if f.shape[0] != min_h:
                    scale = min_h / f.shape[0]
                    new_w = int(f.shape[1] * scale)
                    from PIL import Image

                    f = np.array(Image.fromarray(f).resize((new_w, min_h)))
                resized.append(f)
            viz = np.concatenate(resized, axis=1)

        return action, viz

    def __del__(self):
        if hasattr(self, "_channel"):
            self._channel.close()
