# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Repository Overview

This repo contains:
- **SageMaker training infrastructure** for pi0/pi0.5 VLA models using vla_foundry WebDataset shards
- **OpenPI** (submodule at `src/openpi/`) — the core training framework from Physical Intelligence
- **RL training** scripts for robot manipulation tasks (SAC, BC+PPO)
- **Simulation and inference** code for LBM dual-arm robot environments

## Training a pi0.5 Model on SageMaker

### Config registration

All training configs live in [src/openpi/src/openpi/training/misc/sim_improvement_config.py](src/openpi/src/openpi/training/misc/sim_improvement_config.py), returned by `get_sim_improvement_configs()`. They are registered into the global config registry in [src/openpi/src/openpi/training/config.py](src/openpi/src/openpi/training/config.py).

The canonical pi0.5 + LBM example is `pi05_lbm_sim`:
- **Model:** `Pi0Config(action_horizon=15, pi05=True)` — pi0.5 treats the robot state as language tokens via discretization
- **Data:** `WDSDataConfig` pointing to a vla_foundry manifest JSONL on S3
- **Data transforms:** `LBMInputs` / `LBMOutputs` from [src/openpi/src/openpi/policies/lbm_policy.py](src/openpi/src/openpi/policies/lbm_policy.py) — these pack the proprioception keys and camera images into the model's expected `state / image / actions` format
- **Pretrained weights:** loaded from `gs://openpi-assets/checkpoints/pi05_droid_jointpos/params`

### Launching on SageMaker

```bash
# Build Docker image and submit to SageMaker queue
python sagemaker/launch_training.py --config pi05_lbm_sim

# Key flags
--instance_type p5       # p4de | p5 | p5en | p6
--user <your-username>   # sets ECR repo name and job owner tag
--priority 10
```

The launcher reads `sagemaker/secrets.env` for WANDB keys etc., builds `sagemaker/Dockerfile.openpi`, pushes to ECR, and submits via AWS Batch queue. Inside the container, `sagemaker/train.sh` runs:
1. `compute_norm_stats.py --config-name $CONFIG_NAME`
2. `train.py $CONFIG_NAME`

### Local training

```bash
python src/openpi/scripts/train.py pi05_lbm_sim
# or just compute norm stats
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_lbm_sim
```

## vla_foundry WebDataset (WDS) Shards

The WDS data pipeline is in [src/openpi/src/openpi/training/wds_dataset.py](src/openpi/src/openpi/training/wds_dataset.py).

- **Manifest:** a JSONL at `manifest_path` with `{"shard": "<s3-url>", "num_sequences": N}` per line
- **Streaming:** shards are piped via `aws s3 cp {shard}.tar -`; decoding is done selectively (only `camera_names` keys are JPEG-decoded)
- **Sequence layout:** each sample has images (JPEG per camera per timestep) and `lowdim.npz` with action/proprioception arrays; `_crop_sequence()` extracts `[lowdim_past_timesteps + 1 + lowdim_future_timesteps]` steps around a random anchor
- **Multi-worker loading:** `WebLoader` with `num_workers` parallel workers; shard count must be divisible by num_workers

Key `WDSDataConfig` fields:
| Field | Purpose |
|---|---|
| `manifest_path` | S3 path to JSONL shard manifest |
| `camera_names` | Camera keys to decode from each tar (e.g. `["wrist_right_minus", "scene_left_0"]`) |
| `action_fields` | Keys inside `lowdim.npz` to concatenate as actions |
| `proprioception_fields` | Keys inside `lowdim.npz` for robot state |
| `assets_dir` / `asset_id` | S3/GCS path for normalization stats |

## Running the YAM pi0.5 Pipeline

YAM data comes out of the robot stack as a **raw per-frame dump** (directory-per-sequence with per-frame pickles), not in vla_foundry WebDataset format. The training pipeline only understands WDS, so training on YAM data is a two-stage flow:

1. **Convert** raw YAM dumps → vla_foundry WDS tar shards (one-time, local).
2. **Train** against those WDS shards using `pi05_yam` (SageMaker) or `pi05_yam_local` (local dry-run / sanity checks).

### Raw YAM dump schema

Each raw sequence directory looks like:

```
<raw-dump>/
  <seq_id>/                   # e.g. "0000", "0001"
    rgb/<camera>/NNNNNNNNNN.png
    depth/...                 # ignored
    lowdim/NNNNNNNNNN.pkl     # dict: joints, action_joints, language_prompt, ...
    metadata.json
  metadata_shared.json
```

Per-frame `.pkl` has these keys (others are ignored):
- `joints`: `(14,)` float32 — actual joint positions
- `action_joints`: `(14,)` float32 — **absolute** commanded joints (joint_cmd format, NOT deltas)
- `language_prompt`: string — e.g. `"Sort cans shortest to tallest / left to right"`

Cameras: `scene_camera`, `left_wrist_camera`, `right_wrist_camera`.

### Step 1 — Convert raw YAM dump → WDS shards

```bash
python scripts/convert_yam_raw_to_wds.py \
    --input  /path/to/raw-yam-dump \
    --output /path/to/yam-wds \
    --samples-per-shard 10 \
    --stride 15 \
    --past 0 --future 14
```

Output: `manifest.jsonl` plus `shard_NNNNNN.tar` files in `<output>`. Each tar contains one sample per training anchor:

- `{sample_id}.{camera}_t0.jpg` — anchor-frame JPEG per camera (720×1280×3, JPEG q=95)
- `{sample_id}.lowdim.npz` — two arrays of shape `(past+1+future, 14)`:
  - `yam__joint_pos` (from `.pkl::joints`)
  - `yam__action_joints` (from `.pkl::action_joints`)
- `{sample_id}.metadata.json` — `{"anchor_relative_idx": 0}`
- `{sample_id}.language_instructions.json` — `{"original": "<prompt>"}`

Flags: `--stride` is frames between anchor picks per sequence; `--samples-per-shard` should be tuned so shard count is divisible by your training `num_workers`. Depth is skipped unconditionally.

### Step 2 — Point the training config at your shards

The repo ships two YAM configs in [sim_improvement_config.py](src/openpi/src/openpi/training/misc/sim_improvement_config.py):

- **`pi05_yam`** — production SageMaker run. S3 manifest + real `Pi0Config` + pretrained weights from `pi05_droid_jointpos`.
- **`pi05_yam_local`** — local dry-run. Local manifest path, dummy model variants, `NoOpWeightLoader`, `num_workers=0`.

For `pi05_yam`, set `manifest_path` to the S3 path of your converted WDS manifest. For `pi05_yam_local`, update the local `manifest_path` to your `<output>/manifest.jsonl` from step 1.

**Do not** modify `YAM_ACTION_FIELDS` / `YAM_PROPRIOCEPTION_FIELDS` / `YAM_CAMERA_NAMES` in [yam_policy.py](src/openpi/src/openpi/policies/yam_policy.py) — they match what the conversion script emits:

| Constant | Value |
|---|---|
| `YAM_PROPRIOCEPTION_FIELDS` | `["yam__joint_pos"]` |
| `YAM_ACTION_FIELDS` | `["yam__action_joints"]` |
| `YAM_CAMERA_NAMES` | `["scene_camera", "left_wrist_camera", "right_wrist_camera"]` |

### Step 3 — Dry-run the pipeline locally (recommended before SageMaker)

Dry-runs read a few batches through the full transform stack and exit without training. Two scripts, both with a `--dry-run` flag:

```bash
# Data pipeline only (WDS → YAMInputs → batched tensors)
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local --dry-run

# Full pipeline (adds model transforms: ResizeImages(224x224), TokenizePrompt, PadStatesAndActions)
python src/openpi/scripts/train.py pi05_yam_local --dry-run
```

Expected shapes on the `pi05_yam_local` config (`batch_size=2`, `action_horizon=15`):
- After data transforms: `state (2, 14)`, `actions (2, 15, 14)`, images `(2, 720, 1280, 3)` each
- After model transforms: `state (2, 32)`, `actions (2, 15, 32)`, images `(2, 224, 224, 3)` each (state/action padded to the model's 32D space; images resized)

### Step 4 — Compute normalization stats

```bash
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local
# → writes assets/pi05_yam_local/yam_local/norm_stats.json
```

For the SageMaker config, `compute_norm_stats.py --config-name pi05_yam` will first try to download cached stats from S3 (`remote_checkpoint_dir`) and upload fresh ones after computing.

### Step 5 — Launch on SageMaker

```bash
python sagemaker/launch_training.py --config pi05_yam \
    --user <your-username> \
    --instance_type p5        # p4de | p5 | p5en | p6
```

The launcher builds `sagemaker/Dockerfile.openpi`, pushes to ECR, and submits via AWS Batch. The container runs:
1. `compute_norm_stats.py --config-name pi05_yam` (skips if stats already in S3)
2. `train.py pi05_yam`

Credentials and WANDB keys come from `sagemaker/secrets.env`.

### YAM robot specifics

- **Policy file:** [yam_policy.py](src/openpi/src/openpi/policies/yam_policy.py) — `YAMInputs` packs 3 cameras (`scene_camera→base_0_rgb`, `left_wrist_camera→left_wrist_0_rgb`, `right_wrist_camera→right_wrist_0_rgb`) + 14D joint state; `YAMOutputs` slices `actions[:, :14]`.
- **Action space:** 14D absolute joint commands `[left_j0..5, left_gripper, right_j0..5, right_gripper]` — the `joint_cmd` format straight from the robot's leader-follower teleop. **No delta conversion is applied here**; downstream tooling (e.g., Pi0.5_yam inference) is responsible for any delta transforms.
- **Pretrained weights (`pi05_yam`):** initialized from `gs://openpi-assets/checkpoints/pi05_droid_jointpos/params`.
- **Dry-run configs:**
  - `pi05_yam_local` — real WDS pipeline with dummy model variants (tests data + transforms end-to-end on CPU/small GPU).
  - `debug_pi05_yam` — `FakeDataConfig` + dummy model (tests model wiring without data).

### What changed vs. the original codebase

Files added:
- [scripts/convert_yam_raw_to_wds.py](scripts/convert_yam_raw_to_wds.py) — conversion utility for the raw YAM dump format.
- [src/openpi/src/openpi/policies/yam_policy.py](src/openpi/src/openpi/policies/yam_policy.py) — `YAMInputs` / `YAMOutputs` transforms.

Files modified:
- [src/openpi/src/openpi/training/misc/sim_improvement_config.py](src/openpi/src/openpi/training/misc/sim_improvement_config.py) — added `pi05_yam`, `pi05_yam_local`, `debug_pi05_yam` configs.
- [src/openpi/scripts/compute_norm_stats.py](src/openpi/scripts/compute_norm_stats.py) — added `--dry-run` flag and S3-cache support for norm stats.
- [src/openpi/scripts/train.py](src/openpi/scripts/train.py) — added `--dry-run` flag (intercepted before `tyro` parsing).
- [src/openpi/src/openpi/training/mds_dataset.py](src/openpi/src/openpi/training/mds_dataset.py) — made `mosaicml-streaming` import lazy so non-MDS configs can import without the dep.

## Architecture Notes

### Pi0 vs Pi0.5

In `Pi0Config`, setting `pi05=True` switches to the Pi0.5 architecture where robot state is **discretized into language tokens** (vs. continuous state injection in Pi0). Pi0.5 also uses `adaRMSNorm` to inject the flow-matching timestep into the action expert. `max_token_len` increases to 200 (from 48).

### Policy transform convention

Each robot has `*Inputs` and `*Outputs` transform classes. `*Inputs` takes raw dataset samples (with `observation/proprioception/<key>`, `observation/images/<cam>_t0`, `actions/<key>`, `prompt`) and produces the flat `{state, image, image_mask, actions, prompt}` dict the model expects. `*Outputs` converts model action output back to robot action space.

### MDS vs WDS

Two dataset formats are supported:
- **WDS** (`WDSDataConfig`): vla_foundry tar shards, streamed directly from S3 — used for LBM sim data
- **MDS** (`MDSDataConfig`): MosaicML Streaming format with deterministic shuffling and elastic resume — used for DROID/UR rollout data

## Dependencies

```bash
uv sync   # install all dependencies
```

The project uses `uv` for package management. OpenPI, rsl_rl, and polaris are git submodules under `src/`.
