"""
Train SVD image-to-video with LoRA or full fine-tuning on robot video clips (random N-frame windows).
Usage (LoRA):
  CUDA_VISIBLE_DEVICES=5 python scripts/training/train_svd_lora.py \\
    --dataset_root /path/to/episodes --config scripts/sampling/configs/svd.yaml \\
    --ckpt checkpoints/svd.safetensors --output_dir outputs/svd_lora_robot \\
    --lora_rank 4 --steps 2000

Usage (full fine-tune all diffusion params, VAE/conditioner frozen):
  CUDA_VISIBLE_DEVICES=5 python scripts/training/train_svd_lora.py \\
    --dataset_root /path/to/episodes --config scripts/sampling/configs/svd.yaml \\
    --ckpt checkpoints/svd.safetensors --output_dir outputs/svd_finetune \\
    --full_finetune --steps 2000 --lr 1e-5
"""
import argparse
import math
import os
import sys
from pathlib import Path

import torch
from einops import repeat
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

# project root
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "../..")))

from sgm.modules.diffusionmodules.loss import StandardDiffusionLoss
from sgm.modules.lora import inject_lora, save_lora_state_dict
from sgm.util import default
from scripts.training.robot_video_dataset import RobotVideoDataset, collate_robot_video


def load_model_and_loss(config_path: str, ckpt_path: str, device: str, num_frames: int = 14, use_checkpoint: bool = False):
    config = OmegaConf.load(config_path)
    # Ensure checkpoint path
    config.model.params.ckpt_path = ckpt_path
    # Gradient checkpointing breaks backward when only LoRA has requires_grad; disable by default for training
    if "network_config" in config.model.params and "params" in config.model.params.network_config:
        config.model.params.network_config.params["use_checkpoint"] = use_checkpoint
    # Sampler guider needs num_frames (required for LinearPredictionGuider)
    if "sampler_config" in config.model.params and "params" in config.model.params.sampler_config:
        gc = config.model.params.sampler_config.params.get("guider_config", {})
        if gc and "params" in gc:
            gc.params["num_frames"] = num_frames
    # Add loss config for training if missing
    if not hasattr(config.model.params, "loss_fn_config") or config.model.params.loss_fn_config is None:
        config.model.params.loss_fn_config = {
            "target": "sgm.modules.diffusionmodules.loss.StandardDiffusionLoss",
            "params": {
                "sigma_sampler_config": {
                    "target": "sgm.modules.diffusionmodules.sigma_sampling.EDMSampling",
                    "params": {"p_mean": -1.2, "p_std": 1.2},
                },
                "loss_weighting_config": {
                    "target": "sgm.modules.diffusionmodules.loss_weighting.VWeighting",
                },
                "loss_type": "l2",
                "batch2model_keys": ["num_video_frames", "image_only_indicator"],
            },
        }
    from sgm.util import instantiate_from_config
    model = instantiate_from_config(config.model)
    model = model.to(device)
    return model, config


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_root", type=str, default="/data/cameron/keygrip/scratch/parsed_school_long_recap")
    parser.add_argument("--config", type=str, default="scripts/sampling/configs/svd.yaml")
    parser.add_argument("--ckpt", type=str, default="checkpoints/svd.safetensors")
    parser.add_argument("--output_dir", type=str, default="outputs/svd_lora_robot")
    parser.add_argument("--lora_rank", type=int, default=4)
    parser.add_argument("--lora_scale", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--steps", type=int, default=2000)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_frames", type=int, default=8, help="Frames per clip")
    parser.add_argument("--target_num_frames", type=int, default=8, help="Model num_frames (set same as --num_frames)")
    parser.add_argument("--cond_aug", type=float, default=0.02)
    parser.add_argument("--fps_id", type=int, default=6)
    parser.add_argument("--motion_bucket_id", type=int, default=127)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--save_every", type=int, default=500)
    parser.add_argument("--height", type=int, default=256, help="Video height")
    parser.add_argument("--width", type=int, default=448, help="Video width")
    parser.add_argument("--checkpoint", action="store_true", help="Enable gradient checkpointing to save memory (can break LoRA backward)")
    parser.add_argument("--full_finetune", action="store_true", help="Fine-tune all diffusion model parameters (no LoRA); VAE and conditioner stay frozen")
    args = parser.parse_args()

    if args.full_finetune and args.lr == 1e-4:
        args.lr = 1e-5
        print("Full fine-tune: using lr=1e-5 (override with --lr if needed)")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.output_dir, exist_ok=True)

    print("Loading model and loss...")
    model, config = load_model_and_loss(
        args.config, args.ckpt, device, num_frames=args.target_num_frames, use_checkpoint=getattr(args, "checkpoint", False)
    )
    loss_fn = model.loss_fn
    if loss_fn is None:
        raise RuntimeError("Model has no loss_fn; ensure loss_fn_config is in config")

    if args.full_finetune:
        # Full fine-tune: train diffusion model only; freeze first_stage and conditioner
        for p in model.first_stage_model.parameters():
            p.requires_grad = False
        for p in model.conditioner.parameters():
            p.requires_grad = False
        trainable_params = list(model.model.parameters())
        for p in trainable_params:
            p.requires_grad = True
        n_trainable = sum(p.numel() for p in trainable_params)
        print(f"Full fine-tune: training diffusion model only ({n_trainable:,} params); VAE and conditioner frozen")
    else:
        print("Injecting LoRA...")
        lora_params = inject_lora(model.model, rank=args.lora_rank, scale=args.lora_scale)
        model = model.to(device)
        trainable_params = lora_params
        n_trainable = sum(p.numel() for p in lora_params)
        print(f"LoRA parameters: {n_trainable:,}")

    model = model.to(device)

    print("Building dataset and dataloader...")
    print(f"  num_frames={args.num_frames}, target_num_frames={args.target_num_frames}, resolution={args.height}x{args.width}, batch_size={args.batch_size}")
    dataset = RobotVideoDataset(
        dataset_root=args.dataset_root,
        num_frames=args.num_frames,
        target_num_frames=args.target_num_frames,
        height=args.height,
        width=args.width,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=lambda b: collate_robot_video(
            b,
            cond_aug=args.cond_aug,
            fps_id=args.fps_id,
            motion_bucket_id=args.motion_bucket_id,
        ),
        pin_memory=True,
    )

    optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.01)

    model.train()
    step = 0
    pbar = tqdm(total=args.steps, desc="Train")
    while step < args.steps:
        for batch in dataloader:
            if step >= args.steps:
                break
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            x = model.get_input(batch)
            B, T, C, H, W = x.shape
            # First stage encoder expects (N, C, H, W); reshape video to (B*T, C, H, W)
            x = x.reshape(B * T, C, H, W)
            with torch.no_grad():
                x = model.encode_first_stage(x)
            # Repeat conditioning for each of B*T latent frames
            batch["cond_frames_without_noise"] = batch["cond_frames_without_noise"].repeat_interleave(T, dim=0).squeeze(1)
            batch["cond_frames"] = batch["cond_frames"].repeat_interleave(T, dim=0).squeeze(1)
            batch["fps_id"] = batch["fps_id"].repeat_interleave(T, dim=0)
            batch["motion_bucket_id"] = batch["motion_bucket_id"].repeat_interleave(T, dim=0)
            batch["cond_aug"] = batch["cond_aug"].repeat_interleave(T, dim=0)
            batch["image_only_indicator"] = batch["image_only_indicator"].repeat_interleave(T, dim=0)
            # Run conditioner once
            cond = model.conditioner(batch)
            # VideoTransformerBlock now expands (B*T, 1, dim) to (B*S, 1, dim) internally
            batch["global_step"] = step
            loss = loss_fn._forward(model.model, model.denoiser, cond, x, batch).mean()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
            optimizer.step()
            pbar.set_postfix(loss=loss.item())
            pbar.update(1)
            step += 1
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if step > 0 and step % args.save_every == 0:
                if args.full_finetune:
                    path = os.path.join(args.output_dir, f"finetune_step_{step}.pt")
                    torch.save(model.model.state_dict(), path)
                    tqdm.write(f"Saved {path}")
                else:
                    path = os.path.join(args.output_dir, f"lora_step_{step}.pt")
                    save_lora_state_dict(model.model, path, rank=args.lora_rank)
                    tqdm.write(f"Saved {path}")
    pbar.close()
    if args.full_finetune:
        path = os.path.join(args.output_dir, "finetune_final.pt")
        torch.save(model.model.state_dict(), path)
        print("Done. Full fine-tune weights saved to", path)
    else:
        save_lora_state_dict(model.model, os.path.join(args.output_dir, "lora_final.pt"), rank=args.lora_rank)
        print("Done. LoRA weights saved to", args.output_dir)


if __name__ == "__main__":
    main()
