{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "record_path = pathlib.Path(\"../policy_records\")\n",
    "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
    "\n",
    "records = []\n",
    "for i in range(num_steps):\n",
    "    record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
    "    records.append(record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"length of records\", len(records))\n",
    "print(\"keys in records\", records[0].keys())\n",
    "\n",
    "for k in records[0]:\n",
    "    print(f\"{k} shape: {records[0][k].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "\n",
    "def get_image(step: int, idx: int = 0):\n",
    "    img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
    "    return img[idx].transpose(1, 2, 0)\n",
    "\n",
    "\n",
    "def show_image(step: int, idx_lst: list[int]):\n",
    "    imgs = [get_image(step, idx) for idx in idx_lst]\n",
    "    return Image.fromarray(np.hstack(imgs))\n",
    "\n",
    "\n",
    "for i in range(2):\n",
    "    display(show_image(i, [0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "def get_axis(name, axis):\n",
    "    return np.array([record[name][axis] for record in records])\n",
    "\n",
    "\n",
    "# qpos is [..., 14] of type float:\n",
    "# 0-5: left arm joint angles\n",
    "# 6: left arm gripper\n",
    "# 7-12: right arm joint angles\n",
    "# 13: right arm gripper\n",
    "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
    "\n",
    "\n",
    "def make_data():\n",
    "    cur_dim = 0\n",
    "    in_data = {}\n",
    "    out_data = {}\n",
    "    for name, dim_size in names:\n",
    "        for i in range(dim_size):\n",
    "            in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
    "            out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
    "            cur_dim += 1\n",
    "    return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
    "\n",
    "\n",
    "in_data, out_data = make_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in in_data.columns:\n",
    "    data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
    "    data.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
