# SVD LoRA fine-tuning on robot video clips

Train LoRA adapters for the Stable Video Diffusion (SVD) image-to-video model on random N-frame windows from robot episodes.

## Setup

- Install deps and ensure `checkpoints/svd.safetensors` exists.
- Dataset: directory of episode folders (each with `000000.png`, `000001.png`, ...).

## Usage

```bash
# Single GPU (e.g. GPU 5)
CUDA_VISIBLE_DEVICES=5 python scripts/training/train_svd_lora.py \
  --dataset_root /path/to/parsed_episodes \
  --ckpt checkpoints/svd.safetensors \
  --output_dir outputs/svd_lora_robot \
  --lora_rank 4 --batch_size 1 --steps 2000 --save_every 500

# Multiple GPUs 5–9 (single process uses GPU 5; for DDP you’d need to add support)
CUDA_VISIBLE_DEVICES=5,6,7,8,9 python scripts/training/train_svd_lora.py \
  --dataset_root /data/cameron/keygrip/scratch/parsed_longer_school_slowgrasp_feb10 \
  --output_dir outputs/svd_lora_robot \
  --num_frames 14 --target_num_frames 14 \
  --batch_size 1 --steps 2000
```

## Inference with LoRA

```bash
python scripts/sampling/simple_video_sample.py \
  --input_path /path/to/first_frame.png \
  --lora_path outputs/svd_lora_robot/lora_final.pt
```

## Known issues and status

- **8-frame training**: Can hit a batch shape mismatch (64 vs 8) in a downstream ResBlock; conditioning and VideoTransformerBlock context expansion are fixed for the attention path. Still investigating which layer emits batch 64.
- **14-frame training**: Forward passes the attention path but runs **OOM** on a 48GB GPU (576×1024). Use `--num_frames 14 --target_num_frames 14` only with more VRAM, gradient checkpointing, or lower resolution.
- **Conditioning**: SVD uses different batch/spatial layouts in different blocks. Context is expanded in `VideoTransformerBlock` and in `MemoryEfficientCrossAttention` when batch sizes differ so that training can run for the attention path.

## Outputs

- `outputs/svd_lora_robot/lora_final.pt` – LoRA weights (and `lora_step_*.pt` if `--save_every` is set).
- Load with `--lora_path` in `simple_video_sample.py` (base model is still required via the config’s `ckpt_path`).
