.venv/bin/torchrun --nproc_per_node=8 --nnodes=1 vla_foundry/main.py \
--model.type stable_diffusion \
--model.use_diffusers_unet False \
--model.use_diffusers_scheduler False \
--model.unet "include vla_foundry/config_presets/models/unet.yaml" \
--model.unet.image_size 128 \
--distributed.fsdp True \
--data.type image_caption \
--data.processor stable_diffusion \
--data.dataset_manifest ["s3://tri-ml-datasets/datasets/datacompdr_1b/manifest.jsonl"] \
--data.dataset_modality ["image_caption"] \
--data.dataset_weighting [1.0] \
--data.seq_len 64 \
--hparams.loss_function mse \
--hparams.per_gpu_batch_size 16 \
--hparams.global_batch_size 1024 \
--hparams.lr 1e-3 \
--hparams.lr_cooldown_end 1e-6 \
--total_train_samples 50_000_000 \
--num_checkpoints 10 \
--remote_sync s3://tri-ml-datasets/vla_foundry_scratch/models/stable_diffusion \
--wandb True \
"$@"