"""OpenPI Pi0.5 bridge for YAM robot deployment.

Implements the ``ModelBridge`` interface from raiden. Connects to a running
OpenPI policy server (``serve_policy.py``) via websocket and translates
between raiden observations and the OpenPI observation format.

The OpenPI server handles all model inference, normalization, and
delta→absolute action conversion. This bridge is purely an observation
reformatter and action reorderer.

Usage (server + bridge)::

    # Terminal 1: start the OpenPI policy server
    python scripts/serve_policy.py \\
        --policy.config pi05_yam \\
        --policy.dir /path/to/checkpoint \\
        --port 8000

    # Terminal 2: run raiden inference (note: --resize_images must be empty
    # so the bridge receives raw 1280x720 frames and can letterbox correctly)
    rd infer \\
        --bridge openpi_bridge:OpenPiBridge \\
        --ckpt_path unused \\
        --action_hz 30.0 \\
        --resize_images '' \\
        --bridge-kwargs host=localhost port=8000 action_horizon=10 \\
            prompt='Sort cans shortest to heighest / left to right (robot'\\''s perspective)'

OpenPI 14D layout: [left_joint(6), left_grip(1), right_joint(6), right_grip(1)]
Raiden 14D layout: [right_joint(6), right_grip(1), left_joint(6), left_grip(1)]
"""

import cv2
import numpy as np

from raiden.inference import ModelBridge

MODEL_IMG_SIZE = 224

# Camera name mapping: raiden camera name → OpenPI observation key.
_CAM_MAP: dict[str, str] = {
    "scene_camera": "observation/image_head",
    "left_wrist_camera": "observation/image_left_wrist",
    "right_wrist_camera": "observation/image_right_wrist",
}


def _raiden_to_openpi_state(
    r_joint_pos: np.ndarray,
    l_joint_pos: np.ndarray,
) -> np.ndarray:
    """Convert raiden proprio (7,) pairs to OpenPI 14D state.

    Raiden: r_joint_pos = [r_joints(6), r_grip(1)]
            l_joint_pos = [l_joints(6), l_grip(1)]
    OpenPI: [l_joints(6), l_grip(1), r_joints(6), r_grip(1)]
    """
    return np.concatenate(
        [
            l_joint_pos[:6],
            l_joint_pos[6:7],
            r_joint_pos[:6],
            r_joint_pos[6:7],
        ]
    ).astype(np.float32)


def _openpi_action_to_raiden(action_14d: np.ndarray) -> np.ndarray:
    """Convert OpenPI 14D action to raiden 14D motor command.

    OpenPI: [l_joints(6), l_grip(1), r_joints(6), r_grip(1)]
    Raiden: [r_joints(6), r_grip(1), l_joints(6), l_grip(1)]
    """
    return np.concatenate(
        [
            action_14d[7:14],  # right joints (6) + right gripper (1)
            action_14d[0:7],  # left joints (6) + left gripper (1)
        ]
    ).astype(np.float32)


def _resize_with_pad(image: np.ndarray, height: int, width: int) -> np.ndarray:
    """Aspect-ratio preserving resize with zero-padding.

    Replicates the server-side resize_with_pad behavior so that inference
    images match what the model saw during training (letterboxed, not squashed).
    """
    cur_h, cur_w = image.shape[:2]
    ratio = max(cur_w / width, cur_h / height)
    new_w = int(cur_w / ratio)
    new_h = int(cur_h / ratio)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    pad_h0 = (height - new_h) // 2
    pad_h1 = height - new_h - pad_h0
    pad_w0 = (width - new_w) // 2
    pad_w1 = width - new_w - pad_w0

    padded = cv2.copyMakeBorder(
        resized,
        pad_h0,
        pad_h1,
        pad_w0,
        pad_w1,
        cv2.BORDER_CONSTANT,
        value=(0, 0, 0),
    )
    return padded


class OpenPiBridge(ModelBridge):
    """ModelBridge implementation for OpenPI Pi0.5 on YAM.

    Connects to a running OpenPI websocket server and translates between
    raiden observations and OpenPI's expected input format.
    """

    def __init__(self, action_horizon: int = 10, action_scale: float = 1.0):
        self._broker = None
        self._action_horizon = action_horizon
        self._action_scale = action_scale
        self._prompt: str = ""
        self._step = 0
        self._n_infer = 0
        self._t_infer_sum = 0.0

    def load(self, ckpt_path: str, **kwargs) -> None:
        """Connect to the OpenPI policy server.

        The ckpt_path is unused (the server loads the model). Connection
        parameters are passed via kwargs.

        Keyword args:
            host: Server host (default: "localhost").
            port: Server port (default: 8000).
            action_horizon: Number of actions to execute per inference call.
            prompt: Language instruction for the task.
            action_scale: Interpolation factor (1.0 = full predicted action).
        """
        from openpi_client import action_chunk_broker
        from openpi_client import websocket_client_policy as _ws

        host = kwargs.get("host", "localhost")
        port = int(kwargs.get("port", 8000))
        action_horizon = int(kwargs.get("action_horizon", self._action_horizon))
        self._action_scale = float(kwargs.get("action_scale", self._action_scale))
        self._prompt = kwargs.get("prompt", "")

        print(f"[openpi_bridge] Connecting to server at {host}:{port}")
        ws_policy = _ws.WebsocketClientPolicy(host=host, port=port)

        metadata = ws_policy.get_server_metadata()
        print(f"[openpi_bridge] Server metadata: {metadata}")

        self._broker = action_chunk_broker.ActionChunkBroker(
            policy=ws_policy,
            action_horizon=action_horizon,
        )
        print(f"[openpi_bridge] Ready (action_horizon={action_horizon})")

    def reset(self) -> None:
        """Reset action chunk buffer."""
        if self._broker is not None:
            self._broker.reset()
        self._step = 0
        self._n_infer = 0
        self._t_infer_sum = 0.0

    def predict(self, obs) -> np.ndarray:
        """Convert raiden observation → OpenPI input → inference → 14D action.

        Returns (14,) float32 motor command in raiden order:
        [r_joints(6), r_grip(1), l_joints(6), l_grip(1)]
        """
        import time

        t0 = time.perf_counter()
        obs_dict = self._preprocess(obs)
        result = self._broker.infer(obs_dict)
        infer_ms = (time.perf_counter() - t0) * 1e3

        self._t_infer_sum += infer_ms
        self._step += 1

        if self._step <= 3:
            state = obs_dict["observation/state"]
            print(f"[DEBUG] step={self._step}")
            print(f"  state sent (openpi order): {np.array2string(state, precision=4)}")
            print(f"  raw action from server (openpi order): {np.array2string(result['actions'], precision=4)}")
            for cam_key in ("observation/image_head", "observation/image_left_wrist", "observation/image_right_wrist"):
                if cam_key in obs_dict:
                    img = obs_dict[cam_key]
                    print(f"  {cam_key}: shape={img.shape} dtype={img.dtype} mean={img.mean():.1f}")
            if self._step == 1:
                import os
                os.makedirs("/tmp/openpi_debug", exist_ok=True)
                for cam_key in ("observation/image_head", "observation/image_left_wrist", "observation/image_right_wrist"):
                    if cam_key in obs_dict:
                        fname = cam_key.replace("observation/", "") + ".png"
                        cv2.imwrite(f"/tmp/openpi_debug/{fname}", obs_dict[cam_key][..., ::-1])
                print(f"  [saved debug images to /tmp/openpi_debug/]")

        if self._step % 50 == 1:
            print(
                f"[openpi_bridge] step={self._step:4d}  "
                f"infer={infer_ms:.1f}ms  "
                f"avg={self._t_infer_sum / self._step:.1f}ms"
            )

        actions = result["actions"]  # (14,) from ActionChunkBroker
        action_raiden = _openpi_action_to_raiden(actions)

        if self._action_scale != 1.0:
            current = np.concatenate(
                [
                    obs.proprios.get(
                        "follower_r_joint_pos", np.zeros(7, dtype=np.float32)
                    ),
                    obs.proprios.get(
                        "follower_l_joint_pos", np.zeros(7, dtype=np.float32)
                    ),
                ]
            )
            action_raiden = current + self._action_scale * (action_raiden - current)

        return action_raiden

    def _preprocess(self, obs) -> dict:
        """Convert raiden Observation to OpenPI observation dict.

        Key details:
        - Images arrive from raiden as RGB uint8 (raiden converts BGR→RGB
          in the capture loop). We do NOT apply another color conversion.
        - Images are resized using aspect-ratio-preserving letterbox padding
          to match the server-side resize_with_pad used during training.
        """
        obs_dict: dict = {}

        for cam in obs.cameras:
            obs_key = _CAM_MAP.get(cam.name)
            if obs_key is None:
                continue
            # cam.image is already RGB uint8 from raiden's capture loop.
            # Apply letterbox resize to match training preprocessing.
            rgb = _resize_with_pad(cam.image, MODEL_IMG_SIZE, MODEL_IMG_SIZE)
            obs_dict[obs_key] = rgb

        # State: assemble 14D from raiden proprios
        r_pos = obs.proprios.get("follower_r_joint_pos", np.zeros(7, dtype=np.float32))
        l_pos = obs.proprios.get("follower_l_joint_pos", np.zeros(7, dtype=np.float32))
        obs_dict["observation/state"] = _raiden_to_openpi_state(r_pos, l_pos)

        # Language prompt
        if self._prompt:
            obs_dict["prompt"] = self._prompt

        return obs_dict


# ---------------------------------------------------------------------------
# Standalone entry point
# ---------------------------------------------------------------------------


def main():
    import argparse

    parser = argparse.ArgumentParser(
        description="Deploy OpenPI Pi0.5 policy on YAM robot"
    )
    parser.add_argument(
        "--ckpt_path", default="unused", help="Unused (server loads model)"
    )
    parser.add_argument("--host", default="localhost", help="OpenPI server host")
    parser.add_argument("--port", type=int, default=8000, help="OpenPI server port")
    parser.add_argument(
        "--action_horizon", type=int, default=10, help="Actions per inference"
    )
    parser.add_argument("--action_hz", type=float, default=30.0)
    parser.add_argument("--prompt", default="", help="Language instruction")
    parser.add_argument("--camera_config_file", default="./config/camera_config.json")
    parser.add_argument(
        "--calibration_file", default="./config/calibration_results.json"
    )
    parser.add_argument("--stereo_method", default="zed", choices=["zed", "ffs"])
    parser.add_argument("--depth_mode", default="NEURAL_LIGHT")

    args = parser.parse_args()

    from raiden.inference import RaidenInferenceLoop

    bridge = OpenPiBridge(action_horizon=args.action_horizon)

    loop = RaidenInferenceLoop(
        bridge=bridge,
        ckpt_path=args.ckpt_path,
        action_hz=args.action_hz,
        bridge_kwargs={
            "host": args.host,
            "port": args.port,
            "action_horizon": args.action_horizon,
            "prompt": args.prompt,
        },
        camera_config_file=args.camera_config_file,
        calibration_file=args.calibration_file,
        stereo_method=args.stereo_method,
        depth_mode=args.depth_mode,
    )
    loop.run()


if __name__ == "__main__":
    main()
