"""Train DinoVolumeQuery2View on libero (dual-camera).

Defaults to libero_spatial task 0, 1D PCA rotation (PC1 EV ≈ 99% on libero).
"""
import os, sys, time, math, argparse, json
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import wandb
import cv2

sys.path.insert(0, os.path.dirname(__file__))
from data_libero_2view import CachedTrajectory2ViewDataset


def _denorm_rgb(rgb_t):
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    return (rgb_t.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)


def _joint_pca_two_features(F_bev, F_wrist):
    """F_bev, F_wrist: (C, H, W) on any device. Returns (pca_bev_rgb, pca_wrist_rgb, ev_top3)."""
    fb = F_bev.detach().cpu().numpy().transpose(1, 2, 0).reshape(-1, F_bev.shape[0])
    fw = F_wrist.detach().cpu().numpy().transpose(1, 2, 0).reshape(-1, F_wrist.shape[0])
    joint = np.concatenate([fb, fw], axis=0)
    centred = joint - joint.mean(0, keepdims=True)
    u, sv, vt = np.linalg.svd(centred, full_matrices=False)
    V = vt[:3].T
    ev = (sv ** 2) / (sv ** 2).sum()
    pcs = centred @ V
    lo, hi = np.percentile(pcs, [2, 98], axis=0)
    pcs_n = np.clip((pcs - lo) / (hi - lo + 1e-8), 0, 1)
    Hb, Wb = F_bev.shape[1], F_bev.shape[2]
    pca_b = pcs_n[:Hb * Wb].reshape(Hb, Wb, 3)
    pca_w = pcs_n[Hb * Wb:].reshape(Hb, Wb, 3)
    return pca_b, pca_w, float(ev[:3].sum())


def make_2view_panel(rgb_bev_t, rgb_wrist_t, F_bev, F_wrist, gt_pix_arr, pred_pix_arr, ep_step_label=""):
    """Compose a 2x2 PNG: BEV w/ rainbow trajectory overlay, wrist RGB, F_bev PCA, F_wrist PCA."""
    rgb_b = (_denorm_rgb(rgb_bev_t)   * 255).astype(np.uint8).copy()
    rgb_w = (_denorm_rgb(rgb_wrist_t) * 255).astype(np.uint8).copy()
    T = gt_pix_arr.shape[0]
    for t in range(T):
        hue = int(t / max(T - 1, 1) * 170)
        col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
        px, py = int(pred_pix_arr[t, 0]), int(pred_pix_arr[t, 1])
        gx, gy = int(gt_pix_arr[t, 0]), int(gt_pix_arr[t, 1])
        cv2.circle(rgb_b, (px, py), 4, col, -1)
        cv2.circle(rgb_b, (gx, gy), 4, (255, 255, 255), 1)
    pca_b, pca_w, ev_top3 = _joint_pca_two_features(F_bev, F_wrist)
    Hi, Wi = rgb_b.shape[:2]
    pca_b8 = cv2.resize((pca_b * 255).astype(np.uint8), (Wi, Hi), interpolation=cv2.INTER_NEAREST)
    pca_w8 = cv2.resize((pca_w * 255).astype(np.uint8), (Wi, Hi), interpolation=cv2.INTER_NEAREST)
    top = np.concatenate([rgb_b, rgb_w], axis=1)
    bot = np.concatenate([pca_b8, pca_w8], axis=1)
    panel = np.concatenate([top, bot], axis=0)
    label = f"{ep_step_label} | BEV+traj | wrist | F_bev PCA | F_wrist PCA  (joint EV={ev_top3:.0%})"
    bar = np.full((28, panel.shape[1], 3), 240, dtype=np.uint8)
    cv2.putText(bar, label, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (30, 30, 30), 1, cv2.LINE_AA)
    return np.concatenate([bar, panel], axis=0)


from model_dino_volume_query_2view import (DinoVolumeQuery2View, N_HEIGHT_BINS,
                                            N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE,
                                            build_bev_world_xyz_table)
from model_dino_volume_query_dualfrustum import (
    DinoVolumeQuery2ViewDualFrustum, build_wrist_world_xyz_table_batched,
)

LIBERO_IMG = 448


def discretize(values, lo, hi, n_bins):
    norm = (values - lo) / max(hi - lo, 1e-8)
    return (norm.clamp(0, 1) * (n_bins - 1)).long().clamp(0, n_bins - 1)


def compute_dataset_stats(ds):
    all_z, all_g, all_e = [], [], []
    for demo in ds.demos:
        all_z.append(demo['eef_pos'][:, 2])
        all_g.append(demo['gripper'])
        eulers = np.stack([R.from_quat(q).as_euler('xyz') for q in demo['eef_quat']])
        all_e.append(eulers)
    z = np.concatenate(all_z); g = np.concatenate(all_g); e = np.concatenate(all_e, axis=0)
    z_lo, z_hi = float(z.min()), float(z.max()); z_pad = (z_hi - z_lo) * 0.05
    g_lo, g_hi = float(g.min()), float(g.max()); g_pad = (g_hi - g_lo) * 0.05
    return {
        "min_height": z_lo - z_pad, "max_height": z_hi + z_pad,
        "min_grip":   g_lo - g_pad, "max_grip":   g_hi + g_pad,
        "eulers_all": e,
    }


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero_2view")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids",  type=str, default="0")
    p.add_argument("--max_demos", type=int, default=0)
    p.add_argument("--n_window",  type=int, default=8)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--gripper_loss_weight",  type=float, default=0.5)
    p.add_argument("--rotation_loss_weight", type=float, default=0.5)
    p.add_argument("--log_scalars_every", type=int, default=10)
    p.add_argument("--vis_every_steps",   type=int, default=100)
    p.add_argument("--save_every_epochs", type=int, default=5)
    p.add_argument("--rot_pca_path", type=str,
                   default="/data/cameron/para/libero/rotation_pca_basis_libero_spatial_t0.npz")
    p.add_argument("--resume_from", type=str, default="")
    p.add_argument("--run_name", type=str, default="libero_2view_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    p.add_argument("--fusion_mode", type=str, default="sum", choices=["sum", "max", "poe"],
                    help="2view fusion: sum (raw scores), max (logsumexp over views), poe (product of experts)")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    task_ids = [int(t) for t in args.task_ids.split(",") if t.strip()]
    print(f"Loading 2view cache: {args.cache_root}/{args.benchmark} task_ids={task_ids}")
    full = CachedTrajectory2ViewDataset(
        cache_root=args.cache_root, benchmark_name=args.benchmark,
        task_ids=task_ids,
        image_size=LIBERO_IMG, n_window=args.n_window, frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    stats = compute_dataset_stats(full)
    print(f"  height range: [{stats['min_height']:.3f}, {stats['max_height']:.3f}]")

    # 1D PCA rotation basis
    pca_basis = np.load(args.rot_pca_path)
    rot_mean = torch.tensor(pca_basis['mean'], dtype=torch.float32)
    rot_axis = torch.tensor(pca_basis['principal_axis'], dtype=torch.float32)
    rot_pca_min = float(pca_basis['pca_min']); rot_pca_max = float(pca_basis['pca_max'])
    print(f"  rot PCA1 EV={float(pca_basis['ev_ratio_pc1']):.3f}, axis={pca_basis['principal_axis']}")

    n = len(full); n_val = max(1, int(n * 0.05))
    train_ds, val_ds = random_split(full, [n - n_val, n_val],
                                    generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True,
                              persistent_workers=args.num_workers > 0)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True,
                            persistent_workers=args.num_workers > 0)
    print(f"  train={len(train_ds)} val={len(val_ds)}")

    # Per-sample BEV table — re-built per batch to handle varied viewpoints
    from model_dino_volume_query_2view import build_bev_world_xyz_table_batched
    bev_xyz_static = None
    bev_per_sample = True  # set to False to fall back to per-batch-broadcast single table
    if not bev_per_sample:
        bev_K_norm    = torch.tensor(full.demos[0]['bev_K_norm'],     dtype=torch.float32, device=device)
        bev_extrinsic = torch.tensor(full.demos[0]['bev_extrinsic'],  dtype=torch.float32, device=device)
        bev_xyz_static = build_bev_world_xyz_table(
            bev_K_norm, bev_extrinsic,
            N_HEIGHT_BINS, stats['min_height'], stats['max_height'],
            PRED_SIZE, PRED_SIZE, LIBERO_IMG, device,
        )
        print(f"  bev_xyz_table (static): {tuple(bev_xyz_static.shape)}")
    else:
        print("  bev_xyz_table: per-sample (computed per batch)")

    print("Building DinoVolumeQuery2ViewDualFrustum (dual-frustum volume sampling)...")
    model = DinoVolumeQuery2ViewDualFrustum(
        n_window=args.n_window, n_height_bins=N_HEIGHT_BINS,
        n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
        image_size=LIBERO_IMG, pred_size=PRED_SIZE,
        rotation_mode='1d_pca',
    ).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    if args.resume_from:
        sd = torch.load(args.resume_from, map_location=device, weights_only=False)
        ckpt_sd = sd["model_state_dict"]
        cur = model.state_dict()
        loaded = {k: v for k, v in ckpt_sd.items() if k in cur and cur[k].shape == v.shape}
        model.load_state_dict(loaded, strict=False)
        print(f"  resumed: {len(loaded)} matching keys")

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)
    global_step = 0
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb_bev   = batch["rgb_bev"].to(device, non_blocking=True)
            rgb_wrist = batch["rgb_wrist"].to(device, non_blocking=True)
            traj2d_bev = batch["trajectory_2d_bev"].to(device, non_blocking=True)
            traj3d     = batch["trajectory_3d"].to(device, non_blocking=True)
            grip_v     = batch["trajectory_gripper"].to(device, non_blocking=True)
            quat       = batch["trajectory_quat"].cpu().numpy()
            wrist_ext  = batch["wrist_extrinsic"].to(device, non_blocking=True)     # (B, 4, 4)
            wrist_K    = batch["wrist_K_norm"].to(device, non_blocking=True)         # (B, 3, 3)
            bev_K      = batch["bev_K_norm"].to(device, non_blocking=True)           # (B, 3, 3)
            bev_ext    = batch["bev_extrinsic"].to(device, non_blocking=True)        # (B, 4, 4)
            B, T, _ = traj2d_bev.shape

            if bev_per_sample:
                bev_xyz_table = build_bev_world_xyz_table_batched(
                    bev_K, bev_ext,
                    N_HEIGHT_BINS, stats['min_height'], stats['max_height'],
                    PRED_SIZE, PRED_SIZE, LIBERO_IMG,
                )                                                                    # (B, Z, H, W, 3)
            else:
                bev_xyz_table = bev_xyz_static

            # GT bins
            gz_cont = traj3d[..., 2]
            gz = discretize(gz_cont, stats["min_height"], stats["max_height"], N_HEIGHT_BINS)
            gg = discretize(grip_v, stats["min_grip"], stats["max_grip"], N_GRIPPER_BINS)
            euler_t = torch.tensor(np.stack([[R.from_quat(q).as_euler('xyz') for q in qs] for qs in quat]),
                                    dtype=torch.float32, device=device)
            proj = (euler_t - rot_mean.to(device)) @ rot_axis.to(device)
            gr_norm = (proj - rot_pca_min) / max(rot_pca_max - rot_pca_min, 1e-8)
            gr = (gr_norm.clamp(0, 1) * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)

            start_pix_bev = traj2d_bev[:, 0, :]
            gt_pix_bev    = traj2d_bev

            # Build wrist-anchored xyz table per batch
            wrist_xyz_table = build_wrist_world_xyz_table_batched(
                wrist_K, wrist_ext,
                N_HEIGHT_BINS, stats['min_height'], stats['max_height'],
                PRED_SIZE, PRED_SIZE, LIBERO_IMG,
            )

            out = model(rgb_bev, rgb_wrist, start_pix_bev, bev_xyz_table, wrist_K, wrist_ext,
                         bev_K, bev_ext, wrist_xyz_table)
            vol   = out["volume_logits"]       # (B, T, Z, 2, H, W) — anchor 0=bev, 1=wrist
            grip_logits = out["gripper_logits"]
            rot_logits  = out["rotation_logits"]
            Z = vol.shape[2]; H, W = vol.shape[-2:]

            gx = (gt_pix_bev[..., 0] * (W / LIBERO_IMG)).long().clamp(0, W - 1)
            gy = (gt_pix_bev[..., 1] * (H / LIBERO_IMG)).long().clamp(0, H - 1)
            gz_ = gz.clamp(0, Z - 1)
            # GT lives in BEV-anchored slot (anchor=0). Flattened index over (Z, 2, H, W):
            tgt_flat = gz_ * (2 * H * W) + 0 * (H * W) + gy * W + gx

            volume_loss   = F.cross_entropy(vol.reshape(B * T, Z * 2 * H * W), tgt_flat.reshape(-1))
            gripper_loss  = F.cross_entropy(grip_logits.reshape(B * T, -1), gg.reshape(-1))
            rotation_loss = F.cross_entropy(rot_logits.reshape(B * T, -1),  gr.reshape(-1))
            total = volume_loss + args.gripper_loss_weight * gripper_loss \
                                + args.rotation_loss_weight * rotation_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(v=f"{volume_loss.item():.2f}", g=f"{gripper_loss.item():.2f}",
                             r=f"{rotation_loss.item():.2f}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                with torch.no_grad():
                    # For pixel-error monitoring, only consider BEV-anchored slot (anchor=0)
                    vol_bev_only = vol[:, :, :, 0]  # (B, T, Z, H, W)
                    flat_v = vol_bev_only.reshape(B, T, -1).argmax(dim=-1)
                    pyx = flat_v % (H * W)
                    py_p = (pyx // W).float() / (H / LIBERO_IMG)
                    px_p = (pyx % W).float()  / (W / LIBERO_IMG)
                    pred_pix = torch.stack([px_p, py_p], dim=-1)
                    train_pix = (pred_pix - gt_pix_bev).norm(dim=-1).mean()
                    train_grip_acc = (grip_logits.argmax(-1) == gg).float().mean()
                    train_rot_acc  = (rot_logits.argmax(-1)  == gr).float().mean()
                wandb.log({"train/volume_loss":   volume_loss.item(),
                           "train/gripper_loss":  gripper_loss.item(),
                           "train/rotation_loss": rotation_loss.item(),
                           "train/total":         total.item(),
                           "train/pix_argmax":    train_pix.item(),
                           "train/grip_acc":      train_grip_acc.item(),
                           "train/rot_acc":       train_rot_acc.item(),
                           "epoch": epoch}, step=global_step)

            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                # Sample 0 from this batch — use its predictions + features for the panel.
                with torch.no_grad():
                    F_bev_t   = out["pixel_feats"      ][0]
                    F_wrist_t = out["pixel_feats_wrist"][0]
                    # For pixel-error monitoring, only consider BEV-anchored slot (anchor=0)
                    vol_bev_only = vol[:, :, :, 0]  # (B, T, Z, H, W)
                    flat_v = vol_bev_only.reshape(B, T, -1).argmax(dim=-1)
                    pyx = flat_v % (H * W)
                    py_p = (pyx // W).float() / (H / LIBERO_IMG)
                    px_p = (pyx % W).float()  / (W / LIBERO_IMG)
                    pred_pix0 = torch.stack([px_p, py_p], dim=-1)[0].cpu().numpy()
                    gt_pix0   = gt_pix_bev[0].cpu().numpy()
                    panel = make_2view_panel(
                        rgb_bev[0], rgb_wrist[0], F_bev_t, F_wrist_t,
                        gt_pix0, pred_pix0,
                        ep_step_label=f"ep {epoch} step {global_step}",
                    )
                wandb.log({"vis/2view_panel": wandb.Image(panel)}, step=global_step)

        print(f"Epoch {epoch}: g={gripper_loss.item():.3f}  r={rotation_loss.item():.3f}  v={volume_loss.item():.3f}")
        is_save = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "args": vars(args),
                    "rotation_mode": "1d_pca", "n_rot_bins": N_ROT_BINS,
                    "fusion_mode": args.fusion_mode,
                    "min_height": stats["min_height"], "max_height": stats["max_height"],
                    "min_grip":   stats["min_grip"],   "max_grip":   stats["max_grip"],
                    "rot_pca_mean": np.asarray(pca_basis['mean']),
                    "rot_pca_axis": np.asarray(pca_basis['principal_axis']),
                    "rot_pca_min":  rot_pca_min, "rot_pca_max":  rot_pca_max,
                    "n_window": args.n_window, "image_size": LIBERO_IMG,
                    "bev_K_norm": full.demos[0]['bev_K_norm'],
                    "bev_extrinsic": full.demos[0]['bev_extrinsic']}
            torch.save(ckpt, ckpt_dir / "latest.pth")
    wandb.finish()
    print(f"Done. {ckpt_dir}")


if __name__ == "__main__":
    main()
