import numpy as np
import mediapy as mp
from openpi_client import websocket_client_policy, image_tools

class DroidJointPosClient():
    def __init__(self, host: str, port: int, open_loop_horizon: int) -> None:
        self.client = websocket_client_policy.WebsocketClientPolicy(
            host=host, port=port
        )
        self.actions_from_chunk_completed = 0
        self.pred_action_chunk = None
        self.open_loop_horizon = open_loop_horizon

    @property
    def rerender(self) -> bool:
        return (
            self.actions_from_chunk_completed == 0
            or self.actions_from_chunk_completed >= self.open_loop_horizon
        )

    def visualize(self, request: dict):
        """
        Return the camera views how the model sees it
        """
        curr_obs = self._extract_observation(request)
        base_img = image_tools.resize_with_pad(curr_obs["right_image"], 224, 224)
        wrist_img = image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224)
        combined = np.concatenate([base_img, wrist_img], axis=1)
        return combined

    def reset(self):
        self.actions_from_chunk_completed = 0
        self.pred_action_chunk = None

    def infer(
        self, obs: dict, instruction: str, return_viz: bool = False
    ) -> tuple[np.ndarray, np.ndarray | None]:
        """
        Infer the next action from the policy in a server-client setup
        """
        both = None
        ret = {}
        if (
            self.actions_from_chunk_completed == 0
            or self.actions_from_chunk_completed >= self.open_loop_horizon
        ):
            curr_obs = self._extract_observation(obs)

            self.actions_from_chunk_completed = 0
            exterior_image = image_tools.resize_with_pad(
                curr_obs["right_image"], 224, 224
            )
            wrist_image = image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224)
            request_data = {
                "observation/exterior_image_1_left": exterior_image,
                "observation/wrist_image_left": wrist_image,
                "observation/joint_position": curr_obs["joint_position"],
                "observation/gripper_position": curr_obs["gripper_position"],
                # "observation/state": curr_obs["joint_position"],
                "prompt": instruction,
            }
            server_response = self.client.infer(request_data)
            self.pred_action_chunk = server_response["actions"]
            both = np.concatenate([exterior_image, wrist_image], axis=1)

        if return_viz and both is None:
            curr_obs = self._extract_observation(obs)
            both = np.concatenate(
                [
                    image_tools.resize_with_pad(curr_obs["right_image"], 224, 224),
                    image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
                ],
                axis=1,
            )

        if self.pred_action_chunk is None:
            raise ValueError("No action chunk predicted")

        action = self.pred_action_chunk[self.actions_from_chunk_completed]
        self.actions_from_chunk_completed += 1

        # binarize gripper action
        # if action[-1].item() > 0.5:
        #     action = np.concatenate([action[:-1], np.ones((1,))])
        # else:
        #     action = np.concatenate([action[:-1], np.zeros((1,))])
        # action = action[:-1]
        # print(f"gripper: {action[-1]}")

        return action, both

    def _extract_observation(self, obs_dict):
        # Assign images
        # right_image = obs_dict["vision"]["external_camera"].clone().detach().cpu().numpy()[0]
        # wrist_image = obs_dict["vision"]["wrist_camera"].clone().detach().cpu().numpy()[0]
        right_image = obs_dict["splat"]["external_cam"]
        wrist_image = obs_dict["splat"]["wrist_cam"]


        # Capture proprioceptive state
        # vision_obs = obs_dict["vision"]
        vision_obs = obs_dict["policy"]
        joint_position = vision_obs["arm_joint_pos"].clone().detach().cpu().numpy()[0]
        gripper_position = vision_obs["gripper_pos"].clone().detach().cpu().numpy()[0]

        return {
            "right_image": right_image,
            "wrist_image": wrist_image,
            "gripper_position": gripper_position,
            "joint_position": joint_position,
        }