"""Train DinoVolumeQuery on LIBERO data (CachedTrajectoryDataset).

Mirror of train_dino_volume_query.py but adapted to libero's dataset interface.
Defaults to libero_spatial task 0, 1D PCA rotation (PC1 EV ≈ 99.4% on libero — essentially
lossless).

Outputs same checkpoint format with embedded PCA basis + height/grip ranges so the
adapted libero/eval_libero_query.py can decode at deploy.
"""
import os, sys, time, math, argparse, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import wandb
import cv2

sys.path.insert(0, os.path.dirname(__file__))
from data import CachedTrajectoryDataset
from model_dino_volume_query import (DinoVolumeQuery, IMG_SIZE as MODEL_IMG_SIZE,
                                     N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE)

# Libero uses 448 — override the model's 504 default
LIBERO_IMG = 448


# ---------------- Discretisation helpers ----------------

def discretize(values, lo, hi, n_bins):
    norm = (values - lo) / max(hi - lo, 1e-8)
    return (norm.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)


def compute_dataset_stats(ds):
    """Pre-scan demos for height/gripper/euler stats."""
    all_z, all_g, all_e = [], [], []
    for demo in ds.demos:
        all_z.append(demo['eef_pos'][:, 2])
        all_g.append(demo['gripper'])
        eulers = np.stack([R.from_quat(q).as_euler('xyz') for q in demo['eef_quat']])
        all_e.append(eulers)
    z = np.concatenate(all_z); g = np.concatenate(all_g); e = np.concatenate(all_e, axis=0)
    z_lo, z_hi = float(z.min()), float(z.max())
    z_pad = (z_hi - z_lo) * 0.05
    g_lo, g_hi = float(g.min()), float(g.max())
    g_pad = (g_hi - g_lo) * 0.05
    return {
        "min_height": z_lo - z_pad, "max_height": z_hi + z_pad,
        "min_grip":   g_lo - g_pad, "max_grip":   g_hi + g_pad,
        "eulers_for_pca": e,
    }


def piano_roll(logits, gt_bins, valid=None, cell_h=6, cell_w=6, gt_color=(0, 0, 255)):
    T, n = logits.shape
    p = torch.softmax(logits, dim=-1).cpu().numpy()
    p_norm = p / (p.max(axis=1, keepdims=True) + 1e-8)
    img8 = (p_norm * 255).astype(np.uint8)
    img = cv2.applyColorMap(img8, cv2.COLORMAP_VIRIDIS)
    for t in range(T):
        gt = int(gt_bins[t])
        img[t, gt] = gt_color
        if valid is not None and not valid[t]:
            img[t] = (img[t].astype(np.float32) * 0.3).astype(np.uint8)
    img = cv2.resize(img, (n * cell_w, T * cell_h), interpolation=cv2.INTER_NEAREST)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def rainbow_overlay(rgb_chw, pred_pix, gt_pix, img_size=LIBERO_IMG):
    # Libero stores rgb ImageNet-normed
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    img = (rgb_chw.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)
    img = (img * 255).astype(np.uint8).copy()
    T = pred_pix.shape[0]
    for t in range(T):
        hue = int(t / max(T - 1, 1) * 170)
        col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
        px, py = int(pred_pix[t, 0]), int(pred_pix[t, 1])
        gx, gy = int(gt_pix[t, 0]),   int(gt_pix[t, 1])
        cv2.circle(img, (px, py), 4, col, -1)
        cv2.circle(img, (gx, gy), 4, (255, 255, 255), 1)
    return img


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids", type=str, default="0", help="comma-separated, or 'all'")
    p.add_argument("--max_demos", type=int, default=0)
    p.add_argument("--n_window", type=int, default=8)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--gripper_loss_weight",  type=float, default=0.5)
    p.add_argument("--rotation_loss_weight", type=float, default=0.5)
    p.add_argument("--vis_every_steps",  type=int, default=50)
    p.add_argument("--log_scalars_every", type=int, default=10)
    p.add_argument("--save_every_epochs", type=int, default=5)
    p.add_argument("--resume_from", type=str, default="")
    p.add_argument("--rot_pca_path", type=str, default="/data/cameron/para/libero/rotation_pca_basis_libero_spatial_t0.npz")
    p.add_argument("--run_name", type=str, default="libero_query_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    task_ids = None if args.task_ids == "all" else [int(t) for t in args.task_ids.split(",") if t.strip()]
    print(f"Loading libero cache: {args.cache_root}/{args.benchmark} task_ids={task_ids}")
    full = CachedTrajectoryDataset(
        cache_root=args.cache_root, benchmark_name=args.benchmark, task_ids=task_ids,
        image_size=LIBERO_IMG, n_window=args.n_window, frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    stats = compute_dataset_stats(full)
    # 1D PCA rotation basis (libero is essentially 1D — PC1 EV ≈ 99%)
    pca_basis = np.load(args.rot_pca_path)
    print(f"  PCA basis: PC1 EV={float(pca_basis['ev_ratio_pc1']):.4f}, "
          f"axis={pca_basis['principal_axis']}")
    rot_mean = torch.tensor(pca_basis['mean'], dtype=torch.float32)
    rot_axis = torch.tensor(pca_basis['principal_axis'], dtype=torch.float32)
    rot_pca_min = float(pca_basis['pca_min'])
    rot_pca_max = float(pca_basis['pca_max'])
    print(f"  Height range: [{stats['min_height']:.3f}, {stats['max_height']:.3f}]")
    print(f"  Grip range:   [{stats['min_grip']:.3f}, {stats['max_grip']:.3f}]")

    n = len(full); n_val = max(1, int(n * args.val_split))
    train_ds, val_ds = random_split(full, [n - n_val, n_val],
                                    generator=torch.Generator().manual_seed(42))
    print(f"  Train: {n - n_val}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True,
                              persistent_workers=args.num_workers > 0)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True,
                            persistent_workers=args.num_workers > 0)

    print("Building model (DinoVolumeQuery, 1D PCA rotation, libero image_size=448)...")
    model = DinoVolumeQuery(
        n_window=args.n_window, n_height_bins=N_HEIGHT_BINS,
        n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
        image_size=LIBERO_IMG, pred_size=PRED_SIZE,
        use_eef=True, rotation_mode='1d_pca',
    ).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    if args.resume_from:
        print(f"Resuming from {args.resume_from}")
        sd_full = torch.load(args.resume_from, map_location=device, weights_only=False)
        ckpt_sd = sd_full["model_state_dict"]
        cur_sd = model.state_dict()
        loaded = {k: v for k, v in ckpt_sd.items() if k in cur_sd and cur_sd[k].shape == v.shape}
        missing, unexpected = model.load_state_dict(loaded, strict=False)
        print(f"  loaded={len(loaded)} keys; missing={len(missing)}, unexpected={len(unexpected)}")

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=args.lr, weight_decay=1e-4)

    global_step = 0
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb     = batch["rgb"].to(device, non_blocking=True)               # (B, 3, 448, 448) ImageNet-normed
            traj2d  = batch["trajectory_2d"].to(device, non_blocking=True)     # (B, T, 2) in 448-px space
            traj3d  = batch["trajectory_3d"].to(device, non_blocking=True)     # (B, T, 3)
            grip_v  = batch["trajectory_gripper"].to(device, non_blocking=True) # (B, T)
            quat    = batch["trajectory_quat"].cpu().numpy()                    # (B, T, 4)
            B, T, _ = traj2d.shape

            # Discretise targets
            gz_cont = traj3d[..., 2]
            gz = discretize(gz_cont, stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
            gg = discretize(grip_v, stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS)
            # 1D PCA rotation: project euler onto PC1
            euler_t = torch.tensor(
                np.stack([[R.from_quat(q).as_euler('xyz') for q in qs] for qs in quat]),
                dtype=torch.float32, device=device)                             # (B, T, 3)
            proj = (euler_t - rot_mean.to(device)) @ rot_axis.to(device)        # (B, T)
            gr_norm = (proj - rot_pca_min) / max(rot_pca_max - rot_pca_min, 1e-8)
            gr = (gr_norm.clamp(0, 1) * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)

            # start_pix = trajectory_2d[..., 0, :]
            start_pix = traj2d[:, 0, :]
            # gt_pix = future T pixels (the trajectory IS the future; libero's t=0 is current)
            # We'll predict T pixels matching the dataset's T positions.
            gt_pix = traj2d

            out = model(rgb, start_pix=start_pix)
            vol         = out["volume_logits"]                                  # (B, T, Z, H, W)
            grip_logits = out["gripper_logits"]                                 # (B, T, n_grip)
            rot_logits  = out["rotation_logits"]                                # (B, T, n_rot)
            Z = vol.shape[2]; H, W = vol.shape[-2:]

            gx = (gt_pix[..., 0] * (W / LIBERO_IMG)).long().clamp(0, W - 1)
            gy = (gt_pix[..., 1] * (H / LIBERO_IMG)).long().clamp(0, H - 1)
            gz_ = gz.clamp(0, Z - 1)
            tgt_flat = gz_ * (H * W) + gy * W + gx                              # (B, T)

            volume_loss   = F.cross_entropy(vol.reshape(B * T, Z * H * W), tgt_flat.reshape(-1))
            gripper_loss  = F.cross_entropy(grip_logits.reshape(B * T, -1), gg.reshape(-1))
            rotation_loss = F.cross_entropy(rot_logits.reshape(B * T, -1),  gr.reshape(-1))
            total = volume_loss + args.gripper_loss_weight * gripper_loss \
                                + args.rotation_loss_weight * rotation_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(v=f"{volume_loss.item():.2f}",
                             g=f"{gripper_loss.item():.2f}",
                             r=f"{rotation_loss.item():.2f}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                with torch.no_grad():
                    flat_v = vol.reshape(B, T, -1).argmax(dim=-1)
                    pyx = flat_v % (H * W)
                    py_p = (pyx // W).float() / (H / LIBERO_IMG)
                    px_p = (pyx % W).float()  / (W / LIBERO_IMG)
                    pred_pix = torch.stack([px_p, py_p], dim=-1)
                    train_pix = (pred_pix - gt_pix).norm(dim=-1).mean()
                    train_grip_acc = (grip_logits.argmax(-1) == gg).float().mean()
                    train_rot_acc = (rot_logits.argmax(-1) == gr).float().mean()
                wandb.log({"train/volume_loss":   volume_loss.item(),
                           "train/gripper_loss":  gripper_loss.item(),
                           "train/rotation_loss": rotation_loss.item(),
                           "train/total":         total.item(),
                           "train/pix_argmax":    train_pix.item(),
                           "train/grip_acc":      train_grip_acc.item(),
                           "train/rot_acc":       train_rot_acc.item(),
                           "epoch": epoch}, step=global_step)

            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                model.eval()
                with torch.no_grad():
                    idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    s = train_ds[idx]
                    v_rgb = s["rgb"].unsqueeze(0).to(device)
                    v_sp  = s["trajectory_2d"][0].unsqueeze(0).to(device)
                    v_pix = s["trajectory_2d"].cpu().numpy()
                    v_eul = np.stack([R.from_quat(q).as_euler('xyz') for q in s["trajectory_quat"].numpy()])
                    v_eul_t = torch.tensor(v_eul, dtype=torch.float32)
                    v_proj = (v_eul_t - rot_mean) @ rot_axis
                    v_gr = ((v_proj - rot_pca_min) / max(rot_pca_max - rot_pca_min, 1e-8)
                            ).clamp(0, 1).mul(N_ROT_BINS - 1).long()
                    v_gg = discretize(s["trajectory_gripper"], stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS)
                    v_gz = discretize(s["trajectory_3d"][..., 2], stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
                    vo = model(v_rgb, start_pix=v_sp)
                    v_vol = vo["volume_logits"][0]
                    Tv, Zv, Hv, Wv = v_vol.shape
                    flat = v_vol.reshape(Tv, -1).argmax(dim=-1).cpu().numpy()
                    pyx = flat % (Hv * Wv)
                    py = (pyx // Wv).astype(np.float32) / (Hv / LIBERO_IMG)
                    px = (pyx % Wv).astype(np.float32) / (Wv / LIBERO_IMG)
                    v_pred_pix = np.stack([px, py], axis=-1)
                viz_kp  = rainbow_overlay(s["rgb"], v_pred_pix, v_pix)
                viz_gd  = piano_roll(vo["gripper_logits"][0], v_gg.numpy())
                viz_rd  = piano_roll(vo["rotation_logits"][0], v_gr.numpy())
                wandb.log({"vis/keypoints":    wandb.Image(viz_kp),
                           "vis/gripper_dist": wandb.Image(viz_gd),
                           "vis/rotation_dist": wandb.Image(viz_rd)}, step=global_step)
                model.train()

        # End-of-epoch summary
        print(f"Epoch {epoch}: g={gripper_loss.item():.3f}  r={rotation_loss.item():.3f}  v={volume_loss.item():.3f}")
        wandb.log({"epoch_end/epoch": epoch}, step=global_step)

        is_save = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "args": vars(args),
                    "rotation_mode": "1d_pca",
                    "n_rot_bins": N_ROT_BINS,
                    "min_height": stats["min_height"], "max_height": stats["max_height"],
                    "min_grip":   stats["min_grip"],   "max_grip":   stats["max_grip"],
                    "rot_pca_mean": np.asarray(pca_basis['mean']),
                    "rot_pca_axis": np.asarray(pca_basis['principal_axis']),
                    "rot_pca_min":  rot_pca_min,
                    "rot_pca_max":  rot_pca_max,
                    "image_size":   LIBERO_IMG,
                    "n_window":     args.n_window,
                    "frame_stride": args.frame_stride}
            torch.save(ckpt, ckpt_dir / "latest.pth")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
