{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [],
            "source": [
                "import dataclasses\n",
                "\n",
                "import jax\n",
                "\n",
                "from openpi.models import model as _model\n",
                "from openpi.policies import droid_policy\n",
                "from openpi.policies import policy_config as _policy_config\n",
                "from openpi.shared import download\n",
                "from openpi.training import config as _config\n",
                "from openpi.training import data_loader as _data_loader"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Policy inference\n",
                "\n",
                "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "config = _config.get_config(\"pi0_fast_droid\")\n",
                "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
                "\n",
                "# Create a trained policy.\n",
                "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
                "\n",
                "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
                "example = droid_policy.make_droid_example()\n",
                "result = policy.infer(example)\n",
                "\n",
                "# Delete the policy to free up memory.\n",
                "del policy\n",
                "\n",
                "print(\"Actions shape:\", result[\"actions\"].shape)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Working with a live model\n",
                "\n",
                "\n",
                "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "config = _config.get_config(\"pi0_aloha_sim\")\n",
                "\n",
                "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
                "key = jax.random.key(0)\n",
                "\n",
                "# Create a model from the checkpoint.\n",
                "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
                "\n",
                "# We can create fake observations and actions to test the model.\n",
                "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
                "\n",
                "# Sample actions from the model.\n",
                "loss = model.compute_loss(key, obs, act)\n",
                "print(\"Loss shape:\", loss.shape)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Reduce the batch size to reduce memory usage.\n",
                "config = dataclasses.replace(config, batch_size=2)\n",
                "\n",
                "# Load a single batch of data. This is the same data that will be used during training.\n",
                "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
                "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
                "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
                "obs, act = next(iter(loader))\n",
                "\n",
                "# Sample actions from the model.\n",
                "loss = model.compute_loss(key, obs, act)\n",
                "\n",
                "# Delete the model to free up memory.\n",
                "del model\n",
                "\n",
                "print(\"Loss shape:\", loss.shape)"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": ".venv",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.9"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
