"""VLA Foundry DiffusionPolicy bridge for YAM robot deployment.

Same pattern as the OpenPI bridge: connects to a policy server, sends
observations as a simple dict, gets back joint actions.

Usage::

    # Terminal 1: start the VLA Foundry policy server
    cd ~/vla_foundry
    uv run --group inference python serve_vla_foundry.py \\
        --checkpoint_directory /home/robot-lab/checkpoints/stack_blocks_vla_foundry/ \\
        --checkpoint_name checkpoint_11 \\
        --num_flow_steps 10 \\
        --open_loop_steps 26 \\
        --guidance_scale 1.0 \\
        --port 8200

    # Terminal 2: run raiden inference
    cd ~/raiden_pi05/raiden
    uv run rd infer \\
        --bridge vla_foundry_bridge:VlaFoundryBridge \\
        --ckpt_path unused \\
        --action_hz 30 \\
        --max_joint_delta 0.5 \\
        --resize_images '' \\
        --bridge-kwargs host=localhost port=8200 prompt='stack the blocks'

Server returns: [l_joints(6), l_grip(1), r_joints(6), r_grip(1)]
Raiden expects:  [r_joints(6), r_grip(1), l_joints(6), l_grip(1)]
"""

import time

import cv2
import msgpack
import numpy as np
from chiral.types import Observation
from websockets.sync.client import connect

from raiden.inference import ModelBridge

_IMG_SIZE = 384


class VlaFoundryBridge(ModelBridge):
    """ModelBridge that connects to a VLA Foundry policy server."""

    def __init__(self):
        self._ws = None
        self._prompt = ""
        self._step = 0
        self._t_sum = 0.0

    def load(self, ckpt_path: str, **kwargs) -> None:
        host = kwargs.get("host", "localhost")
        port = int(kwargs.get("port", 8200))
        self._prompt = kwargs.get("prompt", "")

        uri = f"ws://{host}:{port}"
        print(f"[vla_foundry_bridge] Connecting to {uri}")
        self._ws = connect(uri, max_size=100_000_000)

        self._ws.send(msgpack.packb({"type": "reset"}))
        resp = msgpack.unpackb(self._ws.recv(), raw=False)
        assert resp["type"] == "reset_ack"
        print(f"[vla_foundry_bridge] Connected. prompt='{self._prompt}'")

    def reset(self) -> None:
        if self._ws is not None:
            self._ws.send(msgpack.packb({"type": "reset"}))
            msgpack.unpackb(self._ws.recv(), raw=False)
        self._step = 0
        self._t_sum = 0.0

    def predict(self, obs: Observation) -> np.ndarray:
        t0 = time.perf_counter()

        obs_dict = self._preprocess(obs)
        self._ws.send(msgpack.packb(obs_dict, use_bin_type=True))
        resp = msgpack.unpackb(self._ws.recv(), raw=False)

        action = np.array(resp["action"], dtype=np.float32)

        if len(action) == 14:
            # Joint-space: server returns [l(7), r(7)], swap to raiden [r(7), l(7)]
            action_raiden = np.concatenate([action[7:14], action[0:7]])
        else:
            # EE pose (20D): already in raiden layout
            # [l_xyz(3), r_xyz(3), l_rot6d(6), r_rot6d(6), l_grip(1), r_grip(1)]
            action_raiden = action

        infer_ms = (time.perf_counter() - t0) * 1e3
        self._t_sum += infer_ms
        self._step += 1

        if self._step % 50 == 1:
            print(
                f"[vla_foundry_bridge] step={self._step:4d}  "
                f"rtt={infer_ms:.1f}ms  "
                f"avg={self._t_sum / self._step:.1f}ms"
            )

        return action_raiden

    def _preprocess(self, obs: Observation) -> dict:
        """Pack observation into a simple dict — images + joint state + prompt."""
        # Images: resize to 384x384, BGR uint8 as raw bytes
        images = {}
        for cam in obs.cameras:
            img = cv2.resize(cam.image, (_IMG_SIZE, _IMG_SIZE))
            images[cam.name] = {
                "data": img.tobytes(),
                "shape": list(img.shape),
            }

        # Joint state: (7,) float32 per arm
        r_pos = obs.proprios.get("follower_r_joint_pos", np.zeros(7, dtype=np.float32))
        l_pos = obs.proprios.get("follower_l_joint_pos", np.zeros(7, dtype=np.float32))

        # EE poses for proprioception (computed by raiden via FK)
        ee_poses = {}
        for key in [
            "robot__actual__poses__left::yam__xyz",
            "robot__actual__poses__left::yam__rot_6d",
            "robot__actual__grippers__left::yam_hand",
            "robot__actual__poses__right::yam__xyz",
            "robot__actual__poses__right::yam__rot_6d",
            "robot__actual__grippers__right::yam_hand",
        ]:
            val = obs.proprios.get(key)
            if val is not None:
                ee_poses[key] = np.asarray(val, dtype=np.float64).tolist()

        return {
            "type": "step",
            "images": images,
            "joint_pos_r": np.asarray(r_pos, dtype=np.float32).tolist(),
            "joint_pos_l": np.asarray(l_pos, dtype=np.float32).tolist(),
            "ee_poses": ee_poses,
            "timestamp": obs.timestamp,
            "prompt": self._prompt,
        }
