"""Fine-tune Stable Video Diffusion (SVD) on LIBERO videos with LoRA.

Loads the pretrained SVD img2vid model, adds LoRA adapters to the UNet,
and trains on LIBERO parsed frames. VAE and CLIP image encoder are frozen.

Usage:
    CUDA_VISIBLE_DEVICES=4 python video_training/svd_finetune/train.py \
        --data-root /data/libero/parsed_libero/libero_spatial \
        --svd-path /data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid \
        --log_wandb --run-name svd_libero_spatial
"""

import argparse
import sys
import tempfile
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm

from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel, EulerDiscreteScheduler
from diffusers.training_utils import compute_snr
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from peft import LoraConfig, get_peft_model

# UVA dataset for loading LIBERO frames
UVA_ROOT = Path(__file__).resolve().parent.parent / "unified_video_action"
sys.path.insert(0, str(UVA_ROOT))
from simple_uva.dataset import LiberoVideoDataset, collate_batch

SVD_DEFAULT_PATH = "/data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid"
NUM_FRAMES = 14  # SVD generates 14 frames
IMG_SIZE = 256   # resize LIBERO frames


def encode_image_clip(image_encoder, feature_extractor, pixel_values, device):
    """Encode conditioning image through CLIP.

    Args:
        pixel_values: (B, 3, H, W) in [-1, 1]
    Returns:
        image_embeddings: (B, 1, 1024) CLIP embeddings
    """
    # Convert from [-1,1] to [0,1] for CLIP processor
    images_01 = (pixel_values + 1.0) / 2.0
    # CLIP expects specific normalization — use feature_extractor
    # But since we already have tensors, do manual CLIP normalization
    clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
    clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
    # Resize to 224x224 for CLIP
    images_clip = F.interpolate(images_01, size=(224, 224), mode="bilinear", align_corners=False)
    images_clip = (images_clip - clip_mean) / clip_std

    with torch.no_grad():
        image_embeddings = image_encoder(images_clip).image_embeds  # (B, 1024)
    return image_embeddings.unsqueeze(1)  # (B, 1, 1024)


def main():
    p = argparse.ArgumentParser(description="Fine-tune SVD on LIBERO with LoRA")
    p.add_argument("--data-root", type=str, default="/data/libero/parsed_libero/libero_spatial")
    p.add_argument("--svd-path", type=str, default=SVD_DEFAULT_PATH)
    p.add_argument("--batch-size", type=int, default=1)
    p.add_argument("--gradient-accumulation", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=100)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--run-name", type=str, default="svd_libero_spatial")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--vis-every", type=int, default=200)
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--checkpoint-dir", type=str, default="video_training/svd_finetune/checkpoints")
    p.add_argument("--frame-stride", type=int, default=3)
    p.add_argument("--lora-rank", type=int, default=16)
    p.add_argument("--mixed-precision", action="store_true", default=True)
    args = p.parse_args()

    device = torch.device(args.device)
    dtype = torch.float16 if args.mixed_precision else torch.float32
    ckpt_dir = Path(args.checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.log_wandb:
        import wandb
        wandb.init(project="svd_finetune", config=vars(args), name=args.run_name, mode="online")

    # --- Load SVD components ---
    svd_path = Path(args.svd_path)
    print("Loading SVD components...")

    vae = AutoencoderKLTemporalDecoder.from_pretrained(str(svd_path), subfolder="vae", torch_dtype=dtype)
    vae.to(device).eval()
    for param in vae.parameters():
        param.requires_grad = False
    print(f"  VAE loaded (frozen)")

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(str(svd_path), subfolder="image_encoder", torch_dtype=dtype)
    image_encoder.to(device).eval()
    for param in image_encoder.parameters():
        param.requires_grad = False
    feature_extractor = CLIPImageProcessor.from_pretrained(str(svd_path), subfolder="feature_extractor")
    print(f"  CLIP image encoder loaded (frozen)")

    unet = UNetSpatioTemporalConditionModel.from_pretrained(str(svd_path), subfolder="unet", torch_dtype=dtype)
    unet.to(device)
    print(f"  UNet loaded: {sum(p.numel() for p in unet.parameters()):,} params")

    noise_scheduler = EulerDiscreteScheduler.from_pretrained(str(svd_path), subfolder="scheduler")

    # --- Add LoRA to UNet ---
    lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        target_modules=["to_q", "to_k", "to_v", "to_out.0"],
        lora_dropout=0.0,
    )
    unet = get_peft_model(unet, lora_config)
    unet.print_trainable_parameters()
    unet.train()

    # --- Dataset ---
    dataset = LiberoVideoDataset(
        root=args.data_root,
        num_frames=NUM_FRAMES,
        size=IMG_SIZE,
        frame_stride=args.frame_stride,
    )
    print(f"Dataset: {len(dataset)} episodes ({NUM_FRAMES} frames, stride {args.frame_stride})")
    loader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, collate_fn=collate_batch,
        pin_memory=True, persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    # --- Optimizer ---
    trainable_params = [p for p in unet.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled=args.mixed_precision)

    # --- Training ---
    global_step = 0
    for epoch in range(args.epochs):
        pbar = tqdm(loader, desc=f"epoch {epoch}")
        for batch_idx, batch in enumerate(pbar):
            # batch: (B, 3, T, H, W) in [-1, 1]
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape

            with torch.cuda.amp.autocast(enabled=args.mixed_precision):
                # 1. Encode conditioning image through CLIP
                cond_image = batch[:, :, 0]  # (B, 3, H, W)
                image_embeddings = encode_image_clip(
                    image_encoder, feature_extractor, cond_image, device
                )  # (B, 1, 1024)

                # 2. Encode all frames through VAE
                frames = rearrange(batch, "b c t h w -> (b t) c h w")
                with torch.no_grad():
                    latents = vae.encode(frames.to(dtype)).latent_dist.sample()  # (B*T, 4, H_lat, W_lat)
                latents = rearrange(latents, "(b t) c h w -> b t c h w", b=B)
                latents = latents * vae.config.scaling_factor

                # SVD uses (B, T, C, H, W) format
                # Conditioning: first frame latent repeated for all timesteps
                cond_latent = latents[:, 0:1].expand(-1, T, -1, -1, -1)  # (B, T, 4, H_lat, W_lat)

                # 3. Sample noise and timesteps
                noise = torch.randn_like(latents)  # (B, T, 4, H, W)
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (B,),
                    device=device, dtype=torch.long,
                )

                # 4. Add noise (forward diffusion)
                sigmas = noise_scheduler.sigmas[timesteps.cpu()].to(device=device, dtype=dtype)
                sigmas = sigmas.view(B, 1, 1, 1, 1)
                noisy_latents = latents + noise * sigmas  # (B, T, 4, H, W)

                # 5. Concatenate conditioning along channels dim: (B, T, 8, H, W)
                unet_input = torch.cat([noisy_latents, cond_latent], dim=2)
                added_time_ids = torch.tensor(
                    [[6.0, 127.0, 0.0]] * B, device=device, dtype=dtype
                )
                model_pred = unet(
                    unet_input, timesteps,
                    encoder_hidden_states=image_embeddings,
                    added_time_ids=added_time_ids,
                ).sample

                # 6. Loss: predict noise
                loss = F.mse_loss(model_pred, noise)
                loss = loss / args.gradient_accumulation

            scaler.scale(loss).backward()

            if (batch_idx + 1) % args.gradient_accumulation == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()

            pbar.set_postfix(loss=f"{loss.item() * args.gradient_accumulation:.4f}", step=global_step)

            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item() * args.gradient_accumulation}, step=global_step)

            # --- Visualization ---
            if global_step % args.vis_every == 0 and args.log_wandb and global_step > 0:
                import wandb
                import torchvision
                unet.eval()
                with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision):
                    # Quick vis: encode first frame, concat as cond, single-step denoise
                    cond_img = batch[:1, :, 0]  # (1, 3, H, W)
                    emb = encode_image_clip(image_encoder, feature_extractor, cond_img, device)
                    cond_lat = vae.encode(cond_img.to(dtype)).latent_dist.sample() * vae.config.scaling_factor
                    cond_rep = cond_lat.unsqueeze(1).expand(-1, T, -1, -1, -1)  # (1, T, 4, H_l, W_l)
                    noise = torch.randn(1, T, 4, cond_lat.shape[2], cond_lat.shape[3],
                                        device=device, dtype=dtype)
                    z = torch.cat([noise, cond_rep], dim=2)  # (1, T, 8, H_l, W_l)
                    t_vis = torch.tensor([0], device=device)
                    added_ids = torch.tensor([[6.0, 127.0, 0.0]], device=device, dtype=dtype)
                    pred = unet(z, t_vis, encoder_hidden_states=emb, added_time_ids=added_ids).sample
                    # Decode: pred is (1, T, 4, H_l, W_l)
                    pred_flat = rearrange(pred, "b t c h w -> (b t) c h w")
                    pred_flat = pred_flat / vae.config.scaling_factor
                    decoded = vae.decode(pred_flat, num_frames=T).sample  # (T, 3, H, W)
                    decoded = ((decoded.float().cpu() + 1.0) / 2.0).clamp(0, 1)
                    frames_np = (decoded.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                        tmp_path = f.name
                    torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                    wandb.log({"vis/predicted_video": wandb.Video(tmp_path, format="mp4")}, step=global_step)
                    Path(tmp_path).unlink(missing_ok=True)
                unet.train()

            # --- Checkpoint ---
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                unet.save_pretrained(str(ckpt_dir / "unet_lora"))
                torch.save({
                    "step": global_step,
                    "optimizer": opt.state_dict(),
                }, ckpt_dir / "latest.pt")
                print(f"Saved checkpoint at step {global_step}")

            global_step += 1

        print(f"Epoch {epoch} done, step={global_step}")

    if args.log_wandb:
        import wandb
        wandb.finish()
    print(f"Done. Checkpoints at {ckpt_dir}")


if __name__ == "__main__":
    main()
