"""Joint SVD video diffusion + PARA action prediction training.

During each training step:
  1. UNet forward pass with random noise level (standard diffusion training)
  2. Hook into up_block_1 (1280ch, 20x36) and up_block_2 (640ch, 40x72)
  3. Project each to 128 dims via linear, concat → 256ch
  4. Bilinear upsample to 64x64, conv refinement, PARA heads
  5. Loss = EMA-weighted(volume_loss, gripper_loss, diffusion_loss)

Usage:
    CUDA_VISIBLE_DEVICES=0,3,5,8 python train_svd_para_joint.py
"""

import argparse
import io
import json
import math
import os
import random
import shutil
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torchvision

import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from PIL import Image
from scipy.spatial.transform import Rotation as ScipyR

import diffusers
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, "/data/cameron/para_videopolicy")
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel
from data import CachedTrajectoryDataset

import decord
import imageio
import wandb

logger = get_logger(__name__, log_level="INFO")

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
PARA_OUT_SIZE = 64
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
PRERENDER_SIZE = 448
EMA_ALPHA = 0.01  # EMA smoothing for loss weighting
PROJ_DIM = 128    # project each UNet feature block to this dim before concat

# Dataset stats
MIN_HEIGHT = MAX_HEIGHT = 0.0
MIN_GRIPPER = MAX_GRIPPER = 0.0
MIN_ROT = MAX_ROT = None


# ---------------------------------------------------------------------------
# PARA Heads on UNet features
# ---------------------------------------------------------------------------
class ParaHeadsOnUNet(nn.Module):
    """PARA heads that attach to SVD UNet's up_block_1 and up_block_2."""

    def __init__(self, n_window=1, n_height_bins=N_HEIGHT_BINS, proj_dim=PROJ_DIM):
        super().__init__()
        self.n_height_bins = n_height_bins

        # Project each block to PROJ_DIM, then concat → 2*PROJ_DIM = 256
        self.proj_block1 = nn.Conv2d(1280, proj_dim, 1)  # up_block_1: 1280ch, 20x36
        self.proj_block2 = nn.Conv2d(640, proj_dim, 1)   # up_block_2: 640ch, 40x72

        D = proj_dim * 2  # 256
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        self.volume_head = nn.Conv2d(D, n_height_bins, 1)
        self.gripper_head = nn.Conv2d(D, N_GRIPPER_BINS, 1)
        self.rotation_head = nn.Conv2d(D, 3 * N_ROT_BINS, 1)

    def forward(self, feat_block1, feat_block2, query_pixels=None):
        """
        Args:
            feat_block1: (B*T, 1280, 20, 36) from up_block_1
            feat_block2: (B*T, 640, 40, 72) from up_block_2
            query_pixels: (B*T, 2) in PARA_OUT_SIZE coords [x, y]
        Returns:
            volume_logits: (B*T, N_HEIGHT_BINS, P, P)
            gripper_logits: (B*T, N_GRIPPER_BINS) or None
            rotation_logits: (B*T, 3, N_ROT_BINS) or None
        """
        P = PARA_OUT_SIZE

        # Project and upsample both to P×P
        f1 = self.proj_block1(feat_block1)  # (B*T, 128, 20, 36)
        f1 = F.interpolate(f1, size=(P, P), mode='bilinear', align_corners=False)

        f2 = self.proj_block2(feat_block2)  # (B*T, 128, 40, 72)
        f2 = F.interpolate(f2, size=(P, P), mode='bilinear', align_corners=False)

        # Concat → 256ch
        feats = torch.cat([f1, f2], dim=1)  # (B*T, 256, P, P)
        feats = self.feature_convs(feats)

        vol = self.volume_head(feats)  # (B*T, Nh, P, P)

        gripper_logits = rotation_logits = None
        if query_pixels is not None:
            BT = feats.shape[0]
            px = query_pixels[:, 0].long().clamp(0, P - 1)
            py = query_pixels[:, 1].long().clamp(0, P - 1)
            idx = torch.arange(BT, device=feats.device)

            # Gripper
            grip_map = self.gripper_head(feats.detach())  # (BT, Ng, P, P)
            gripper_logits = grip_map[idx, :, py, px]  # (BT, Ng)

            # Rotation
            rot_map = self.rotation_head(feats.detach())  # (BT, 3*Nr, P, P)
            rot_at_px = rot_map[idx, :, py, px]  # (BT, 3*Nr)
            rotation_logits = rot_at_px.view(BT, 3, N_ROT_BINS)

        return vol, gripper_logits, rotation_logits


# ---------------------------------------------------------------------------
# Loss functions
# ---------------------------------------------------------------------------
def discretize(values, min_v, max_v, n_bins):
    norm = (values - min_v) / (max_v - min_v + 1e-8)
    return (norm.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)

def compute_volume_loss(vol_logits, traj_2d, target_h_bins):
    """vol_logits: (BT, Nh, H, W), traj_2d: (BT, 2), target_h_bins: (BT,)"""
    BT, Nh, H, W = vol_logits.shape
    px = traj_2d[:, 0].long().clamp(0, W - 1)
    py = traj_2d[:, 1].long().clamp(0, H - 1)
    h_bin = target_h_bins.clamp(0, Nh - 1)
    logits_flat = vol_logits.reshape(BT, -1)
    target_idx = (h_bin * (H * W) + py * W + px).long()
    return F.cross_entropy(logits_flat, target_idx)

def compute_gripper_loss(logits, target, min_g, max_g):
    target_bins = discretize(target, min_g, max_g, N_GRIPPER_BINS)
    return F.cross_entropy(logits, target_bins)

def compute_rotation_loss(logits, target_euler, min_r, max_r):
    """logits: (BT, 3, Nr), target_euler: (BT, 3)"""
    losses = []
    for axis in range(3):
        bins = discretize(target_euler[:, axis], min_r[axis], max_r[axis], N_ROT_BINS)
        losses.append(F.cross_entropy(logits[:, axis, :], bins))
    return torch.stack(losses).mean()


# ---------------------------------------------------------------------------
# Diffusion helpers (from train_svd.py)
# ---------------------------------------------------------------------------
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return torch.distributions.Normal(loc, scale).icdf(u).exp()

def tensor_to_vae_latent(t, vae):
    """t: (B, C, T, H, W) -> latents (B, C_latent, T, h, w)"""
    video = t
    b, c, f, h, w = video.shape
    video = video.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)  # (B*T, C, H, W)
    latents = vae.encode(video).latent_dist.sample()
    latents = latents.reshape(b, f, *latents.shape[1:])
    latents = latents.permute(0, 2, 1, 3, 4)  # (B, C, T, H, W)
    return latents

def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
    h, w = input.shape[-2:]
    factors = (h / size[0], w / size[1])
    sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
    ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
    if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]
    if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1
    input = torch.nn.functional.pad(input, [ks[1]//2]*2 + [ks[0]//2]*2, mode="reflect")
    kernel = torch.ones(1, dtype=input.dtype, device=input.device)
    for s, k in zip(sigmas, ks):
        coord = torch.arange(k, dtype=input.dtype, device=input.device) - (k - 1) / 2
        g = torch.exp(-coord**2 / (2*s**2))
        g /= g.sum()
        kernel = kernel.unsqueeze(-1) * g.unsqueeze(0)
    kernel = kernel.expand(input.shape[-3], -1, -1).unsqueeze(1)
    input = F.conv2d(input, kernel, groups=input.shape[-3])
    return F.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)


# ---------------------------------------------------------------------------
# Dataset: video frames + PARA trajectory
# ---------------------------------------------------------------------------
class VideoParaDataset(Dataset):
    """Wraps CachedTrajectoryDataset to provide video frames for SVD + PARA annotations."""

    def __init__(self, cache_root, benchmark, task_ids, width=576, height=320,
                 sample_frames=7, image_size=448, frame_stride=1):
        self.para_dataset = CachedTrajectoryDataset(
            cache_root=cache_root, benchmark_name=benchmark,
            task_ids=task_ids, image_size=image_size,
            n_window=sample_frames, frame_stride=frame_stride,
        )
        self.width = width
        self.height = height
        self.sample_frames = sample_frames
        self.image_size = image_size

    def __len__(self):
        return len(self.para_dataset)

    def __getitem__(self, idx):
        sample = self.para_dataset[idx]

        # Video frames for SVD: (T, 3, H, W) in [-1, 1]
        rgb_frames = sample['rgb_frames_raw']  # (T, 448, 448, 3) float [0,1]
        T = rgb_frames.shape[0]
        frames = rgb_frames.permute(0, 3, 1, 2)  # (T, 3, H, W)
        frames_resized = F.interpolate(frames, size=(self.height, self.width),
                                       mode='bilinear', align_corners=False)
        video = frames_resized * 2.0 - 1.0  # to [-1, 1]
        # (T, 3, H, W) -> (3, T, H, W) for SVD convention
        video = video.permute(1, 0, 2, 3)

        return {
            'pixel_values': video,  # (3, T, H, W) in [-1, 1]
            'trajectory_2d': sample['trajectory_2d'],  # (T, 2) in 448 space
            'trajectory_3d': sample['trajectory_3d'],
            'trajectory_gripper': sample['trajectory_gripper'],
            'trajectory_euler': sample['trajectory_euler'],
            'rgb_frames_raw': sample['rgb_frames_raw'],  # for visualization
        }


# ---------------------------------------------------------------------------
# Dataset stats
# ---------------------------------------------------------------------------
def compute_dataset_stats(dataset, n=500):
    rng = random.Random(42)
    indices = rng.sample(range(len(dataset)), min(n, len(dataset)))
    all_h, all_g, all_e = [], [], []
    for idx in tqdm(indices, desc="Stats", leave=False):
        try:
            s = dataset.para_dataset[idx]
        except: continue
        all_h.extend(s['trajectory_3d'].numpy()[:, 2].tolist())
        all_g.extend(s['trajectory_gripper'].numpy().tolist())
        all_e.append(s['trajectory_euler'].numpy())
    h, g, e = np.array(all_h), np.array(all_g), np.concatenate(all_e, 0)
    return {"min_height": float(h.min()), "max_height": float(h.max()),
            "min_gripper": float(g.min()), "max_gripper": float(g.max()),
            "min_rot": e.min(0).tolist(), "max_rot": e.max(0).tolist()}


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--pretrained", type=str,
                   default="checkpoints/stable-video-diffusion-img2vid-xt-1-1")
    p.add_argument("--pretrain_unet", type=str,
                   default="output_libero_ood_objpos/checkpoint-31500/unet")
    p.add_argument("--cache_root", type=str, default="/data/libero/ood_objpos_task0")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids", type=str, default="0")
    p.add_argument("--width", type=int, default=576)
    p.add_argument("--height", type=int, default=320)
    p.add_argument("--num_frames", type=int, default=7)
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=1)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--freeze_unet", action="store_true", help="Freeze UNet, train only PARA heads")
    p.add_argument("--max_steps", type=int, default=999999)
    p.add_argument("--output_dir", type=str, default="output_svd_para_joint")
    p.add_argument("--ckpt_every", type=int, default=1000)
    p.add_argument("--vis_every", type=int, default=200)
    p.add_argument("--seed", type=int, default=123)
    args = p.parse_args()

    # Accelerator
    from accelerate import DistributedDataParallelKwargs
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=1,
                              kwargs_handlers=[ddp_kwargs])
    device = accelerator.device
    set_seed(args.seed)

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
        wandb.init(project="svd_para_joint", name=Path(args.output_dir).name,
                   config=vars(args))

    # --- Models ---
    weight_dtype = torch.bfloat16

    # Frozen: VAE, image encoder, feature extractor
    vae = AutoencoderKLTemporalDecoder.from_pretrained(
        args.pretrained, subfolder="vae").to(device, dtype=weight_dtype)
    vae.requires_grad_(False)

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        args.pretrained, subfolder="image_encoder").to(device, dtype=weight_dtype)
    image_encoder.requires_grad_(False)

    feature_extractor = CLIPImageProcessor.from_pretrained(
        args.pretrained, subfolder="feature_extractor")

    # Trainable: UNet + PARA heads
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.pretrain_unet, subfolder="unet" if "unet" not in args.pretrain_unet else None,
    ).to(device, dtype=weight_dtype)
    unet.enable_gradient_checkpointing()

    para_heads = ParaHeadsOnUNet(n_window=1).to(device, dtype=torch.float32)

    # Register hooks for feature capture
    captured = {}
    def make_hook(name):
        def hook_fn(module, inp, out):
            captured[name] = out[0] if isinstance(out, tuple) else out
        return hook_fn
    unet.up_blocks[1].register_forward_hook(make_hook("up_block_1"))
    unet.up_blocks[2].register_forward_hook(make_hook("up_block_2"))

    # CLIP image encoding helper
    def encode_image(pixel_values):
        pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
        pixel_values = (pixel_values + 1.0) / 2.0
        pixel_values = torchvision.transforms.functional.normalize(
            pixel_values, [0.48145466, 0.4578275, 0.40821073],
            [0.26862954, 0.26130258, 0.27577711])
        image_embeddings = image_encoder(pixel_values).image_embeds
        return image_embeddings.unsqueeze(1)

    # --- Dataset ---
    task_ids = [int(x) for x in args.task_ids.split(",")]
    dataset = VideoParaDataset(
        cache_root=args.cache_root, benchmark=args.benchmark,
        task_ids=task_ids, width=args.width, height=args.height,
        sample_frames=args.num_frames, frame_stride=args.frame_stride,
    )

    # Stats
    stats_path = Path(args.output_dir) / "dataset_stats.json"
    if stats_path.exists():
        stats = json.load(open(stats_path))
    else:
        stats = compute_dataset_stats(dataset)
        if accelerator.is_main_process:
            json.dump(stats, open(stats_path, "w"), indent=2)

    global MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER, MIN_ROT, MAX_ROT
    MIN_HEIGHT, MAX_HEIGHT = stats["min_height"], stats["max_height"]
    MIN_GRIPPER, MAX_GRIPPER = stats["min_gripper"], stats["max_gripper"]
    MIN_ROT, MAX_ROT = stats["min_rot"], stats["max_rot"]
    min_r_t = torch.tensor(MIN_ROT, device=device)
    max_r_t = torch.tensor(MAX_ROT, device=device)
    coord_scale = PARA_OUT_SIZE / PRERENDER_SIZE

    logger.info(f"Dataset: {len(dataset)} samples, Height: [{MIN_HEIGHT:.3f}, {MAX_HEIGHT:.3f}]")

    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
                            num_workers=4, pin_memory=True, drop_last=True)

    # --- Optimizer (separate LRs: UNet barely moves, PARA heads learn fast) ---
    if args.freeze_unet:
        unet.requires_grad_(False)
        unet.eval()
        all_params = [{"params": para_heads.parameters(), "lr": args.lr * 2.0}]
        logger.info("UNet FROZEN — training PARA heads only")
    else:
        all_params = [
            {"params": unet.parameters(), "lr": args.lr * 0.02},      # 1e-6 for UNet
            {"params": para_heads.parameters(), "lr": args.lr * 2.0},  # 1e-4 for PARA heads
        ]
    optimizer = torch.optim.AdamW(all_params, weight_decay=1e-4)
    lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0,
                                  num_training_steps=args.max_steps)

    unet, para_heads, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        unet, para_heads, optimizer, dataloader, lr_scheduler)

    # --- EMA loss weighting ---
    loss_emas = {}

    global_step = 0
    progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)

    for epoch in range(99999):
        for batch in dataloader:
            if global_step >= args.max_steps:
                break

            with accelerator.accumulate(unet, para_heads):
                pixel_values = batch['pixel_values'].to(weight_dtype)  # (B, 3, T, H, W)
                B, C, T, H, W = pixel_values.shape

                # PARA targets (per frame)
                traj_2d = batch['trajectory_2d'].to(device)[:, :T]  # (B, T, 2) in 448 space
                traj_3d = batch['trajectory_3d'].to(device)[:, :T]
                traj_gripper = batch['trajectory_gripper'].to(device)[:, :T]
                traj_euler = batch['trajectory_euler'].to(device)[:, :T]

                traj_para = (traj_2d * coord_scale).clamp(0, PARA_OUT_SIZE - 1.001)  # (B, T, 2)
                target_h_bins = discretize(traj_3d[:, :, 2], MIN_HEIGHT, MAX_HEIGHT, N_HEIGHT_BINS)

                # --- SVD diffusion forward ---
                conditional_pixel_values = pixel_values[:, :, 0:1]  # first frame
                latents = tensor_to_vae_latent(pixel_values, vae)

                noise = torch.randn_like(latents)
                bsz = latents.shape[0]

                cond_sigmas = rand_log_normal(shape=[bsz], loc=-3.0, scale=0.5).to(latents)
                cond_sigmas_5d = cond_sigmas[:, None, None, None, None]
                conditional_pixel_values_noised = \
                    torch.randn_like(conditional_pixel_values) * cond_sigmas_5d + conditional_pixel_values
                conditional_latents = tensor_to_vae_latent(conditional_pixel_values_noised, vae)
                conditional_latents = conditional_latents[:, :, 0:1]  # (B, C, 1, h, w)
                conditional_latents = conditional_latents / vae.config.scaling_factor

                sigmas = rand_log_normal(shape=[bsz], loc=0.7, scale=1.6).to(latents.device)
                sigmas_5d = sigmas[:, None, None, None, None]
                noisy_latents = latents + noise * sigmas_5d
                timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device)

                inp_noisy_latents = noisy_latents / ((sigmas_5d**2 + 1) ** 0.5)

                image_embeddings = encode_image(
                    pixel_values[:, :, 0].float())

                added_time_ids = torch.stack([
                    torch.tensor([6.0]).expand(bsz),
                    torch.tensor([127.0]).expand(bsz),
                    cond_sigmas.cpu(),
                ], dim=1).to(device, dtype=weight_dtype)

                # Concat conditioning latent
                conditional_latents_expanded = conditional_latents.expand(-1, -1, T, -1, -1)
                inp = torch.cat([inp_noisy_latents, conditional_latents_expanded], dim=1)
                inp = inp.permute(0, 2, 1, 3, 4)  # (B, T, 8, h, w)

                # UNet forward
                captured.clear()
                model_pred = unet(inp, timesteps, encoder_hidden_states=image_embeddings,
                                  added_time_ids=added_time_ids).sample

                # Diffusion loss (v-prediction)
                target = noise  # simplified - actual SVD uses v-prediction
                sigmas_bc = sigmas_5d.permute(0, 2, 1, 3, 4)
                c_out = -sigmas_bc / ((sigmas_bc**2 + 1)**0.5)
                c_skip = 1 / (sigmas_bc**2 + 1)
                denoised = model_pred * c_out + c_skip * noisy_latents.permute(0, 2, 1, 3, 4)
                target_latents = latents.permute(0, 2, 1, 3, 4)
                weighting = (1 + sigmas_5d.permute(0, 2, 1, 3, 4)**2) / sigmas_5d.permute(0, 2, 1, 3, 4)**2
                diffusion_loss = (weighting * (denoised - target_latents)**2).mean()

                # --- PARA forward on captured UNet features ---
                feat1 = captured["up_block_1"].float()  # (B*T, 1280, 20, 36)
                feat2 = captured["up_block_2"].float()  # (B*T, 640, 40, 72)

                # Flatten PARA targets to (B*T, ...)
                traj_para_flat = traj_para.reshape(B * T, 2)
                target_h_flat = target_h_bins.reshape(B * T)
                traj_grip_flat = traj_gripper.reshape(B * T)
                traj_euler_flat = traj_euler.reshape(B * T, 3)

                vol_logits, grip_logits, rot_logits = para_heads(
                    feat1, feat2, query_pixels=traj_para_flat)

                volume_loss = compute_volume_loss(vol_logits, traj_para_flat, target_h_flat)
                gripper_loss = compute_gripper_loss(grip_logits, traj_grip_flat,
                                                    MIN_GRIPPER, MAX_GRIPPER)
                rotation_loss = compute_rotation_loss(rot_logits, traj_euler_flat,
                                                      min_r_t, max_r_t)

                # --- EMA adaptive loss weighting ---
                raw = {'vol': volume_loss.item(), 'grip': gripper_loss.item(),
                       'diff': diffusion_loss.item()}
                for k in raw:
                    if k not in loss_emas:
                        loss_emas[k] = raw[k]
                    loss_emas[k] = (1 - EMA_ALPHA) * loss_emas[k] + EMA_ALPHA * raw[k]

                active = [k for k in raw if loss_emas.get(k, 0) > 1e-10]
                if active:
                    inv_sum = sum(1.0 / (loss_emas[k] + 1e-8) for k in active)
                    n = len(active)
                    weights = {k: (n / inv_sum) / (loss_emas[k] + 1e-8) for k in active}
                else:
                    weights = {k: 1.0 for k in raw}

                total_loss = (weights.get('vol', 1) * volume_loss +
                              weights.get('grip', 1) * gripper_loss +
                              weights.get('diff', 1) * diffusion_loss)

                accelerator.backward(total_loss)
                if accelerator.sync_gradients:
                    all_p_flat = list(unet.parameters()) + list(para_heads.parameters())
                    accelerator.clip_grad_norm_(all_p_flat, 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # --- Logging ---
            if accelerator.is_main_process:
                if global_step % 5 == 0:
                    wandb.log({
                        "train/total_loss": total_loss.item(),
                        "train/volume_loss": volume_loss.item(),
                        "train/gripper_loss": gripper_loss.item(),
                        "train/rotation_loss": rotation_loss.item(),
                        "train/diffusion_loss": diffusion_loss.item(),
                        "train/w_vol": weights.get('vol', 1),
                        "train/w_grip": weights.get('grip', 1),
                        "train/w_diff": weights.get('diff', 1),
                    }, step=global_step)

                progress_bar.set_postfix(
                    vol=f"{volume_loss.item():.3f}",
                    grip=f"{gripper_loss.item():.3f}",
                    diff=f"{diffusion_loss.item():.3f}",
                )

                # --- Visualization ---
                if (global_step % args.vis_every == 0 and global_step > 0) or global_step == 5:
                    try:
                        unet_unwrapped = accelerator.unwrap_model(unet)
                        para_unwrapped = accelerator.unwrap_model(para_heads)

                        # --- 1) Training heatmaps on GT frames ---
                        vol_det = vol_logits[:T].detach().float()
                        vol_probs = F.softmax(vol_det.reshape(T, -1), dim=1)
                        vol_probs = vol_probs.view(T, N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
                        heatmaps_train = vol_probs.max(dim=1)[0].cpu().numpy()
                        pred_flat = vol_probs.max(dim=1)[0].view(T, -1).argmax(dim=1)
                        pred_py_train = pred_flat // PARA_OUT_SIZE
                        pred_px_train = pred_flat % PARA_OUT_SIZE

                        # --- 2) Generate video with fresh fp16 pipeline ---
                        import tempfile
                        tmp_dir = tempfile.mkdtemp()
                        unet_unwrapped.save_pretrained(os.path.join(tmp_dir, "unet"))
                        fresh_unet = UNetSpatioTemporalConditionModel.from_pretrained(
                            tmp_dir, subfolder="unet", torch_dtype=torch.float16)
                        pipe = StableVideoDiffusionPipeline.from_pretrained(
                            args.pretrained, unet=fresh_unet,
                            torch_dtype=torch.float16, variant="fp16")
                        pipe.to(device)

                        # Hook for PARA features on generated video
                        gen_captured = {}
                        def gen_hook(name):
                            def fn(mod, inp, out):
                                gen_captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
                            return fn
                        h1 = fresh_unet.up_blocks[1].register_forward_hook(gen_hook("ub1"))
                        h2 = fresh_unet.up_blocks[2].register_forward_hook(gen_hook("ub2"))

                        first_frame_raw = batch['rgb_frames_raw'][0, 0].cpu().numpy()
                        first_pil = Image.fromarray((first_frame_raw * 255).astype(np.uint8)).resize(
                            (args.width, args.height))

                        with torch.inference_mode():
                            gen_pil = pipe(first_pil, height=args.height, width=args.width,
                                         num_frames=args.num_frames, decode_chunk_size=4,
                                         num_inference_steps=25).frames[0]

                        h1.remove(); h2.remove()

                        # Run PARA on generated video features
                        gen_heatmaps = None
                        gen_pred_px = gen_pred_py = None
                        if "ub1" in gen_captured and "ub2" in gen_captured:
                            with torch.no_grad():
                                gen_vol, _, _ = para_unwrapped(gen_captured["ub1"], gen_captured["ub2"])
                            n_gen = min(gen_vol.shape[0], len(gen_pil), T)
                            gv_probs = F.softmax(gen_vol[:n_gen].reshape(n_gen, -1), dim=1)
                            gv_probs = gv_probs.view(n_gen, N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
                            gen_heatmaps = gv_probs.max(dim=1)[0].cpu().numpy()
                            gf = gv_probs.max(dim=1)[0].view(n_gen, -1).argmax(dim=1)
                            gen_pred_py = gf // PARA_OUT_SIZE
                            gen_pred_px = gf % PARA_OUT_SIZE

                        del pipe, fresh_unet
                        shutil.rmtree(tmp_dir, ignore_errors=True)
                        gen_captured.clear()
                        torch.cuda.empty_cache()

                        # --- 3) Build visualization: GT w/ heatmap | Gen w/ heatmap ---
                        vis_frames = []
                        n_vis = min(T, len(gen_pil), 7)
                        for t in range(n_vis):
                            gt_x = int(traj_2d[0, t, 0].item())
                            gt_y = int(traj_2d[0, t, 1].item())

                            # Left: GT frame + training heatmap
                            gt_frame = (batch['rgb_frames_raw'][0, t].cpu().numpy() * 255).astype(np.uint8).copy()
                            ht = heatmaps_train[t]
                            ht_n = (ht - ht.min()) / (ht.max() - ht.min() + 1e-8)
                            ht_up = cv2.resize(ht_n, (448, 448))
                            ht_c = cv2.applyColorMap((ht_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
                            ht_c = cv2.cvtColor(ht_c, cv2.COLOR_BGR2RGB)
                            left = (gt_frame * 0.5 + ht_c * 0.5).astype(np.uint8)
                            cv2.circle(left, (gt_x, gt_y), 8, (0, 255, 255), 3)
                            px_t = int(pred_px_train[t].item() / coord_scale)
                            py_t = int(pred_py_train[t].item() / coord_scale)
                            cv2.circle(left, (px_t, py_t), 8, (255, 0, 0), 3)
                            cv2.putText(left, f"GT t={t} PRED=red GT=cyan", (5, 25),
                                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

                            # Right: Generated frame + gen heatmap
                            gen_np = cv2.resize(np.array(gen_pil[t]), (448, 448))
                            if gen_heatmaps is not None and t < len(gen_heatmaps):
                                hg = gen_heatmaps[t]
                                hg_n = (hg - hg.min()) / (hg.max() - hg.min() + 1e-8)
                                hg_up = cv2.resize(hg_n, (448, 448))
                                hg_c = cv2.applyColorMap((hg_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
                                hg_c = cv2.cvtColor(hg_c, cv2.COLOR_BGR2RGB)
                                right = (gen_np * 0.5 + hg_c * 0.5).astype(np.uint8)
                                if gen_pred_px is not None:
                                    gpx = int(gen_pred_px[t].item() / coord_scale)
                                    gpy = int(gen_pred_py[t].item() / coord_scale)
                                    cv2.circle(right, (gpx, gpy), 8, (255, 0, 0), 3)
                                cv2.circle(right, (gt_x, gt_y), 8, (0, 255, 255), 2)
                            else:
                                right = gen_np
                            cv2.putText(right, f"Gen t={t} PRED=red GT=cyan", (5, 25),
                                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

                            combined = np.concatenate([left, right], axis=1)
                            vis_frames.append(combined)

                        if vis_frames:
                            vis_vid = np.stack(vis_frames).transpose(0, 3, 1, 2)
                            wandb.log({"vis/gt_vs_gen_heatmaps": wandb.Video(
                                vis_vid, fps=2, format="mp4")}, step=global_step)
                            logger.info(f"Logged vis at step {global_step}")

                    except Exception as e:
                        import traceback
                        logger.warning(f"Vis failed at step {global_step}: {e}\n{traceback.format_exc()}")

                # --- Checkpoint ---
                if global_step > 0 and global_step % args.ckpt_every == 0:
                    save_dir = Path(args.output_dir) / f"checkpoint-{global_step}"
                    save_dir.mkdir(exist_ok=True)
                    unet_unwrapped = accelerator.unwrap_model(unet)
                    unet_unwrapped.save_pretrained(save_dir / "unet")
                    torch.save({
                        "para_heads": accelerator.unwrap_model(para_heads).state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "stats": stats,
                        "global_step": global_step,
                    }, save_dir / "para_checkpoint.pt")
                    logger.info(f"Saved checkpoint at step {global_step}")

            global_step += 1
            progress_bar.update(1)

        if global_step >= args.max_steps:
            break

    if accelerator.is_main_process:
        wandb.finish()
    logger.info("Training complete!")


if __name__ == "__main__":
    main()
