"""Joint SVD video diffusion + Global Action Regression training.

Baseline comparison for PARA: same SVD UNet features (up_block_1 + up_block_2),
same projection + conv layers, but replaces spatial heatmap prediction with
global average pooling + MLP regression to (x, y, z, gripper) directly.

Usage:
    accelerate launch --config_file ... train_svd_global_action_regressor.py ...
"""

import argparse
import io
import json
import math
import os
import random
import shutil
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torchvision

import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from PIL import Image
from scipy.spatial.transform import Rotation as ScipyR

import diffusers
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, "/data/cameron/para_videopolicy")
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel
from data import CachedTrajectoryDataset

import decord
import imageio
import wandb

logger = get_logger(__name__, log_level="INFO")

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
PRERENDER_SIZE = 448
EMA_ALPHA = 0.01
PROJ_DIM = 128
N_GRIPPER_BINS = 32  # keep binned gripper for fair comparison

# Dataset stats (filled at runtime)
MIN_POS = MAX_POS = None  # 3D position bounds
MIN_GRIPPER = MAX_GRIPPER = 0.0


# ---------------------------------------------------------------------------
# Global Action Regression Head
# ---------------------------------------------------------------------------
class GlobalActionHead(nn.Module):
    """Global regression head on SVD UNet features.

    Same feature extraction as PARA (proj + concat + conv), but replaces
    spatial heatmap with global avg pool + MLP for direct (x,y,z,gripper).
    """

    def __init__(self, proj_dim=PROJ_DIM, hidden_dim=256):
        super().__init__()
        # Same projection as PARA
        self.proj_block1 = nn.Conv2d(1280, proj_dim, 1)
        self.proj_block2 = nn.Conv2d(640, proj_dim, 1)

        D = proj_dim * 2  # 256
        # Same conv refinement as PARA
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        # Global avg pool → MLP
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.position_mlp = nn.Sequential(
            nn.Linear(D, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, 3),  # (x, y, z) in normalized coords
        )
        self.gripper_mlp = nn.Sequential(
            nn.Linear(D, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, N_GRIPPER_BINS),
        )

    def forward(self, feat_block1, feat_block2):
        """
        Args:
            feat_block1: (B*T, 1280, 20, 36)
            feat_block2: (B*T, 640, 40, 72)
        Returns:
            pos_pred: (B*T, 3) predicted (x, y, z) in normalized [0, 1]
            grip_logits: (B*T, N_GRIPPER_BINS)
        """
        P = 64  # upsample target (same as PARA)

        f1 = self.proj_block1(feat_block1)
        f1 = F.interpolate(f1, size=(P, P), mode='bilinear', align_corners=False)

        f2 = self.proj_block2(feat_block2)
        f2 = F.interpolate(f2, size=(P, P), mode='bilinear', align_corners=False)

        feats = torch.cat([f1, f2], dim=1)  # (B*T, 256, P, P)
        feats = self.feature_convs(feats)

        pooled = self.pool(feats).squeeze(-1).squeeze(-1)  # (B*T, 256)

        pos_pred = self.position_mlp(pooled)  # (B*T, 3)
        grip_logits = self.gripper_mlp(pooled)  # (B*T, N_GRIPPER_BINS)

        return pos_pred, grip_logits


# ---------------------------------------------------------------------------
# Loss functions
# ---------------------------------------------------------------------------
def discretize(values, min_v, max_v, n_bins):
    norm = (values - min_v) / (max_v - min_v + 1e-8)
    return (norm.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)

def compute_position_loss(pred_pos, target_pos, min_pos, max_pos):
    """L1 loss on normalized 3D positions."""
    target_norm = (target_pos - min_pos) / (max_pos - min_pos + 1e-8)
    return F.l1_loss(pred_pos, target_norm)

def compute_gripper_loss(logits, target, min_g, max_g):
    target_bins = discretize(target, min_g, max_g, N_GRIPPER_BINS)
    return F.cross_entropy(logits, target_bins)


# ---------------------------------------------------------------------------
# Diffusion helpers
# ---------------------------------------------------------------------------
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return (torch.erfinv(2 * u - 1) * math.sqrt(2.0) * scale + loc).exp()

def tensor_to_vae_latent(t, vae):
    b, c, f, h, w = t.shape
    video = t.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
    latents = vae.encode(video).latent_dist.sample()
    latents = latents.reshape(b, f, *latents.shape[1:])
    latents = latents.permute(0, 2, 1, 3, 4)
    return latents

def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
    h, w = input.shape[-2:]
    factors = (h / size[0], w / size[1])
    sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
    ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
    if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]
    if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1
    input = F.pad(input, [ks[1]//2]*2 + [ks[0]//2]*2, mode="reflect")
    kernel = torch.ones(1, dtype=input.dtype, device=input.device)
    for s, k in zip(sigmas, ks):
        coord = torch.arange(k, dtype=input.dtype, device=input.device) - (k - 1) / 2
        g = torch.exp(-coord**2 / (2*s**2))
        g /= g.sum()
        kernel = kernel.unsqueeze(-1) * g.unsqueeze(0)
    kernel = kernel.expand(input.shape[-3], -1, -1).unsqueeze(1)
    input = F.conv2d(input, kernel, groups=input.shape[-3])
    return F.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)


# ---------------------------------------------------------------------------
# SVD Joint Dataset wrapper (same as PARA version)
# ---------------------------------------------------------------------------
class SVDGlobalDataset(Dataset):
    """Wraps CachedTrajectoryDataset for SVD + global action regression."""

    def __init__(self, cache_root, benchmark, task_ids, width=576, height=320,
                 sample_frames=7, image_size=448, frame_stride=1):
        self.para_dataset = CachedTrajectoryDataset(
            cache_root=cache_root, benchmark_name=benchmark,
            task_ids=task_ids, image_size=image_size,
            n_window=sample_frames, frame_stride=frame_stride,
        )
        self.width = width
        self.height = height
        self.sample_frames = sample_frames

    def __len__(self):
        return len(self.para_dataset)

    def __getitem__(self, idx):
        sample = self.para_dataset[idx]
        rgb_frames = sample['rgb_frames_raw']  # (T, 3, H, W) float [0,1] — actually (T, H, W, 3)
        T = rgb_frames.shape[0]

        if rgb_frames.ndim == 4 and rgb_frames.shape[-1] == 3:
            frames = rgb_frames.permute(0, 3, 1, 2)  # (T, 3, H, W)
        else:
            frames = rgb_frames

        frames_resized = F.interpolate(frames, size=(self.height, self.width),
                                       mode='bilinear', align_corners=False)
        video = frames_resized * 2.0 - 1.0
        video = video.permute(1, 0, 2, 3)  # (3, T, H, W)

        return {
            'pixel_values': video,
            'trajectory_3d': sample['trajectory_3d'],
            'trajectory_gripper': sample['trajectory_gripper'],
            'rgb_frames_raw': sample.get('rgb_frames_raw', rgb_frames),
        }


# ---------------------------------------------------------------------------
# Dataset stats
# ---------------------------------------------------------------------------
def compute_dataset_stats(dataset, n=500):
    rng = random.Random(42)
    indices = rng.sample(range(len(dataset)), min(n, len(dataset)))
    all_pos, all_g = [], []
    for idx in tqdm(indices, desc="Stats", leave=False):
        try:
            s = dataset.para_dataset[idx]
        except: continue
        all_pos.append(s['trajectory_3d'].numpy())
        all_g.extend(s['trajectory_gripper'].numpy().tolist())
    pos = np.concatenate(all_pos, 0)
    g = np.array(all_g)
    return {
        "min_pos": pos.min(0).tolist(), "max_pos": pos.max(0).tolist(),
        "min_gripper": float(g.min()), "max_gripper": float(g.max()),
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--pretrained", type=str, default="checkpoints/stable-video-diffusion-img2vid-xt-1-1")
    p.add_argument("--pretrain_unet", type=str, default=None)
    p.add_argument("--cache_root", type=str, default="/data/libero/ood_objpos_v3")
    p.add_argument("--task_ids", type=str, default="0")
    p.add_argument("--width", type=int, default=576)
    p.add_argument("--height", type=int, default=320)
    p.add_argument("--num_frames", type=int, default=7)
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=1)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--freeze_unet", action="store_true")
    p.add_argument("--max_steps", type=int, default=999999)
    p.add_argument("--output_dir", type=str, default="output_svd_global_action")
    p.add_argument("--ckpt_every", type=int, default=1000)
    p.add_argument("--vis_every", type=int, default=200)
    p.add_argument("--seed", type=int, default=123)
    args = p.parse_args()

    project_config = ProjectConfiguration(project_dir=args.output_dir)
    kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        project_config=project_config,
        kwargs_handlers=[kwargs],
    )

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
        wandb.init(project="svd_global_action", config=vars(args))

    set_seed(args.seed)
    device = accelerator.device
    weight_dtype = torch.bfloat16

    # --- Load models ---
    noise_scheduler = diffusers.EulerDiscreteScheduler.from_pretrained(
        args.pretrained, subfolder="scheduler")
    vae = AutoencoderKLTemporalDecoder.from_pretrained(
        args.pretrained, subfolder="vae").to(device, dtype=weight_dtype)
    vae.requires_grad_(False)
    vae.eval()

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        args.pretrained, subfolder="image_encoder").to(device, dtype=weight_dtype)
    image_encoder.requires_grad_(False)
    image_encoder.eval()

    feature_extractor = CLIPImageProcessor.from_pretrained(
        args.pretrained, subfolder="feature_extractor")

    unet_path = args.pretrain_unet or os.path.join(args.pretrained, "unet")
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        unet_path, subfolder="unet" if "unet" not in unet_path else None,
    ).to(device, dtype=weight_dtype)
    unet.enable_gradient_checkpointing()
    unet.train()

    # --- Global action head ---
    action_head = GlobalActionHead().to(device)
    action_head.train()

    # --- Feature hooks ---
    captured = {}
    def hook(name):
        def fn(mod, inp, out):
            captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
        return fn
    unet.up_blocks[1].register_forward_hook(hook("up_block_1"))
    unet.up_blocks[2].register_forward_hook(hook("up_block_2"))

    # --- CLIP image encoder helper ---
    @torch.no_grad()
    def encode_image(pixel_values_rgb):
        pv = _resize_with_antialiasing(pixel_values_rgb, (224, 224))
        pv = (pv + 1.0) / 2.0
        pv = torchvision.transforms.functional.normalize(
            pv, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
        return image_encoder(pv.to(device, weight_dtype)).image_embeds.unsqueeze(1)

    # --- Dataset ---
    task_ids = [int(x) for x in args.task_ids.split(",")]
    benchmark = "libero_spatial"
    dataset = SVDGlobalDataset(
        cache_root=args.cache_root, benchmark=benchmark, task_ids=task_ids,
        width=args.width, height=args.height,
        sample_frames=args.num_frames, frame_stride=args.frame_stride,
    )

    stats_path = Path(args.output_dir) / "dataset_stats.json"
    if stats_path.exists():
        stats = json.loads(stats_path.read_text())
    else:
        stats = compute_dataset_stats(dataset)
        if accelerator.is_main_process:
            stats_path.parent.mkdir(exist_ok=True, parents=True)
            stats_path.write_text(json.dumps(stats, indent=2))

    global MIN_POS, MAX_POS, MIN_GRIPPER, MAX_GRIPPER
    MIN_POS = torch.tensor(stats["min_pos"], device=device, dtype=torch.float32)
    MAX_POS = torch.tensor(stats["max_pos"], device=device, dtype=torch.float32)
    MIN_GRIPPER, MAX_GRIPPER = stats["min_gripper"], stats["max_gripper"]

    logger.info(f"Dataset: {len(dataset)} samples, Pos: [{MIN_POS.cpu().numpy()}, {MAX_POS.cpu().numpy()}]")

    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
                            num_workers=4, pin_memory=True, drop_last=True)

    # --- Optimizer ---
    if args.freeze_unet:
        unet.requires_grad_(False)
        unet.eval()
        all_params = [{"params": action_head.parameters(), "lr": args.lr * 2.0}]
        logger.info("UNet FROZEN — training action head only")
    else:
        all_params = [
            {"params": unet.parameters(), "lr": args.lr * 0.02},
            {"params": action_head.parameters(), "lr": args.lr * 2.0},
        ]
    optimizer = torch.optim.AdamW(all_params, weight_decay=1e-4)
    lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0,
                                  num_training_steps=args.max_steps)

    unet, action_head, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        unet, action_head, optimizer, dataloader, lr_scheduler)

    # --- EMA loss weighting ---
    loss_emas = {}

    global_step = 0
    progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)

    for epoch in range(99999):
        for batch in dataloader:
            if global_step >= args.max_steps:
                break

            with accelerator.accumulate(unet, action_head):
                pixel_values = batch['pixel_values'].to(weight_dtype)
                B, C, T, H, W = pixel_values.shape

                # Action targets
                traj_3d = batch['trajectory_3d'].to(device)[:, :T]  # (B, T, 3)
                traj_gripper = batch['trajectory_gripper'].to(device)[:, :T]

                # --- SVD diffusion forward ---
                conditional_pixel_values = pixel_values[:, :, 0:1]
                latents = tensor_to_vae_latent(pixel_values, vae)

                noise = torch.randn_like(latents)
                bsz = latents.shape[0]

                cond_sigmas = rand_log_normal(shape=[bsz], loc=-3.0, scale=0.5).to(latents)
                cond_sigmas_5d = cond_sigmas[:, None, None, None, None]
                conditional_pixel_values_noised = \
                    torch.randn_like(conditional_pixel_values) * cond_sigmas_5d + conditional_pixel_values
                conditional_latents = tensor_to_vae_latent(conditional_pixel_values_noised, vae)
                conditional_latents = conditional_latents[:, :, 0:1]
                conditional_latents = conditional_latents / vae.config.scaling_factor

                sigmas = rand_log_normal(shape=[bsz], loc=0.7, scale=1.6).to(latents.device)
                sigmas_5d = sigmas[:, None, None, None, None]
                noisy_latents = latents + noise * sigmas_5d
                timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device, dtype=weight_dtype)

                inp_noisy_latents = noisy_latents / ((sigmas_5d**2 + 1) ** 0.5)

                image_embeddings = encode_image(pixel_values[:, :, 0])

                added_time_ids = torch.stack([
                    torch.tensor([6.0]).expand(bsz),
                    torch.tensor([127.0]).expand(bsz),
                    cond_sigmas.cpu(),
                ], dim=1).to(device, dtype=weight_dtype)

                conditional_latents_expanded = conditional_latents.expand(-1, -1, T, -1, -1)
                inp = torch.cat([inp_noisy_latents, conditional_latents_expanded], dim=1)
                inp = inp.permute(0, 2, 1, 3, 4)

                # UNet forward
                captured.clear()
                model_pred = unet(inp, timesteps, encoder_hidden_states=image_embeddings,
                                  added_time_ids=added_time_ids).sample

                # Diffusion loss
                sigmas_bc = sigmas_5d.permute(0, 2, 1, 3, 4)
                c_out = -sigmas_bc / ((sigmas_bc**2 + 1)**0.5)
                c_skip = 1 / (sigmas_bc**2 + 1)
                denoised = model_pred * c_out + c_skip * noisy_latents.permute(0, 2, 1, 3, 4)
                target_latents = latents.permute(0, 2, 1, 3, 4)
                weighting = (1 + sigmas_5d.permute(0, 2, 1, 3, 4)**2) / sigmas_5d.permute(0, 2, 1, 3, 4)**2
                diffusion_loss = (weighting * (denoised - target_latents)**2).mean()

                # --- Global action head forward ---
                feat1 = captured["up_block_1"].float()
                feat2 = captured["up_block_2"].float()

                traj_3d_flat = traj_3d.reshape(B * T, 3)
                traj_grip_flat = traj_gripper.reshape(B * T)

                pos_pred, grip_logits = action_head(feat1, feat2)

                position_loss = compute_position_loss(pos_pred, traj_3d_flat, MIN_POS, MAX_POS)
                gripper_loss = compute_gripper_loss(grip_logits, traj_grip_flat,
                                                    MIN_GRIPPER, MAX_GRIPPER)

                # --- EMA adaptive loss weighting ---
                raw = {'pos': position_loss.item(), 'grip': gripper_loss.item(),
                       'diff': diffusion_loss.item()}
                for k in raw:
                    if k not in loss_emas:
                        loss_emas[k] = raw[k]
                    loss_emas[k] = (1 - EMA_ALPHA) * loss_emas[k] + EMA_ALPHA * raw[k]

                active = [k for k in raw if loss_emas.get(k, 0) > 1e-10]
                if active:
                    inv_sum = sum(1.0 / (loss_emas[k] + 1e-8) for k in active)
                    n = len(active)
                    weights = {k: (n / inv_sum) / (loss_emas[k] + 1e-8) for k in active}
                else:
                    weights = {k: 1.0 for k in raw}

                total_loss = (weights.get('pos', 1) * position_loss +
                              weights.get('grip', 1) * gripper_loss +
                              weights.get('diff', 1) * diffusion_loss)

                accelerator.backward(total_loss)
                if accelerator.sync_gradients:
                    all_p_flat = list(unet.parameters()) + list(action_head.parameters())
                    accelerator.clip_grad_norm_(all_p_flat, 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # --- Logging ---
            if accelerator.is_main_process:
                if global_step % 5 == 0:
                    wandb.log({
                        "train/total_loss": total_loss.item(),
                        "train/position_loss": position_loss.item(),
                        "train/gripper_loss": gripper_loss.item(),
                        "train/diffusion_loss": diffusion_loss.item(),
                        "train/w_pos": weights.get('pos', 1),
                        "train/w_grip": weights.get('grip', 1),
                        "train/w_diff": weights.get('diff', 1),
                    }, step=global_step)

                progress_bar.set_postfix(
                    pos=f"{position_loss.item():.4f}",
                    grip=f"{gripper_loss.item():.3f}",
                    diff=f"{diffusion_loss.item():.3f}",
                )

                # --- Checkpoint ---
                if global_step > 0 and global_step % args.ckpt_every == 0:
                    save_dir = Path(args.output_dir) / f"checkpoint-{global_step}"
                    save_dir.mkdir(exist_ok=True)
                    unet_unwrapped = accelerator.unwrap_model(unet)
                    unet_unwrapped.save_pretrained(save_dir / "unet")
                    torch.save({
                        "action_head": accelerator.unwrap_model(action_head).state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "stats": stats,
                        "global_step": global_step,
                    }, save_dir / "action_checkpoint.pt")
                    logger.info(f"Saved checkpoint at step {global_step}")

            global_step += 1
            progress_bar.update(1)

        if global_step >= args.max_steps:
            break

    if accelerator.is_main_process:
        wandb.finish()
    logger.info("Training complete!")


if __name__ == "__main__":
    main()
