"""Train the volume AR model on libero_spatial task 0 (bowl pickup-and-place).

Loss = per-timestep voxel CE + grip BCE + per-axis rot CE.
Wandb visualization (every --vis_every_steps): per-timestep rainbow-colored predicted EEF
pixels overlaid on the current RGB frame, with lines connecting them and a faded GT polyline.
"""
import argparse, io, os, sys, time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))

from robot_volume import (
    voxel_centers_world, voxel_idx_to_world,
    N_PAST_EEF, T_FUTURE, N_ROT_BINS, MIN_ROT, MAX_ROT, IMAGE_SIZE, N_VOX,
)
from data_volume_ar import VolumeWindowDataset
from model_volume_smooth import SmoothVolumeARModel as VolumeARModel
from robot_volume import world_to_pixel_torch

W_VOXEL = 1.0
W_GRIP  = 5.0
W_ROT   = 0.5


def discretize_rotation(euler):
    """(B, T, 3) → (B, T, 3) bin indices."""
    min_t = torch.tensor(MIN_ROT, device=euler.device, dtype=torch.float32)
    max_t = torch.tensor(MAX_ROT, device=euler.device, dtype=torch.float32)
    norm = ((euler - min_t) / (max_t - min_t + 1e-8)).clamp(0, 1)
    return (norm * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def rainbow_color(t, T):
    """t in [0, T-1] → BGR uint8 tuple from red→violet."""
    h = int(180.0 * (t / max(T - 1, 1)))   # OpenCV HSV hue 0..180
    hsv = np.uint8([[[h, 255, 255]]])
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)[0, 0]
    return (int(bgr[0]), int(bgr[1]), int(bgr[2]))


def make_rainbow_overlay(rgb_chw_norm, pred_pix, gt_pix, voxel_logits=None, target_size=IMAGE_SIZE):
    """rgb_chw_norm: (3,H,W) imagenet-normalized; pred_pix: (T,2) np; gt_pix: (T,2) np.
    Optional voxel_logits: (V, T) — used to render a per-timestep top-down heatmap.
    Returns (H, W, 3) uint8 RGB for wandb.
    """
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
    img = rgb_chw_norm.detach().cpu().numpy() * std + mean
    img = np.clip(img.transpose(1, 2, 0), 0, 1)
    vis = (img * 255).astype(np.uint8).copy()

    T = pred_pix.shape[0]
    # Faded GT polyline (white)
    for i in range(1, T):
        cv2.line(vis, tuple(np.int32(gt_pix[i-1])), tuple(np.int32(gt_pix[i])),
                 (220, 220, 220), 2, cv2.LINE_AA)
    for i in range(T):
        cv2.circle(vis, tuple(np.int32(gt_pix[i])), 4, (240, 240, 240), 1, cv2.LINE_AA)

    # Rainbow predicted polyline + keypoints
    for i in range(1, T):
        cv2.line(vis, tuple(np.int32(pred_pix[i-1])), tuple(np.int32(pred_pix[i])),
                 rainbow_color(i, T), 2, cv2.LINE_AA)
    for i in range(T):
        c = rainbow_color(i, T)
        cv2.drawMarker(vis, tuple(np.int32(pred_pix[i])), c, cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
        cv2.putText(vis, str(i), (int(pred_pix[i, 0]) + 6, int(pred_pix[i, 1]) - 6),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.45, c, 1, cv2.LINE_AA)
    return vis


def log_vis_batch(model, batch, device, step, voxel_centers, split="val"):
    """Run the model on the FIRST sample of a batch, build the rainbow overlay, log to wandb."""
    model.eval()
    with torch.no_grad():
        rgb = batch["rgb"][:1].to(device)
        past = batch["past_eef_world"][:1].to(device)
        cur = batch["current_eef_world"][:1].to(device)
        w2c = batch["world_to_camera"][:1].to(device)
        out = model(rgb, past, cur, w2c, target_voxel_idx=None)
        pred_idx = out["pred_voxel_idx"][0]                                      # (T,)
        # Lookup voxel center world coords for each predicted index
        pred_world = voxel_centers[pred_idx]                                     # (T, 3)
        gt_world = batch["target_eef_world"][0].to(device)                       # (T, 3)
        # Project both to pixel
        pred_pix = world_to_pixel_torch(pred_world.unsqueeze(0), w2c)[0].cpu().numpy()
        gt_pix   = world_to_pixel_torch(gt_world.unsqueeze(0),   w2c)[0].cpu().numpy()
        overlay = make_rainbow_overlay(batch["rgb"][0], pred_pix, gt_pix)
    wandb.log({f"vis/{split}_rainbow": wandb.Image(overlay)}, step=step)
    model.train()


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--max_demos", type=int, default=50)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--num_workers", type=int, default=8)
    p.add_argument("--vis_every_steps", type=int, default=100)
    p.add_argument("--log_scalars_every", type=int, default=10)
    p.add_argument("--run_name", type=str, default="volume_ar_v0_libero_spatial_t0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("Loading dataset...")
    full = VolumeWindowDataset(args.cache_root, args.benchmark, args.task_id,
                                image_size=IMAGE_SIZE, frame_stride=args.frame_stride,
                                max_demos=args.max_demos)
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
    print(f"  Train: {n_tr}  Val: {n_val}")
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)

    print("Building model...")
    model = VolumeARModel().to(device)
    n_tr_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_tot  = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {n_tr_p:,} / {n_tot:,}")
    wandb.config.update({"trainable_params": n_tr_p, "total_params": n_tot})

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)

    voxel_centers = voxel_centers_world().to(device)                      # (V, 3)

    best_val = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} train", leave=False)
        ep_losses = []
        for batch in pbar:
            rgb   = batch["rgb"].to(device, non_blocking=True)
            past  = batch["past_eef_world"].to(device, non_blocking=True)
            cur   = batch["current_eef_world"].to(device, non_blocking=True)
            w2c   = batch["world_to_camera"].to(device, non_blocking=True)
            tgt_v = batch["target_voxel_idx"].to(device, non_blocking=True)        # (B, T)
            tgt_g = batch["target_grip"].to(device, non_blocking=True)             # (B, T)
            tgt_e = batch["target_rot_euler"].to(device, non_blocking=True)        # (B, T, 3)
            valid = batch["valid_mask"].to(device, non_blocking=True)              # (B, T) bool
            B, T = tgt_v.shape

            out = model(rgb, past, cur, w2c, target_voxel_idx=tgt_v)
            v_logits = out["voxel_logits"]                                          # (B, V, T)
            # Per-timestep cross-entropy: reshape (B, V, T) → (B*T, V), tgt (B*T,)
            mask = valid.reshape(-1).float()                                        # (B*T,)
            denom = mask.sum().clamp_min(1.0)
            v_loss = (F.cross_entropy(v_logits.permute(0, 2, 1).reshape(B * T, N_VOX),
                                       tgt_v.reshape(-1), reduction='none') * mask).sum() / denom

            tgt_g_bin = (tgt_g > 0).float()
            g_loss = (F.binary_cross_entropy_with_logits(out["grip_logit"].reshape(-1),
                                                          tgt_g_bin.reshape(-1), reduction='none') * mask).sum() / denom
            tgt_r_bin = discretize_rotation(tgt_e)                                  # (B, T, 3)
            r_losses = []
            for ax in range(3):
                r_losses.append((F.cross_entropy(out["rot_logits"][:, :, ax].reshape(B * T, N_ROT_BINS),
                                                  tgt_r_bin[:, :, ax].reshape(-1), reduction='none') * mask).sum() / denom)
            r_loss = torch.stack(r_losses).mean()
            loss = W_VOXEL * v_loss + W_GRIP * g_loss + W_ROT * r_loss

            opt.zero_grad(); loss.backward(); opt.step()
            ep_losses.append(loss.item())
            pbar.set_postfix(loss=f"{loss.item():.3f}", v=f"{v_loss.item():.3f}",
                              g=f"{g_loss.item():.3f}", r=f"{r_loss.item():.3f}")
            global_step += 1
            if global_step % args.log_scalars_every == 0:
                # Per-timestep voxel-distance error (world meters) at argmax
                with torch.no_grad():
                    pred_idx = v_logits.argmax(dim=1)                              # (B, T)
                    pred_w   = voxel_centers[pred_idx]                              # (B, T, 3)
                    err_m    = (pred_w - batch["target_eef_world"].to(device)).norm(dim=-1) * valid.float()
                    err_m    = err_m.sum() / denom
                wandb.log({"train/loss": loss.item(),
                           "train/voxel_loss": v_loss.item(),
                           "train/grip_loss": g_loss.item(),
                           "train/rot_loss": r_loss.item(),
                           "train/voxel_err_m": err_m.item(),
                           "epoch": epoch}, step=global_step)
            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                log_vis_batch(model, next(iter(val_loader)), device, global_step,
                               voxel_centers, split="val")
                log_vis_batch(model, batch, device, global_step, voxel_centers, split="train")

        # Validation
        model.eval(); val_losses = []; val_errs = []
        with torch.no_grad():
            for batch in val_loader:
                rgb   = batch["rgb"].to(device); past  = batch["past_eef_world"].to(device)
                cur   = batch["current_eef_world"].to(device); w2c = batch["world_to_camera"].to(device)
                tgt_v = batch["target_voxel_idx"].to(device); tgt_g = batch["target_grip"].to(device)
                tgt_e = batch["target_rot_euler"].to(device); valid = batch["valid_mask"].to(device)
                B, T = tgt_v.shape
                out = model(rgb, past, cur, w2c, target_voxel_idx=tgt_v)
                v_logits = out["voxel_logits"]
                mask = valid.reshape(-1).float(); denom = mask.sum().clamp_min(1.0)
                v_loss = (F.cross_entropy(v_logits.permute(0, 2, 1).reshape(B * T, N_VOX),
                                           tgt_v.reshape(-1), reduction='none') * mask).sum() / denom
                val_losses.append(v_loss.item())
                pred_idx = v_logits.argmax(dim=1)
                pred_w   = voxel_centers[pred_idx]
                err = (pred_w - batch["target_eef_world"].to(device)).norm(dim=-1)
                val_errs.append((err * valid.float()).sum().item() / denom.item())
        tr_loss = float(np.mean(ep_losses)) if ep_losses else float("inf")
        val_loss = float(np.mean(val_losses)) if val_losses else float("inf")
        val_err  = float(np.mean(val_errs)) if val_errs else float("inf")
        print(f"Epoch {epoch}: tr={tr_loss:.4f}  val={val_loss:.4f}  val_err={val_err*1000:.1f}mm")
        wandb.log({"epoch_end/train_loss": tr_loss, "epoch_end/val_voxel_loss": val_loss,
                   "epoch_end/val_voxel_err_mm": val_err * 1000, "epoch": epoch,
                   "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        ckpt = {"epoch": epoch, "global_step": global_step,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "val_loss": val_loss, "args": vars(args)}
        torch.save(ckpt, ckpt_dir / "latest.pth")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(ckpt, ckpt_dir / "best.pth")
            print(f"  ✓ best (val_loss={val_loss:.4f})")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
