# Training pi0.5 on YAM data

This guide walks through training a pi0.5 policy on YAM bimanual-robot data using pre-converted vla_foundry WDS shards streamed directly from S3.

## Overview

```
S3 WDS tar shards ──▶ WDS loader (streams via pipe:aws s3 cp) ──▶ YAMInputs ──▶ train.py
```

The pipeline streams tar shards from S3 on-the-fly — no local data download needed. AWS credentials are resolved automatically via the environment (SSO, instance profile, etc).

## Prerequisites

### Environment setup (one-time)

The training scripts depend on the full openpi ML stack (`jax`, `flax`, `webdataset`, …), which is **not** installed by this repo's top-level `uv sync`. The recommended setup is to reuse the sibling **Pi0.5_yam** repo's venv:

```bash
# 1. Make sure Pi0.5_yam is cloned and its venv is set up.
cd /path/to/Pi0.5_yam && uv sync && cd -

# 2. Activate it for the current shell.
source /path/to/Pi0.5_yam/.venv/bin/activate
```

### Per-shell setup (every new terminal)

After activating the venv, tell Python to prefer **this repo's** vendored `openpi` package:

```bash
cd /path/to/sim-improvement
export PYTHONPATH="$PWD/src/openpi/src:$PYTHONPATH"
```

Verify:

```bash
python -c "import openpi; print(openpi.__file__)"
# should print: /path/to/sim-improvement/src/openpi/src/openpi/__init__.py
```

### AWS credentials

The WDS loader streams from S3 via `aws s3 cp`. Make sure your shell can authenticate to AWS:

```bash
aws s3 ls s3://tri-ml-datasets-uw2/raiden/smoketest-pi05/ --no-sign-request 2>/dev/null \
  || aws sts get-caller-identity  # confirm credentials work
```

## Quick start (3 commands)

```bash
# 1. Dry-run — validate the full pipeline (S3 streaming + transforms)
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local --dry-run

# 2. Compute normalization stats (required before training)
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local --max-frames 200

# 3. Train (dummy model, a few iters — validates end-to-end)
python src/openpi/scripts/train.py pi05_yam_local
```

## Step-by-step

### Step 1 — Dry-run (validate pipeline)

```bash
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local --dry-run
```

Expected output:
```
state: shape=(2, 14)  dtype=float32
actions: shape=(2, 15, 14)  dtype=float32
image/base_0_rgb: shape=(2, 384, 384, 3)
image/left_wrist_0_rgb: shape=(2, 384, 384, 3)
image/right_wrist_0_rgb: shape=(2, 384, 384, 3)
```

The full train.py dry-run (including model transforms like resize/tokenize/pad):

```bash
python src/openpi/scripts/train.py pi05_yam_local --dry-run
```

Expected output (after model transforms):
```
state: (2, 32)
actions: (2, 15, 32)
image/{base_0_rgb,left_wrist_0_rgb,right_wrist_0_rgb}: (2, 224, 224, 3)
```

### Step 2 — Compute normalization stats

```bash
python src/openpi/scripts/compute_norm_stats.py --config-name pi05_yam_local --max-frames 200
# → assets/pi05_yam_local/yam_local/norm_stats.json
```

Use `--max-frames` to cap the number of samples processed (faster for validation; omit for production).

### Step 3 — Local training (validation)

```bash
python src/openpi/scripts/train.py pi05_yam_local
```

This runs 4 training steps with a dummy model (tiny, fast). Confirms the full loop works: data streaming → transforms → forward/backward pass → checkpoint save.

Expected log output:
```
Step 0: grad_norm=3.27, loss=2.21, param_norm=477.80
Step 1: grad_norm=1.82, loss=2.11, param_norm=477.80
Step 2: grad_norm=1.62, loss=2.04, param_norm=477.80
Step 3: grad_norm=1.56, loss=2.08, param_norm=477.80
```

### Step 4 — SageMaker (production)

```bash
python sagemaker/launch_training.py --config pi05_yam \
    --user <your-username> \
    --instance_type p5
```

The `pi05_yam` config uses the full pi0.5 model with pretrained weights from `gs://openpi-assets/checkpoints/pi05_droid_jointpos/params`.

## Configs at a glance

| Config | Data source | Model | Use case |
|---|---|---|---|
| `pi05_yam` | S3 WDS shards | Full pi0.5 + pretrained weights | SageMaker production |
| `pi05_yam_local` | S3 WDS shards | **Dummy** pi0.5 variants, `NoOpWeightLoader` | Local validation (small GPU) |
| `debug_pi05_yam` | `FakeDataConfig` (synthetic) | Dummy pi0.5 | Model-wiring smoke test (no data) |

## Data format (vla_foundry WDS shards on S3)

Current training data:
```
s3://tri-ml-datasets-uw2/raiden/smoketest-pi05/Sort_objects_lf/shards/
├── manifest.jsonl          # 894 shards × 100 samples = ~89,400 training samples
├── shard_000000.tar
├── shard_000001.tar
└── ...
```

Each tar sample contains:
- `{uuid}.{camera}_t0.jpg` — 384×384 JPEG (anchor frame)
- `{uuid}.{camera}_t-1.jpg` — previous frame (not used by current config)
- `{uuid}.lowdim.npz` — temporal window of 21 steps (`past=1, future=19`)
- `{uuid}.metadata.json` — `anchor_relative_idx`, episode info
- `{uuid}.language_instructions.json` — `{"original": [...]}`

### Lowdim fields used for training

| Field | Shape | Role |
|---|---|---|
| `robot__actual__joint_position__left::yam` | (21, 7) | Left arm state (proprioception) |
| `robot__actual__joint_position__right::yam` | (21, 7) | Right arm state (proprioception) |
| `robot__desired__joint_position__left::yam` | (21, 7) | Left arm action (desired joints) |
| `robot__desired__joint_position__right::yam` | (21, 7) | Right arm action (desired joints) |

`YAMInputs` concatenates left+right → 14D, takes the anchor timestep for state, and crops actions to `action_horizon=15` steps (anchor + 14 future).

## Key design decisions

- **Action space:** 14D absolute joint targets — `[left_j0..5, left_grip, right_j0..5, right_grip]`. No delta conversion applied during training.
- **Cameras:** `scene_camera`, `left_wrist_camera`, `right_wrist_camera` → model slots `base_0_rgb`, `left_wrist_0_rgb`, `right_wrist_0_rgb`.
- **Temporal config:** `lowdim_past_timesteps=1`, `lowdim_future_timesteps=14`. The shards store 21-step windows; the loader crops to 16 (past+1+future). Actions are sliced from anchor onward (15 steps).
- **Image handling:** Shards store 384×384 JPEGs. Model transforms resize to 224×224.
- **S3 streaming:** No local download needed. The WDS loader uses `pipe:aws s3 cp {shard}.tar -`.

## Troubleshooting

- **`ModuleNotFoundError: No module named 'jax'`** — wrong interpreter. See [Prerequisites](#environment-setup-one-time).
- **`ValueError: Config 'pi05_yam_local' not found`** — wrong `openpi` on sys.path. See [Per-shell setup](#per-shell-setup-every-new-terminal).
- **`fatal error: An error occurred (403) when calling the HeadObject operation`** — AWS credentials not configured. Run `aws sts get-caller-identity` to diagnose.
- **`Norm stats not found`** — run Step 2 before training.
- **Shard count not divisible by num_workers** — set `num_workers=0` in the config or use a manifest subset.

## Converting raw YAM dumps (alternative path)

If you have raw per-frame YAM data (not pre-converted vla_foundry shards), you can convert it yourself:

```bash
python scripts/convert_yam_raw_to_wds.py \
    --input /path/to/raw-yam-dump \
    --output /path/to/yam-wds \
    --samples-per-shard 10 --stride 15 --past 0 --future 14
```

Note: shards produced by `convert_yam_raw_to_wds.py` use different field names (`yam__joint_pos`, `yam__action_joints`) than the pre-converted S3 shards. Use the S3 shards for training unless you have a specific reason to convert yourself.

## Files

| File | Purpose |
|---|---|
| [src/openpi/src/openpi/policies/yam_policy.py](src/openpi/src/openpi/policies/yam_policy.py) | `YAMInputs` / `YAMOutputs` transforms + field-name constants |
| [src/openpi/src/openpi/training/misc/sim_improvement_config.py](src/openpi/src/openpi/training/misc/sim_improvement_config.py) | `pi05_yam`, `pi05_yam_local`, `debug_pi05_yam` configs |
| [scripts/convert_yam_raw_to_wds.py](scripts/convert_yam_raw_to_wds.py) | Raw YAM dump → WDS shards (alternative to using pre-converted S3 data) |
| [src/openpi/scripts/compute_norm_stats.py](src/openpi/scripts/compute_norm_stats.py) | Norm-stats computation + `--dry-run` flag |
| [src/openpi/scripts/train.py](src/openpi/scripts/train.py) | Training entry + `--dry-run` flag |
