import dataclasses
import enum
import logging
import socket
import h5py
import tyro
import numpy as np
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config



@dataclasses.dataclass
class Args:
    """Arguments for the serve_policy script."""
    dataset: str 
    demo: int = 21

    # Port to serve the policy on.
    port: int = 8000

class ReplayPolicy(_policy.Policy):
    def __init__(self, dataset: str, demo: int):
        self.hdf5_file = h5py.File(dataset, "r")
        self.demo = demo
        self.actions = self.hdf5_file[f"data/demo_{self.demo}"]["action/droid_joint_pos_action"][()]
        self.external_camera_rgb = self.hdf5_file[f"data/demo_{self.demo}"]["obs/vision/external_camera"][()]
        self.wrist_camera_rgb = self.hdf5_file[f"data/demo_{self.demo}"]["obs/vision/wrist_camera"][()]
        self.idx = 0

    def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict:
        outputs = {}
        action = self.actions[self.idx]
        self.idx += 1
        outputs["actions"] = action.reshape(1, -1)
        outputs["external_camera"] = self.external_camera_rgb[self.idx]
        outputs["wrist_camera"] = self.wrist_camera_rgb[self.idx]
        return outputs




def create_policy(args: Args) -> _policy.Policy:
    """Create a policy from the given arguments."""
    return ReplayPolicy(args.dataset, args.demo)


def main(args: Args) -> None:
    policy = create_policy(args)
    # policy_metadata = policy.metadata

    hostname = socket.gethostname()
    local_ip = socket.gethostbyname(hostname)
    logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)

    server = websocket_policy_server.WebsocketPolicyServer(
        policy=policy,
        host="0.0.0.0",
        port=args.port,
        # metadata=policy_metadata,
    )
    server.serve_forever()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    main(tyro.cli(Args))
