"""WebSocket policy server for VLA Foundry DiffusionPolicy.

Accepts simple observation dicts (images + joint state + prompt) from the
raiden bridge and returns joint actions.

Usage::

    cd ~/vla_foundry
    uv run --group inference python serve_vla_foundry.py \\
        --checkpoint_directory /home/robot-lab/checkpoints/stack_blocks_vla_foundry/ \\
        --checkpoint_name checkpoint_11 \\
        --num_flow_steps 10 \\
        --open_loop_steps 26 \\
        --guidance_scale 1.0 \\
        --port 8200
"""

import argparse
import asyncio
import logging
import time
import uuid

import msgpack
import numpy as np
import websockets
from pydrake.math import RigidTransform, RotationMatrix
from robot_gym.multiarm_spaces import (
    CURRENT_VERSION,
    CameraImageSet,
    CameraRgbImage,
    MultiarmObservation,
    PosesAndGrippers,
    PosesAndGrippersActualAndDesired,
)

from vla_foundry.inference.robotics.inference_policy import InferenceDiffusionPolicy

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

_RIGHT_ARM = "right::yam"
_LEFT_ARM = "left::yam"
_RIGHT_GRIPPER = "right::yam_hand"
_LEFT_GRIPPER = "left::yam_hand"


def _unpack_obs(data: dict) -> MultiarmObservation:
    """Convert the bridge's simple dict into a MultiarmObservation."""
    from vla_foundry.data.robotics.utils import rot_6d_to_matrix

    q_r = np.array(data["joint_pos_r"], dtype=np.float32)
    q_l = np.array(data["joint_pos_l"], dtype=np.float32)
    timestamp = data.get("timestamp", time.monotonic())
    ee_poses = data.get("ee_poses", {})

    grip_r = float(q_r[6])
    grip_l = float(q_l[6])

    # EE poses from FK (for proprioception)
    poses = {}
    for side, arm_key in (("right", _RIGHT_ARM), ("left", _LEFT_ARM)):
        xyz_key = f"robot__actual__poses__{side}::yam__xyz"
        rot6d_key = f"robot__actual__poses__{side}::yam__rot_6d"
        if xyz_key in ee_poses and rot6d_key in ee_poses:
            xyz = np.array(ee_poses[xyz_key], dtype=np.float64)
            rot6d = np.array(ee_poses[rot6d_key], dtype=np.float64)
            R = rot_6d_to_matrix(rot6d)
            poses[arm_key] = RigidTransform(R=RotationMatrix(R), p=xyz)

    # Grippers from EE poses if available
    grip_r_key = "robot__actual__grippers__right::yam_hand"
    grip_l_key = "robot__actual__grippers__left::yam_hand"
    if grip_r_key in ee_poses:
        grip_r = float(np.asarray(ee_poses[grip_r_key]).flat[0])
    if grip_l_key in ee_poses:
        grip_l = float(np.asarray(ee_poses[grip_l_key]).flat[0])

    actual = PosesAndGrippers(
        poses=poses,
        grippers={_RIGHT_GRIPPER: grip_r, _LEFT_GRIPPER: grip_l},
        joint_position={_RIGHT_ARM: q_r[:7], _LEFT_ARM: q_l[:7]},
        timestamp_data=timestamp,
        timestamp_received=time.monotonic(),
    )
    desired = PosesAndGrippers(
        poses=poses,
        grippers={_RIGHT_GRIPPER: grip_r, _LEFT_GRIPPER: grip_l},
        joint_position={_RIGHT_ARM: q_r[:7], _LEFT_ARM: q_l[:7]},
        timestamp_data=timestamp,
    )
    robot = PosesAndGrippersActualAndDesired(
        actual=actual, desired=desired, version=CURRENT_VERSION
    )

    # Camera images
    visuo = {}
    for cam_name, cam_data in data.get("images", {}).items():
        shape = cam_data["shape"]
        image = np.frombuffer(cam_data["data"], dtype=np.uint8).reshape(shape).copy()
        # Dummy intrinsics/extrinsics — model doesn't use them
        K = np.eye(3, dtype=np.float64)
        X_TC = RigidTransform()
        rgb = CameraRgbImage(array=image, K=K, X_TC=X_TC, timestamp=timestamp)
        visuo[cam_name] = CameraImageSet(rgb=rgb, depth=None)

    return MultiarmObservation(
        robot=robot,
        visuo=visuo,
        timestamp_packaged=timestamp,
        language_instruction=data.get("prompt", ""),
    )


def _action_to_array(action_pg) -> list:
    """Convert PosesAndGrippers to a flat action list.

    Joint-space model: returns [l(7), r(7)] = 14D
    EE pose model: returns [l_xyz(3), r_xyz(3), l_rot6d(6), r_rot6d(6), l_grip(1), r_grip(1)] = 20D
    """
    from vla_foundry.data.robotics.utils import get_rot_6d, get_xyz

    # Joint-space model: joint_position is populated
    if action_pg.joint_position and _LEFT_ARM in action_pg.joint_position:
        l_joint = np.asarray(action_pg.joint_position[_LEFT_ARM], dtype=np.float32).flatten()[:7]
        r_joint = np.asarray(action_pg.joint_position[_RIGHT_ARM], dtype=np.float32).flatten()[:7]
        return np.concatenate([l_joint, r_joint]).astype(np.float32).tolist()

    # EE pose model: poses + grippers
    l_pose = action_pg.poses[_LEFT_ARM]
    r_pose = action_pg.poses[_RIGHT_ARM]
    l_grip = float(np.asarray(action_pg.grippers[_LEFT_GRIPPER]).flat[0])
    r_grip = float(np.asarray(action_pg.grippers[_RIGHT_GRIPPER]).flat[0])
    return np.concatenate([
        get_xyz(l_pose),      # (3,)
        get_xyz(r_pose),      # (3,)
        get_rot_6d(l_pose),   # (6,)
        get_rot_6d(r_pose),   # (6,)
        [l_grip],             # (1,)
        [r_grip],             # (1,)
    ]).astype(np.float32).tolist()


class VlaFoundryServer:
    def __init__(self, policy: InferenceDiffusionPolicy):
        self._policy = policy
        self._client_id = uuid.uuid4()
        self._initialized = False
        self._step_count = 0
        self._last_action_pg = None

    def reset(self):
        self._policy.reset_batch({self._client_id: 0})
        self._initialized = False
        self._step_count = 0
        self._last_action_pg = None
        logger.info("Policy reset.")

    def _set_past_action(self):
        """Write the last executed action into the t-1 slot of the action buffer."""
        import copy

        from vla_foundry.inference.robotics.chiral_inference_client import (
            _poses_and_grippers_to_action_dict,
        )

        da = self._policy.data_adapter.get(self._client_id)
        if da is None or da.num_past_timesteps == 0 or self._last_action_pg is None:
            return

        action_dict = _poses_and_grippers_to_action_dict(self._last_action_pg, da.action_fields)
        if not action_dict:
            return

        # Match shapes to the anchor slot so np.stack doesn't break
        ref = da.action_buffer[da.num_past_timesteps]
        for field, value in action_dict.items():
            if field in ref:
                action_dict[field] = np.asarray(value, dtype=np.float64).reshape(
                    np.asarray(ref[field]).shape
                )
        da.action_buffer[da.num_past_timesteps - 1] = copy.deepcopy(action_dict)

    def step(self, data: dict) -> list:
        multiarm_obs = _unpack_obs(data)

        if not self._initialized:
            # Force the first step to trigger inference by setting the counter
            # to open_loop_steps - 1. This fills the action buffer immediately.
            self._policy.current_open_loop_step[self._client_id] = self._policy.open_loop_steps - 1
            self._policy.step_batch({self._client_id: multiarm_obs})
            self._initialized = True

        # Write last executed action into t-1 slot before inference triggers
        cur_step = self._policy.current_open_loop_step.get(self._client_id, 0)
        if cur_step % self._policy.open_loop_steps == self._policy.open_loop_steps - 1:
            self._set_past_action()

        action_pg = self._policy.step(multiarm_obs, self._client_id)
        self._last_action_pg = action_pg
        self._step_count += 1
        return _action_to_array(action_pg)

    async def handle_client(self, websocket):
        logger.info("Client connected.")
        try:
            async for message in websocket:
                data = msgpack.unpackb(message, raw=False)
                msg_type = data.get("type")

                if msg_type == "reset":
                    self.reset()
                    await websocket.send(msgpack.packb({"type": "reset_ack"}))

                elif msg_type == "step":
                    t0 = time.perf_counter()
                    action = self.step(data)
                    elapsed_ms = (time.perf_counter() - t0) * 1000

                    if self._step_count % 26 == 0:
                        logger.info(f"step={self._step_count}  infer_ms={elapsed_ms:.1f}")

                    await websocket.send(msgpack.packb({"type": "action", "action": action}))

        except websockets.ConnectionClosed:
            logger.info("Client disconnected.")


async def serve(server: VlaFoundryServer, host: str, port: int):
    logger.info(f"Serving on ws://{host}:{port}")
    async with websockets.serve(server.handle_client, host, port, max_size=100_000_000):
        await asyncio.Future()


def main():
    parser = argparse.ArgumentParser(description="VLA Foundry Policy Server")
    parser.add_argument("--checkpoint_directory", type=str, required=True)
    parser.add_argument("--checkpoint_name", type=str, default=None)
    parser.add_argument("--num_flow_steps", type=int, default=10)
    parser.add_argument("--open_loop_steps", type=int, default=26)
    parser.add_argument("--guidance_scale", type=float, default=1.0)
    parser.add_argument("--sigma_d_obs", type=float, default=0.2)
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8200)
    args = parser.parse_args()

    logger.info(f"Loading policy from {args.checkpoint_directory} ...")
    policy = InferenceDiffusionPolicy(
        checkpoint_directory=args.checkpoint_directory,
        checkpoint_name=args.checkpoint_name,
        device="cuda",
        num_flow_steps=args.num_flow_steps,
        open_loop_steps=args.open_loop_steps,
        lag_compensation=0.0,
        guidance_scale=args.guidance_scale,
        sigma_d_obs=args.sigma_d_obs,
    )
    logger.info("Policy loaded. Starting server...")

    server = VlaFoundryServer(policy)
    asyncio.run(serve(server, args.host, args.port))


if __name__ == "__main__":
    main()
