"""Train SVD Video Diffusion + PARA action heads on LIBERO.

Frozen SVD UNet as feature extractor: extract up_block_2 features at mid and late
denoising timesteps, concatenate, upsample to 64x64, feed to PARA heads.

Usage:
    CUDA_VISIBLE_DEVICES=4 python train_svd_para.py \
        --cache_root /data/libero/ood_objpos_task0 \
        --svd_checkpoint /data/cameron/vidgen/svd_motion_lora/Motion-LoRA/output_libero_7f/checkpoint-46000/unet \
        --log_wandb --run_name svd_para_ood_objpos
"""

import argparse
import io
import json
import random
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, str(Path(__file__).parent))
# Add Motion-LoRA path for SVD models
SVD_ROOT = Path("/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")
sys.path.insert(0, str(SVD_ROOT))

from data import CachedTrajectoryDataset

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
PRERENDER_SIZE = 448        # parsed libero images are 448x448
SVD_SIZE = (320, 576)       # SVD training resolution (H, W)
PARA_OUT_SIZE = 64          # PARA head output spatial size
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
N_WINDOW = 4                # Predict 4 future timesteps

# SVD UNet feature dimensions (from up_block_2)
SVD_FEAT_DIM = 640          # up_block_2 channel dim
SVD_FEAT_H, SVD_FEAT_W = 40, 72  # spatial dims at 320x576

# Denoising timesteps for feature extraction
# Using EulerDiscreteScheduler with 25 steps
MID_STEP = 12               # middle denoising step
LATE_STEP = 23              # late denoising step (low noise)

# Combined feature dim: mid + late timestep features
COMBINED_FEAT_DIM = SVD_FEAT_DIM * 2  # 1280

# Dataset stats — updated at runtime
MIN_HEIGHT = 0.0
MAX_HEIGHT = 1.0
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
MIN_ROT = [-3.14159, -3.14159, -3.14159]
MAX_ROT = [3.14159, 3.14159, 3.14159]

GRIPPER_LOSS_WEIGHT = 5.0
ROTATION_LOSS_WEIGHT = 0.5


# ---------------------------------------------------------------------------
# SVD Feature Extractor (frozen)
# ---------------------------------------------------------------------------
class SVDFeatureExtractor(nn.Module):
    """Frozen SVD UNet that extracts up_block_2 features at two noise levels."""

    def __init__(self, svd_base_path, svd_unet_path=None, device="cuda"):
        super().__init__()
        from svd.models import UNetSpatioTemporalConditionModel
        from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
        from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

        # Load UNet (fine-tuned or base)
        unet_path = svd_unet_path or svd_base_path
        print(f"Loading SVD UNet from: {unet_path}")
        self.unet = UNetSpatioTemporalConditionModel.from_pretrained(
            unet_path, subfolder="unet" if "unet" not in str(unet_path) else None,
            torch_dtype=torch.float16
        )
        self.unet.eval()
        for p in self.unet.parameters():
            p.requires_grad = False

        # Load VAE (always from base)
        self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
            svd_base_path, subfolder="vae", torch_dtype=torch.float16
        )
        self.vae.eval()
        for p in self.vae.parameters():
            p.requires_grad = False

        # Load CLIP image encoder (always from base)
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            svd_base_path, subfolder="image_encoder", torch_dtype=torch.float16
        )
        self.image_encoder.eval()
        for p in self.image_encoder.parameters():
            p.requires_grad = False

        self.feature_extractor = CLIPImageProcessor.from_pretrained(
            svd_base_path, subfolder="feature_extractor"
        )

        # Scheduler for noise levels
        self.scheduler = EulerDiscreteScheduler.from_pretrained(svd_base_path, subfolder="scheduler")
        self.scheduler.set_timesteps(25, device=device)

        # Hook to capture features
        self._captured = {}
        self.unet.up_blocks[2].register_forward_hook(self._make_hook("up_block_2"))

    def _make_hook(self, name):
        def hook_fn(module, input, output):
            if isinstance(output, tuple):
                self._captured[name] = output[0].detach()
            else:
                self._captured[name] = output.detach()
        return hook_fn

    @torch.no_grad()
    def extract_features(self, image_tensor, num_frames=1):
        """
        Extract UNet features at mid and late denoising timesteps.

        Args:
            image_tensor: (B, 3, H, W) in [0, 1], at SVD resolution
        Returns:
            features: (B, COMBINED_FEAT_DIM, feat_H, feat_W) = (B, 1280, 40, 72)
        """
        device = image_tensor.device
        B = image_tensor.shape[0]

        # Normalize to [-1, 1] for VAE
        img_norm = image_tensor * 2.0 - 1.0

        # Encode with VAE
        latent = self.vae.encode(img_norm.half()).latent_dist.sample()
        latent = latent * self.vae.config.scaling_factor
        # Repeat for temporal dim
        latent = latent.unsqueeze(2).repeat(1, 1, num_frames, 1, 1)  # (B, C, T, h, w)

        # Conditioning latent
        cond_latent = latent / self.vae.config.scaling_factor

        # CLIP image embedding
        clip_input = []
        for i in range(B):
            img_pil = torch.clamp(image_tensor[i], 0, 1)
            img_np = (img_pil.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            from PIL import Image
            clip_input.append(Image.fromarray(img_np))
        clip_processed = self.feature_extractor(images=clip_input, return_tensors="pt").pixel_values
        clip_processed = clip_processed.to(device, dtype=torch.float16)
        image_embeddings = self.image_encoder(clip_processed).image_embeds.unsqueeze(1)

        # Added time IDs
        added_time_ids = torch.tensor([[7, 127, 0.02]], device=device, dtype=torch.float16).repeat(B, 1)

        # Generate noise
        noise = torch.randn_like(latent)

        # Extract features at mid and late timesteps
        all_feats = []
        for step_idx in [MID_STEP, LATE_STEP]:
            sigma = self.scheduler.sigmas[step_idx]
            t = self.scheduler.timesteps[step_idx]

            noisy_latent = latent + noise * sigma
            scaled_input = noisy_latent / ((sigma**2 + 1) ** 0.5)

            # Concat with conditioning latent
            noisy_5d = scaled_input.permute(0, 2, 1, 3, 4)  # (B, T, C, h, w)
            cond_5d = cond_latent.permute(0, 2, 1, 3, 4)
            unet_input = torch.cat([noisy_5d, cond_5d], dim=2)  # (B, T, 8, h, w)

            self._captured.clear()
            _ = self.unet(
                unet_input, t.unsqueeze(0),
                encoder_hidden_states=image_embeddings,
                added_time_ids=added_time_ids,
            ).sample

            feat = self._captured["up_block_2"]  # (B*T, C, H, W) or (T, C, H, W)
            if feat.shape[0] == B * num_frames:
                feat = feat.view(B, num_frames, *feat.shape[1:])[:, 0]  # take first frame
            elif feat.shape[0] == num_frames:
                feat = feat[:1]  # take first frame, single batch
            all_feats.append(feat.float())

        # Concat mid + late features
        combined = torch.cat(all_feats, dim=1)  # (B, 1280, 40, 72)
        return combined


# ---------------------------------------------------------------------------
# PARA Heads (on SVD features)
# ---------------------------------------------------------------------------
class ParaHeads(nn.Module):
    """Volume + gripper + rotation heads on SVD diffusion features."""

    def __init__(self, feat_dim=COMBINED_FEAT_DIM, para_out_size=PARA_OUT_SIZE,
                 n_window=N_WINDOW, n_height_bins=N_HEIGHT_BINS):
        super().__init__()
        D = feat_dim
        self.para_out_size = para_out_size
        self.n_height_bins = n_height_bins
        self.n_window = n_window

        # Feature processing: resize from (40, 72) to (64, 64) with conv refinement
        self.feature_net = nn.Sequential(
            nn.Conv2d(D, 512, 3, padding=1), nn.GELU(),
            nn.Conv2d(512, 512, 3, padding=1), nn.GELU(),
            nn.Conv2d(512, 512, 3, padding=1), nn.GELU(),
        )

        self.volume_head = nn.Conv2d(512, n_window * n_height_bins, 1)
        self.gripper_head = nn.Conv2d(512, n_window * N_GRIPPER_BINS, 1)
        self.rotation_head = nn.Conv2d(512, n_window * 3 * N_ROT_BINS, 1)

    def forward(self, features, query_pixels=None):
        """
        Args:
            features: (B, D, H_feat, W_feat) from SVD extractor
            query_pixels: (B, N_WINDOW, 2) in PARA_OUT_SIZE coords [x, y]
        Returns:
            volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, P, P)
            feats: (B, 512, P, P)
            gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS) or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS) or None
        """
        B = features.shape[0]
        P = self.para_out_size

        # Resize to PARA output size
        x = F.interpolate(features, size=(P, P), mode='bilinear', align_corners=False)
        feats = self.feature_net(x)  # (B, 512, P, P)

        vol = self.volume_head(feats)  # (B, N*Nh, P, P)
        vol = vol.view(B, self.n_window, self.n_height_bins, P, P)

        gripper_logits = rotation_logits = None
        if query_pixels is not None:
            N = self.n_window
            px = query_pixels[..., 0].long().clamp(0, P - 1)  # (B, N)
            py = query_pixels[..., 1].long().clamp(0, P - 1)

            # Gripper: index feature map at query pixels
            grip_map = self.gripper_head(feats.detach())  # (B, N*Ng, P, P)
            grip_map = grip_map.view(B, N, N_GRIPPER_BINS, P, P)
            batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
            time_idx = torch.arange(N, device=feats.device).view(1, N).expand(B, N)
            gripper_logits = grip_map[batch_idx, time_idx, :, py, px]  # (B, N, Ng)

            # Rotation
            rot_map = self.rotation_head(feats.detach())  # (B, N*3*Nr, P, P)
            rot_map = rot_map.view(B, N, 3, N_ROT_BINS, P, P)
            rotation_logits = rot_map[batch_idx, time_idx, :, :, py, px]  # (B, N, 3, Nr)

        return vol, feats, gripper_logits, rotation_logits


# ---------------------------------------------------------------------------
# Loss functions
# ---------------------------------------------------------------------------
def discretize_height(height_values, min_h, max_h, n_bins=N_HEIGHT_BINS):
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    return (normalized.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)

def discretize_gripper(gripper_values, min_g, max_g):
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    return (normalized.clamp(0, 1) * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)

def discretize_rotation(euler_values, min_r, max_r):
    normalized = (euler_values - min_r) / (max_r - min_r + 1e-8)
    return (normalized.clamp(0, 1) * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)

def compute_volume_loss(pred, traj_2d, target_h_bins):
    B, N, Nh, H, W = pred.shape
    px = traj_2d[:, :, 0].long().clamp(0, W - 1)
    py = traj_2d[:, :, 1].long().clamp(0, H - 1)
    h_bin = target_h_bins.clamp(0, Nh - 1)
    losses = []
    for t in range(N):
        logits_flat = pred[:, t].reshape(B, -1)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()
        losses.append(F.cross_entropy(logits_flat, target_idx, reduction='mean'))
    return torch.stack(losses).mean()

def compute_gripper_loss(pred, target, min_g, max_g):
    target_bins = discretize_gripper(target, min_g, max_g)
    B, N, Ng = pred.shape
    return F.cross_entropy(pred.reshape(B * N, Ng), target_bins.reshape(B * N))

def compute_rotation_loss(pred, target_euler, min_r, max_r):
    target_bins = discretize_rotation(target_euler, min_r, max_r)
    B, N, _, Nr = pred.shape
    losses = []
    for axis in range(3):
        logits = pred[:, :, axis, :].reshape(B * N, Nr)
        target = target_bins[:, :, axis].reshape(B * N)
        losses.append(F.cross_entropy(logits, target))
    return torch.stack(losses).mean()

def extract_pred_2d_and_height(volume_logits, min_h, max_h):
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device)
    pred_h_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_h_bins[:, t] = vol_t[torch.arange(B, device=device), :, py, px].argmax(dim=1)
    bin_centers = torch.linspace(0, 1, Nh, device=device)
    pred_height = bin_centers[pred_h_bins] * (max_h - min_h) + min_h
    return pred_2d, pred_height


# ---------------------------------------------------------------------------
# Dataset stats
# ---------------------------------------------------------------------------
def compute_dataset_stats(dataset, sample_limit=500, seed=42):
    rng = random.Random(seed)
    n = min(sample_limit, len(dataset))
    indices = rng.sample(range(len(dataset)), n)
    all_h, all_g, all_e = [], [], []
    for idx in tqdm(indices, desc="Computing stats", leave=False):
        try:
            s = dataset[idx]
        except Exception:
            continue
        all_h.extend(s['trajectory_3d'].numpy()[:, 2].tolist())
        all_g.extend(s['trajectory_gripper'].numpy().tolist())
        all_e.append(s['trajectory_euler'].numpy())
    h = np.array(all_h); g = np.array(all_g); e = np.concatenate(all_e, 0)
    return {
        "min_height": float(h.min()), "max_height": float(h.max()),
        "min_gripper": float(g.min()), "max_gripper": float(g.max()),
        "min_rot": e.min(0).tolist(), "max_rot": e.max(0).tolist(),
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--cache_root", type=str, default="/data/libero/ood_objpos_task0")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids", type=str, default="0")
    p.add_argument("--svd_base", type=str,
                   default="/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/checkpoints/stable-video-diffusion-img2vid-xt-1-1")
    p.add_argument("--svd_checkpoint", type=str,
                   default="/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/output_libero_7f/checkpoint-46000/unet")
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=500)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--run_name", type=str, default="svd_para_ood_objpos")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--vis_every", type=int, default=100)
    p.add_argument("--checkpoint_every", type=int, default=1000)
    args = p.parse_args()

    device = torch.device(args.device)
    ckpt_dir = Path(__file__).parent / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.log_wandb:
        import wandb
        wandb.init(project="svd_para_libero", config=vars(args), name=args.run_name)

    # --- Build SVD feature extractor (frozen) ---
    svd_extractor = SVDFeatureExtractor(
        svd_base_path=args.svd_base,
        svd_unet_path=args.svd_checkpoint,
        device=device,
    ).to(device)
    print(f"SVD feature extractor loaded (frozen)")
    print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # --- Build PARA heads (trainable) ---
    para_heads = ParaHeads(
        feat_dim=COMBINED_FEAT_DIM,
        para_out_size=PARA_OUT_SIZE,
        n_window=N_WINDOW,
    ).to(device)
    print(f"ParaHeads: {sum(p.numel() for p in para_heads.parameters()):,} trainable params")

    # --- Dataset ---
    task_ids = [int(x) for x in args.task_ids.split(",")]
    dataset = CachedTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=task_ids,
        image_size=PRERENDER_SIZE,
        n_window=N_WINDOW,
        frame_stride=args.frame_stride,
    )
    print(f"Dataset: {len(dataset)} samples")

    # Train/val split
    val_size = max(1, int(len(dataset) * 0.05))
    train_size = len(dataset) - val_size
    train_ds, val_ds = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True)
    print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

    # --- Dataset stats ---
    stats_path = ckpt_dir / "dataset_stats.json"
    if stats_path.exists():
        stats = json.load(open(stats_path))
    else:
        stats = compute_dataset_stats(dataset)
        json.dump(stats, open(stats_path, "w"), indent=2)

    global MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER, MIN_ROT, MAX_ROT
    MIN_HEIGHT, MAX_HEIGHT = stats["min_height"], stats["max_height"]
    MIN_GRIPPER, MAX_GRIPPER = stats["min_gripper"], stats["max_gripper"]
    MIN_ROT, MAX_ROT = stats["min_rot"], stats["max_rot"]
    min_r_t = torch.tensor(MIN_ROT, dtype=torch.float32, device=device)
    max_r_t = torch.tensor(MAX_ROT, dtype=torch.float32, device=device)
    print(f"Height: [{MIN_HEIGHT:.4f}, {MAX_HEIGHT:.4f}]")
    print(f"Gripper: [{MIN_GRIPPER:.4f}, {MAX_GRIPPER:.4f}]")

    coord_scale = PARA_OUT_SIZE / PRERENDER_SIZE

    # --- Optimizer (only PARA heads) ---
    opt = torch.optim.AdamW(para_heads.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs * len(train_loader))

    # --- Training ---
    global_step = 0
    for epoch in tqdm(range(args.epochs), desc="Epochs"):
        para_heads.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            # Get first frame and resize for SVD
            rgb = batch['rgb']  # (B, 3, 448, 448) ImageNet-normalized
            # Undo ImageNet normalization to get [0, 1]
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            rgb_01 = rgb.to(device) * std + mean  # (B, 3, 448, 448) in [0, 1]

            # Resize to SVD resolution
            rgb_svd = F.interpolate(rgb_01, size=SVD_SIZE, mode='bilinear', align_corners=False)

            # Extract frozen SVD features
            with torch.no_grad():
                features = svd_extractor.extract_features(rgb_svd)  # (B, 1280, 40, 72)

            # Trajectory supervision
            traj_2d = batch['trajectory_2d'].to(device)[:, :N_WINDOW]
            traj_3d = batch['trajectory_3d'].to(device)[:, :N_WINDOW]
            traj_gripper = batch['trajectory_gripper'].to(device)[:, :N_WINDOW]
            traj_euler = batch['trajectory_euler'].to(device)[:, :N_WINDOW]

            traj_para = (traj_2d * coord_scale).clamp(0, PARA_OUT_SIZE - 1.001)
            target_h_bins = discretize_height(traj_3d[:, :, 2], MIN_HEIGHT, MAX_HEIGHT)

            # PARA forward
            vol_logits, feats, grip_logits, rot_logits = para_heads(
                features, query_pixels=traj_para
            )

            # Losses
            vol_loss = compute_volume_loss(vol_logits, traj_para, target_h_bins)
            grip_loss = compute_gripper_loss(grip_logits, traj_gripper, MIN_GRIPPER, MAX_GRIPPER)
            rot_loss = compute_rotation_loss(rot_logits, traj_euler, min_r_t, max_r_t)
            total_loss = vol_loss + GRIPPER_LOSS_WEIGHT * grip_loss + ROTATION_LOSS_WEIGHT * rot_loss

            opt.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(para_heads.parameters(), 1.0)
            opt.step()
            scheduler.step()

            pbar.set_postfix(vol=f"{vol_loss.item():.3f}", grip=f"{grip_loss.item():.3f}",
                           rot=f"{rot_loss.item():.3f}")

            if args.log_wandb:
                import wandb
                wandb.log({
                    "train/total_loss": total_loss.item(),
                    "train/volume_loss": vol_loss.item(),
                    "train/gripper_loss": grip_loss.item(),
                    "train/rotation_loss": rot_loss.item(),
                    "train/lr": scheduler.get_last_lr()[0],
                }, step=global_step)

            # Visualization
            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb, matplotlib
                matplotlib.use("Agg")
                import matplotlib.pyplot as plt

                para_heads.eval()
                with torch.no_grad():
                    heatmaps = F.softmax(vol_logits[:1].reshape(1, N_WINDOW, -1), dim=2)
                    heatmaps = heatmaps.view(1, N_WINDOW, N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
                    heatmaps_2d = heatmaps.max(dim=2)[0]  # (1, N, P, P)

                    input_frame = (rgb_01[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    input_small = cv2.resize(input_frame, (PARA_OUT_SIZE, PARA_OUT_SIZE))

                    fig, axes = plt.subplots(1, N_WINDOW, figsize=(3 * N_WINDOW, 3))
                    for t in range(N_WINDOW):
                        heat = heatmaps_2d[0, t].cpu().numpy()
                        heat_norm = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
                        overlay = input_small / 255.0 * 0.5 + np.stack(
                            [heat_norm, np.zeros_like(heat_norm), np.zeros_like(heat_norm)], axis=-1) * 0.5
                        axes[t].imshow(np.clip(overlay, 0, 1))
                        axes[t].scatter(traj_para[0, t, 0].cpu(), traj_para[0, t, 1].cpu(),
                                       c="cyan", s=40, marker="x", linewidths=2)
                        axes[t].set_title(f"t={t}")
                        axes[t].axis("off")
                    plt.tight_layout()
                    buf = io.BytesIO()
                    plt.savefig(buf, format="png", dpi=100)
                    buf.seek(0)
                    from PIL import Image
                    wandb.log({"vis/heatmaps": wandb.Image(Image.open(buf))}, step=global_step)
                    plt.close("all")
                para_heads.train()

            # Checkpoint
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                torch.save({
                    "epoch": epoch, "global_step": global_step,
                    "para_heads_state_dict": para_heads.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "stats": stats,
                }, ckpt_dir / f"checkpoint_{global_step}.pt")
                print(f"Saved checkpoint at step {global_step}")

            global_step += 1

    print("Training complete!")
    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
