"""Fine-tune Stable Video Diffusion (SVD) on LIBERO videos with LoRA.

Loads the pretrained SVD img2vid model, adds LoRA adapters to the UNet,
and trains on LIBERO parsed frames. VAE and CLIP image encoder are frozen.

Usage:
    CUDA_VISIBLE_DEVICES=4 python video_training/svd_finetune/train.py \
        --data-root /data/libero/parsed_libero/libero_spatial \
        --svd-path /data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid \
        --log_wandb --run-name svd_libero_spatial
"""

import argparse
import sys
import tempfile
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm

from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel, EulerDiscreteScheduler
from diffusers.training_utils import compute_snr
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from peft import LoraConfig, get_peft_model

# UVA dataset for loading LIBERO frames
UVA_ROOT = Path(__file__).resolve().parent.parent / "unified_video_action"
sys.path.insert(0, str(UVA_ROOT))
from simple_uva.dataset import LiberoVideoDataset, collate_batch, _natural_sort_key
from torch.utils.data import Dataset

SVD_DEFAULT_PATH = "/data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid"
NUM_FRAMES = 14  # SVD generates 14 frames
IMG_HEIGHT = 512   # Reduced from native 576x1024 to fit in GPU memory
IMG_WIDTH = 512


class LiberoVideoDatasetRect(Dataset):
    """Like LiberoVideoDataset but outputs non-square (H, W) frames for SVD."""

    def __init__(self, root, num_frames=NUM_FRAMES, height=IMG_HEIGHT, width=IMG_WIDTH,
                 frame_stride=1, max_samples=None):
        from pathlib import Path
        self.root = Path(root).resolve()
        self.num_frames = num_frames
        self.height = height
        self.width = width
        self.frame_stride = frame_stride
        self.min_frames = (num_frames - 1) * frame_stride + 1
        self.episodes = []
        for task_dir in sorted(self.root.iterdir()):
            if not task_dir.is_dir():
                continue
            for demo_dir in sorted(task_dir.iterdir()):
                if not demo_dir.is_dir():
                    continue
                frames_dir = demo_dir / "frames"
                if not frames_dir.is_dir():
                    continue
                pngs = sorted(frames_dir.glob("*.png"), key=_natural_sort_key)
                if len(pngs) >= self.min_frames:
                    self.episodes.append(pngs)
        if max_samples is not None:
            self.episodes = self.episodes[:max_samples]

    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        import random as _rng
        from torchvision.io import read_image
        pngs = self.episodes[idx]
        max_start = len(pngs) - self.min_frames
        start = _rng.randint(0, max_start)
        frames = []
        for k in range(self.num_frames):
            p = pngs[start + k * self.frame_stride]
            img = read_image(str(p))
            if img.shape[0] == 4:
                img = img[:3]
            img = img.float() / 255.0
            img = F.interpolate(img.unsqueeze(0), size=(self.height, self.width),
                                mode="bilinear", align_corners=False).squeeze(0)
            frames.append(img)
        out = torch.stack(frames, dim=1)  # (C, T, H, W)
        out = out * 2.0 - 1.0
        return out.unsqueeze(0)  # (1, C, T, H, W)


def encode_image_clip(image_encoder, feature_extractor, pixel_values, device):
    """Encode conditioning image through CLIP.

    Args:
        pixel_values: (B, 3, H, W) in [-1, 1]
    Returns:
        image_embeddings: (B, 1, 1024) CLIP embeddings
    """
    # Convert from [-1,1] to [0,1] for CLIP processor
    images_01 = (pixel_values + 1.0) / 2.0
    # CLIP expects specific normalization — use feature_extractor
    # But since we already have tensors, do manual CLIP normalization
    clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
    clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
    # Resize to 224x224 for CLIP
    images_clip = F.interpolate(images_01, size=(224, 224), mode="bilinear", align_corners=False)
    images_clip = (images_clip - clip_mean) / clip_std

    with torch.no_grad():
        image_embeddings = image_encoder(images_clip).image_embeds  # (B, 1024)
    return image_embeddings.unsqueeze(1)  # (B, 1, 1024)


def main():
    p = argparse.ArgumentParser(description="Fine-tune SVD on LIBERO with LoRA")
    p.add_argument("--data-root", type=str, default="/data/libero/parsed_libero/libero_spatial")
    p.add_argument("--svd-path", type=str, default=SVD_DEFAULT_PATH)
    p.add_argument("--batch-size", type=int, default=1)
    p.add_argument("--gradient-accumulation", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=100)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--run-name", type=str, default="svd_libero_spatial")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--vis-every", type=int, default=200)
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--checkpoint-dir", type=str, default="video_training/svd_finetune/checkpoints")
    p.add_argument("--frame-stride", type=int, default=3)
    p.add_argument("--lora-rank", type=int, default=16)
    p.add_argument("--mixed-precision", action="store_true", default=True)
    args = p.parse_args()

    device = torch.device(args.device)
    dtype = torch.float16 if args.mixed_precision else torch.float32
    ckpt_dir = Path(args.checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.log_wandb:
        import wandb
        wandb.init(project="svd_finetune", config=vars(args), name=args.run_name, mode="online")

    # --- Load SVD components ---
    svd_path = Path(args.svd_path)
    print("Loading SVD components...")

    vae = AutoencoderKLTemporalDecoder.from_pretrained(str(svd_path), subfolder="vae", torch_dtype=dtype)
    vae.to(device).eval()
    for param in vae.parameters():
        param.requires_grad = False
    print(f"  VAE loaded (frozen)")

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(str(svd_path), subfolder="image_encoder", torch_dtype=dtype)
    image_encoder.to(device).eval()
    for param in image_encoder.parameters():
        param.requires_grad = False
    feature_extractor = CLIPImageProcessor.from_pretrained(str(svd_path), subfolder="feature_extractor")
    print(f"  CLIP image encoder loaded (frozen)")

    unet = UNetSpatioTemporalConditionModel.from_pretrained(str(svd_path), subfolder="unet", torch_dtype=dtype)
    unet.to(device)
    print(f"  UNet loaded: {sum(p.numel() for p in unet.parameters()):,} params")

    noise_scheduler = EulerDiscreteScheduler.from_pretrained(str(svd_path), subfolder="scheduler")

    # --- Add LoRA to UNet ---
    lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        target_modules=["to_q", "to_k", "to_v", "to_out.0"],
        lora_dropout=0.0,
    )
    unet = get_peft_model(unet, lora_config)
    unet.print_trainable_parameters()
    unet.train()

    # --- Dataset ---
    dataset = LiberoVideoDatasetRect(
        root=args.data_root,
        num_frames=NUM_FRAMES,
        height=IMG_HEIGHT,
        width=IMG_WIDTH,
        frame_stride=args.frame_stride,
    )
    print(f"Dataset: {len(dataset)} episodes ({NUM_FRAMES} frames @ {IMG_HEIGHT}x{IMG_WIDTH}, stride {args.frame_stride})")
    loader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, collate_fn=collate_batch,
        pin_memory=True, persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    # --- Optimizer ---
    trainable_params = [p for p in unet.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled=args.mixed_precision)

    # --- Training ---
    global_step = 0
    for epoch in range(args.epochs):
        pbar = tqdm(loader, desc=f"epoch {epoch}")
        for batch_idx, batch in enumerate(pbar):
            # batch: (B, 3, T, H, W) in [-1, 1]
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape

            with torch.cuda.amp.autocast(enabled=args.mixed_precision):
                # 1. Encode conditioning image through CLIP → cross-attention embeddings
                cond_image = batch[:, :, 0]  # (B, 3, H, W) in [-1, 1]
                image_embeddings = encode_image_clip(
                    image_encoder, feature_extractor, cond_image, device
                )  # (B, 1, 1024)

                # 2. Encode all video frames through VAE
                frames = rearrange(batch, "b c t h w -> (b t) c h w")
                with torch.no_grad():
                    latents = vae.encode(frames.to(dtype)).latent_dist.sample()
                latents = rearrange(latents, "(b t) c h w -> b t c h w", b=B)
                latents = latents * vae.config.scaling_factor

                # 3. Conditioning image latent (separate from video latents):
                # SVD adds noise augmentation to cond image, encodes via VAE .mode(), repeats for T
                noise_aug = 0.02
                with torch.no_grad():
                    cond_img_aug = cond_image.to(dtype) + noise_aug * torch.randn_like(cond_image.to(dtype))
                    cond_latent = vae.encode(cond_img_aug).latent_dist.mode()
                    cond_latent = cond_latent * vae.config.scaling_factor
                cond_latent = cond_latent.unsqueeze(1).expand(-1, T, -1, -1, -1)  # (B, T, 4, H, W)

                # 4. EDM training: sample sigma from log-normal, add noise as x + sigma*noise
                # This matches SVD's actual training recipe (EDM/Karras parameterization)
                rnd_normal = torch.randn(B, device=device, dtype=dtype)
                sigma = (rnd_normal * 1.6 + 0.7).exp()  # log-normal: P_mean=0.7, P_std=1.6
                sigma = sigma.clamp(noise_scheduler.config.sigma_min,
                                    noise_scheduler.config.sigma_max)

                noise = torch.randn_like(latents)
                sigma_bc = sigma.view(B, 1, 1, 1, 1)
                noisy_latents = latents + noise * sigma_bc

                # 5. EDM preconditioning: scale input by c_in = 1/sqrt(sigma^2 + 1)
                c_in = (1.0 / (sigma_bc ** 2 + 1).sqrt())
                scaled_noisy = noisy_latents * c_in

                # Concat conditioning along channels: (B, T, 8, H, W)
                unet_input = torch.cat([scaled_noisy, cond_latent], dim=2)
                added_time_ids = torch.tensor(
                    [[6.0, 127.0, noise_aug]] * B, device=device, dtype=dtype
                )
                # Pass sigma as continuous timestep (SVD uses timestep_type="continuous")
                model_pred = unet(
                    unet_input, sigma,
                    encoder_hidden_states=image_embeddings,
                    added_time_ids=added_time_ids,
                ).sample

                # 6. EDM loss: the model predicts F_theta, and the denoised estimate is
                # D = c_skip * noisy + c_out * F_theta
                # c_skip = 1/(sigma^2+1), c_out = -sigma/sqrt(sigma^2+1)
                # Target for F_theta = (clean - c_skip * noisy) / c_out
                c_skip = 1.0 / (sigma_bc ** 2 + 1)
                c_out = -sigma_bc / (sigma_bc ** 2 + 1).sqrt()
                target = (latents - c_skip * noisy_latents) / c_out

                # Weight loss by 1/(sigma^2+1) for uniform SNR weighting
                weight = 1.0 / (sigma_bc ** 2 + 1)
                loss = (weight * (model_pred - target) ** 2).mean()
                loss = loss / args.gradient_accumulation

            scaler.scale(loss).backward()

            if (batch_idx + 1) % args.gradient_accumulation == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()

            pbar.set_postfix(loss=f"{loss.item() * args.gradient_accumulation:.4f}", step=global_step)

            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item() * args.gradient_accumulation}, step=global_step)

            # --- Visualization: full reverse diffusion sampling ---
            if global_step % args.vis_every == 0 and args.log_wandb and global_step > 0:
                print(f"\n[vis] Generating video at step {global_step}...", flush=True)
                try:
                    import wandb
                    import torchvision
                    unet.eval()
                    with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision):
                        cond_img = batch[:1, :, 0].to(dtype)
                        emb = encode_image_clip(image_encoder, feature_extractor, cond_img, device)
                        # Conditioning latent with noise augmentation (matching pipeline)
                        noise_aug = 0.02
                        cond_aug = cond_img + noise_aug * torch.randn_like(cond_img)
                        cond_lat = vae.encode(cond_aug).latent_dist.mode() * vae.config.scaling_factor
                        cond_rep = cond_lat.unsqueeze(1).expand(-1, T, -1, -1, -1)
                        added_ids = torch.tensor([[6.0, 127.0, noise_aug]], device=device, dtype=dtype)

                        # Full reverse sampling with 25 steps
                        vis_scheduler = EulerDiscreteScheduler.from_pretrained(
                            str(svd_path), subfolder="scheduler")
                        vis_scheduler.set_timesteps(25, device=device)
                        z = torch.randn(1, T, 4, cond_lat.shape[2], cond_lat.shape[3],
                                        device=device, dtype=dtype) * vis_scheduler.init_noise_sigma
                        for t_step in vis_scheduler.timesteps:
                            # Scale model input (required by Euler scheduler)
                            z_scaled = vis_scheduler.scale_model_input(z, t_step)
                            z_input = torch.cat([z_scaled, cond_rep], dim=2)
                            pred = unet(z_input, t_step.unsqueeze(0), encoder_hidden_states=emb,
                                        added_time_ids=added_ids).sample
                            z = vis_scheduler.step(pred, t_step, z).prev_sample

                        # Decode
                        z_flat = rearrange(z, "b t c h w -> (b t) c h w")
                        z_flat = z_flat / vae.config.scaling_factor
                        decoded = vae.decode(z_flat, num_frames=T).sample
                        decoded = ((decoded.float().cpu() + 1.0) / 2.0).clamp(0, 1)
                        frames_np = (decoded.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                            tmp_path = f.name
                        torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                        wandb.log({"vis/predicted_video": wandb.Video(tmp_path, format="mp4")}, step=global_step)
                        Path(tmp_path).unlink(missing_ok=True)
                        print(f"[vis] Logged video at step {global_step}", flush=True)

                    unet.train()
                except Exception as e:
                    print(f"[vis] ERROR at step {global_step}: {e}", flush=True)
                    import traceback
                    traceback.print_exc()
                    unet.train()

            # --- Checkpoint ---
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                try:
                    unet.save_pretrained(str(ckpt_dir / "unet_lora"))
                    torch.save({
                        "step": global_step,
                        "optimizer": opt.state_dict(),
                    }, ckpt_dir / "latest.pt")
                    print(f"[ckpt] Saved at step {global_step}", flush=True)
                except Exception as e:
                    print(f"[ckpt] ERROR at step {global_step}: {e}", flush=True)

            global_step += 1

        print(f"Epoch {epoch} done, step={global_step}")

    if args.log_wandb:
        import wandb
        wandb.finish()
    print(f"Done. Checkpoints at {ckpt_dir}")


if __name__ == "__main__":
    main()
