"""Train DinoVolumeFiLM on libero with long context.

Long-context variant: n_window=48, frame_stride=2 → ~96 frames covered per sample
which is most/all of a typical libero demo (80-100 frames). The FiLM volume head
replaces the rank-1 bilinear scoring with a (t,z)-conditioned bottleneck MLP per
voxel (Peebles AdaLN-Zero). Gripper/rotation read from the MLP penultimate at
the per-(b,t) argmax voxel.

Headline metric: train_pix_argmax (val is multimodal, less informative).
"""
import os, sys, time, json, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
from tqdm import tqdm
import wandb
import cv2

sys.path.insert(0, os.path.dirname(__file__))
from data import CachedTrajectoryDataset
from model_dino_volume_film import (DinoVolumeFiLM, IMG_SIZE, N_WINDOW, N_HEIGHT_BINS,
                                     N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE)


# ---------------- Discretisation helpers ----------------

def discretize(values, lo, hi, n_bins):
    norm = (values - lo) / max(hi - lo, 1e-8)
    return (norm.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)


def compute_dataset_stats(ds):
    """One-shot scan of the cached dataset for z/grip/rot ranges."""
    all_z, all_g, all_r = [], [], []
    for demo in ds.demos:
        all_z.append(demo["eef_pos"][:, 2])
        all_g.append(demo["gripper"])
    z = np.concatenate(all_z); g = np.concatenate(all_g)
    z_lo, z_hi = float(z.min()), float(z.max())
    z_pad = (z_hi - z_lo) * 0.05
    g_lo, g_hi = float(g.min()), float(g.max())
    g_pad = (g_hi - g_lo) * 0.05
    return {"min_height": z_lo - z_pad, "max_height": z_hi + z_pad,
            "min_grip":   g_lo - g_pad, "max_grip":   g_hi + g_pad,
            "min_rot": [-3.14159] * 3, "max_rot": [3.14159] * 3}


# ---------------- Viz helpers ----------------

def rainbow_overlay(rgb, pred_pix, gt_pix, valid, img_size=IMG_SIZE):
    """rgb: (3, H, W) normalized. pred_pix, gt_pix: (T, 2). valid: (T,) bool."""
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    img = (rgb.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)
    img = (img * 255).astype(np.uint8).copy()
    T = pred_pix.shape[0]
    for t in range(T):
        if not valid[t]:
            continue
        # rainbow: hsv → rgb
        hue = int(t / max(T - 1, 1) * 170)
        col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
        px, py = int(pred_pix[t, 0]), int(pred_pix[t, 1])
        gx, gy = int(gt_pix[t, 0]),   int(gt_pix[t, 1])
        cv2.circle(img, (px, py), 4, col, -1)
        cv2.circle(img, (gx, gy), 4, (255, 255, 255), 1)
    return img


def marginal_heatmap(vol_logits):
    """vol_logits: (T, Z, H, W). Returns side-by-side per-t marginal image."""
    T, Z, H, W = vol_logits.shape
    # Marginalize over Z, softmax over (H,W) for each t
    flat = vol_logits.reshape(T, Z * H * W)
    p = torch.softmax(flat, dim=-1).reshape(T, Z, H, W)
    margin = p.sum(dim=1)                                          # (T, H, W)
    # Normalize each t's heatmap to [0, 255]
    tiles = []
    for t in range(T):
        m = margin[t].cpu().numpy()
        m = (m / (m.max() + 1e-8) * 255).astype(np.uint8)
        m = cv2.applyColorMap(m, cv2.COLORMAP_JET)
        m = cv2.resize(m, (96, 96), interpolation=cv2.INTER_NEAREST)
        tiles.append(m)
    # Compose into a grid: 8 cols
    cols = 8
    rows = (T + cols - 1) // cols
    grid_h, grid_w = rows * 96, cols * 96
    grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
    for i, m in enumerate(tiles):
        r, c = i // cols, i % cols
        grid[r*96:(r+1)*96, c*96:(c+1)*96] = m
    return cv2.cvtColor(grid, cv2.COLOR_BGR2RGB)


def gripper_strip(grip_logits, gt_grip, valid):
    """grip_logits: (T, n_grip), gt_grip: (T,)."""
    T, B = grip_logits.shape
    rows = []
    for t in range(T):
        if not valid[t]:
            rows.append(np.zeros((40, B * 5 + 20, 3), dtype=np.uint8))
            continue
        bar = np.zeros((40, B * 5 + 20, 3), dtype=np.uint8)
        p = torch.softmax(grip_logits[t], dim=-1).cpu().numpy()
        for b in range(B):
            h = int(p[b] / (p.max() + 1e-8) * 32)
            cv2.rectangle(bar, (10 + b*5, 35 - h), (12 + b*5, 35), (200, 150, 80), -1)
        gt = int(gt_grip[t])
        cv2.rectangle(bar, (10 + gt*5, 2), (12 + gt*5, 35), (0, 0, 255), 1)
        rows.append(bar)
    return cv2.cvtColor(np.concatenate(rows, axis=0), cv2.COLOR_BGR2RGB)


# ---------------- Main ----------------

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids", type=str, default="0",
                   help="Comma-separated task ids. Use 'all' for all tasks.")
    p.add_argument("--max_demos", type=int, default=0,
                   help="Max demos per task (0 = all).")
    p.add_argument("--n_window", type=int, default=48)
    p.add_argument("--frame_stride", type=int, default=2)
    p.add_argument("--batch_size", type=int, default=2)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--gripper_loss_weight",  type=float, default=0.5)
    p.add_argument("--rotation_loss_weight", type=float, default=0.5)
    p.add_argument("--freeze_backbone", action="store_true")
    p.add_argument("--vis_every_steps",  type=int, default=50)
    p.add_argument("--save_every_epochs", type=int, default=5)
    p.add_argument("--run_name", type=str, default="film_volume_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    p.add_argument("--use_checkpoint", type=int, default=1)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    task_ids = None if args.task_ids == "all" else [int(t) for t in args.task_ids.split(",") if t.strip()]
    print(f"Loading libero cache: {args.cache_root}/{args.benchmark} task_ids={task_ids}")
    full = CachedTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=task_ids,
        image_size=IMG_SIZE,
        n_window=args.n_window,
        frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    stats = compute_dataset_stats(full)
    print(f"Stats: min_h={stats['min_height']:.4f} max_h={stats['max_height']:.4f} "
          f"min_g={stats['min_grip']:.4f} max_g={stats['max_grip']:.4f}")
    json.dump(stats, open(ckpt_dir / "dataset_stats.json", "w"), indent=2)

    n = len(full)
    n_val = max(1, int(n * args.val_split))
    train_ds, val_ds = random_split(full, [n - n_val, n_val], generator=torch.Generator().manual_seed(42))
    print(f"Train: {len(train_ds)}  Val: {len(val_ds)}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                               num_workers=args.num_workers, pin_memory=True, drop_last=True,
                               persistent_workers=args.num_workers > 0)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True,
                             persistent_workers=args.num_workers > 0)

    print("Building model...")
    model = DinoVolumeFiLM(
        n_window=args.n_window, n_height_bins=N_HEIGHT_BINS,
        n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
        image_size=IMG_SIZE, pred_size=PRED_SIZE,
        freeze_backbone=args.freeze_backbone,
        use_checkpoint=bool(args.use_checkpoint),
    ).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)
    global_step = 0; t0 = time.time()

    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb     = batch["rgb"].to(device, non_blocking=True)                  # (B, 3, IMG, IMG)
            traj2d  = batch["trajectory_2d"].to(device, non_blocking=True)         # (B, T, 2) in IMG space
            traj3d  = batch["trajectory_3d"].to(device, non_blocking=True)         # (B, T, 3)
            grip_v  = batch["trajectory_gripper"].to(device, non_blocking=True)    # (B, T)
            euler   = batch["trajectory_euler"].to(device, non_blocking=True)      # (B, T, 3)
            B, T, _ = traj2d.shape

            # Discretise targets
            gt_pix_grid = traj2d * (PRED_SIZE / IMG_SIZE)                          # (B, T, 2)
            gx = gt_pix_grid[..., 0].long().clamp(0, PRED_SIZE - 1)
            gy = gt_pix_grid[..., 1].long().clamp(0, PRED_SIZE - 1)
            gz_cont = traj3d[..., 2]
            gz = discretize(gz_cont, stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
            gg = discretize(grip_v, stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS)
            gr = torch.stack([discretize(euler[..., a], stats["min_rot"][a], stats["max_rot"][a], N_ROT_BINS)
                              for a in range(3)], dim=-1)                           # (B, T, 3)

            kp_zyx = torch.stack([gz, gy, gx], dim=-1)                              # (B, T, 3) — teacher force

            out = model(rgb, kp_zyx=kp_zyx)
            vol         = out["volume_logits"]                                       # (B, T, Z, H, W)
            grip_logits = out["gripper_logits"]                                      # (B, T, n_grip)
            rot_logits  = out["rotation_logits"]                                     # (B, T, 3, n_rot)
            Z = vol.shape[2]; H, W = vol.shape[-2:]

            # Volume CE per-t over flat (Z*H*W)
            tgt_flat = gz * (H * W) + gy * W + gx                                    # (B, T)
            volume_loss = F.cross_entropy(vol.reshape(B * T, Z * H * W),
                                           tgt_flat.reshape(-1))
            # Gripper CE
            gripper_loss = F.cross_entropy(grip_logits.reshape(B * T, -1),
                                            gg.reshape(-1))
            # Rotation CE — per-axis
            rotation_loss = sum(
                F.cross_entropy(rot_logits[..., a, :].reshape(B * T, -1),
                                 gr[..., a].reshape(-1))
                for a in range(3)
            ) / 3.0

            total = volume_loss + args.gripper_loss_weight * gripper_loss \
                                 + args.rotation_loss_weight * rotation_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(v=f"{volume_loss.item():.2f}",
                              g=f"{gripper_loss.item():.2f}",
                              r=f"{rotation_loss.item():.2f}")
            global_step += 1

            if global_step % 10 == 0:
                wandb.log({"train/volume_loss":   volume_loss.item(),
                           "train/gripper_loss":  gripper_loss.item(),
                           "train/rotation_loss": rotation_loss.item(),
                           "train/total":         total.item(),
                           "epoch": epoch}, step=global_step)

            if global_step % args.vis_every_steps == 0:
                model.eval()
                with torch.no_grad():
                    idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    s = train_ds[idx]
                    v_rgb = s["rgb"].unsqueeze(0).to(device)
                    v_pix = s["trajectory_2d"].cpu().numpy()
                    v_grip = discretize(s["trajectory_gripper"],
                                          stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS).cpu().numpy()
                    valid = np.ones(args.n_window, dtype=bool)
                    # Teacher-forced kp for the viz forward
                    v_traj2d = s["trajectory_2d"].unsqueeze(0).to(device)
                    v_traj3d = s["trajectory_3d"].unsqueeze(0).to(device)
                    v_gx = (v_traj2d[..., 0] * (PRED_SIZE / IMG_SIZE)).long().clamp(0, PRED_SIZE - 1)
                    v_gy = (v_traj2d[..., 1] * (PRED_SIZE / IMG_SIZE)).long().clamp(0, PRED_SIZE - 1)
                    v_gz = discretize(v_traj3d[..., 2],
                                        stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
                    v_kp = torch.stack([v_gz, v_gy, v_gx], dim=-1)
                    vo = model(v_rgb, kp_zyx=v_kp)
                    v_vol = vo["volume_logits"][0]                                   # (T, Z, H, W)
                    v_grip_logits = vo["gripper_logits"][0]                          # (T, n_grip)
                    # Argmax decode
                    Tv, Zv, Hv, Wv = v_vol.shape
                    flat = v_vol.reshape(Tv, -1).argmax(dim=-1).cpu().numpy()
                    pz = flat // (Hv * Wv)
                    pyx = flat % (Hv * Wv)
                    py = (pyx // Wv).astype(np.float32) / (Hv / IMG_SIZE)
                    px = (pyx % Wv).astype(np.float32) / (Wv / IMG_SIZE)
                    v_pred_pix = np.stack([px, py], axis=-1)
                    train_pix_err = float(np.linalg.norm(v_pred_pix - v_pix, axis=-1).mean())
                viz_kp = rainbow_overlay(s["rgb"], v_pred_pix, v_pix, valid)
                viz_hm = marginal_heatmap(v_vol)
                viz_gd = gripper_strip(v_grip_logits, v_grip, valid)
                wandb.log({
                    "vis/keypoints":     wandb.Image(viz_kp),
                    "vis/heatmap":       wandb.Image(viz_hm),
                    "vis/gripper_dist":  wandb.Image(viz_gd),
                    "train/pix_argmax":  train_pix_err,
                }, step=global_step)
                model.train()

        # End-of-epoch val
        model.eval()
        vv, vp, vg = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb    = batch["rgb"].to(device)
                traj2d = batch["trajectory_2d"].to(device)
                traj3d = batch["trajectory_3d"].to(device)
                grip_v = batch["trajectory_gripper"].to(device)
                B, T, _ = traj2d.shape

                gx_ = (traj2d[..., 0] * (PRED_SIZE / IMG_SIZE)).long().clamp(0, PRED_SIZE - 1)
                gy_ = (traj2d[..., 1] * (PRED_SIZE / IMG_SIZE)).long().clamp(0, PRED_SIZE - 1)
                gz_ = discretize(traj3d[..., 2], stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
                gg_ = discretize(grip_v, stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS)
                kp_zyx = torch.stack([gz_, gy_, gx_], dim=-1)
                out = model(rgb, kp_zyx=kp_zyx)
                vol = out["volume_logits"]; grip = out["gripper_logits"]
                Zv, Hv, Wv = vol.shape[2:]
                tgt = gz_ * (Hv * Wv) + gy_ * Wv + gx_
                vv.append(F.cross_entropy(vol.reshape(B * T, -1), tgt.reshape(-1)).item())
                # Val pix err (argmax decode)
                flat = vol.reshape(B, T, -1).argmax(dim=-1)
                pyx = flat % (Hv * Wv)
                py = (pyx // Wv).float() / (Hv / IMG_SIZE)
                px = (pyx % Wv).float()  / (Wv / IMG_SIZE)
                pred_pix = torch.stack([px, py], dim=-1)
                vp.append((pred_pix - traj2d).norm(dim=-1).mean().item())
                # Val gripper acc
                vg.append((grip.argmax(dim=-1) == gg_).float().mean().item())
        v_v, v_p, v_g = float(np.mean(vv)), float(np.mean(vp)), float(np.mean(vg))
        print(f"Epoch {epoch}: val_v={v_v:.3f}  val_pix={v_p:.1f}px  val_grip_acc={v_g:.3f}")
        wandb.log({"epoch_end/val_v": v_v, "epoch_end/val_pix": v_p, "epoch_end/val_grip_acc": v_g,
                   "epoch_end/epoch": epoch}, step=global_step)

        is_save = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "args": vars(args), "stats": stats}
            torch.save(ckpt, ckpt_dir / "latest.pth")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
