#!/bin/bash
# Train Cosmos Policy with multi-frame future prediction (4 consecutive frames)
# Uses GPUs 4,6,7,8

set -e

# Activate environment
eval "$(conda shell.bash hook)"
conda activate cosmos_policy_env

# Set library paths
export LD_LIBRARY_PATH="/data2/cameron/miniconda3/envs/cosmos_policy_env/lib/python3.10/site-packages/nvidia/cudnn/lib:/data2/cameron/miniconda3/envs/cosmos_policy_env/lib/python3.10/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH"

# Dataset location
export BASE_DATASETS_DIR="/data/cameron/vidgen/cosmos-policy"

# Use GPU 4 (single GPU for now - multi-GPU needs NCCL debugging)
export CUDA_VISIBLE_DEVICES=4
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# wandb setup
export WANDB_PROJECT="cosmos-policy-multi-frame"

# Number of GPUs
NGPUS=1

echo "============================================================"
echo "Training Cosmos Policy - Multi-Frame Future Prediction"
echo "  GPUs: $CUDA_VISIBLE_DEVICES ($NGPUS GPUs)"
echo "  Batch size: 4 per GPU x $NGPUS GPUs x 8 grad_accum = $((4 * NGPUS * 8))"
echo "  Future frames: t+7, t+14, t+21, t+28 (7-frame skip)"
echo "============================================================"

cd /data/cameron/vidgen/cosmos-policy

# Use the mock wrapper that sets up transformer_engine/flash_attn mocks
# before importing the training code
torchrun \
    --nproc_per_node=$NGPUS \
    --master_port=12345 \
    cosmos_policy/scripts/train_with_mocks.py \
    --config=cosmos_policy/config/config.py \
    -- \
    experiment="cosmos_predict2_2b_480p_libero_multi_frame" \
    trainer.grad_accum_iter=1 \
    optimizer=adamw
