"""Train SmoothVolumeARModelV2 on the smith300 first_mobile_collection dataset.

Uses in-memory Smith300VolumeDataset so DataLoader workers can be 0 (we don't pay the
fork-copy cost of a 3.7 GB tensor). Big batch (default 64) to saturate GPU memory.
"""
import argparse, io, os, sys, time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))

from robot_volume import (
    voxel_centers_world, voxel_idx_to_world, world_to_pixel_torch,
    N_PAST_EEF, T_FUTURE, N_ROT_BINS, MIN_ROT, MAX_ROT, IMAGE_SIZE, N_VOX,
)
from data_smith300_volume import Smith300VolumeDataset
from model_volume_smooth_v2 import SmoothVolumeARModelV2

W_VOXEL = 1.0
W_GRIP  = 5.0
W_ROT   = 0.5


def discretize_rotation(euler):
    min_t = torch.tensor(MIN_ROT, device=euler.device, dtype=torch.float32)
    max_t = torch.tensor(MAX_ROT, device=euler.device, dtype=torch.float32)
    norm = ((euler - min_t) / (max_t - min_t + 1e-8)).clamp(0, 1)
    return (norm * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def rainbow_color(t, T):
    h = int(180.0 * (t / max(T - 1, 1)))
    return tuple(int(x) for x in cv2.cvtColor(np.uint8([[[h, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0])


def make_rainbow_overlay(rgb_chw_norm, pred_pix, gt_pix, image_size=IMAGE_SIZE):
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
    img = rgb_chw_norm.detach().cpu().numpy() * std + mean
    img = np.clip(img.transpose(1, 2, 0), 0, 1)
    vis = (img * 255).astype(np.uint8).copy()
    T = pred_pix.shape[0]
    for i in range(1, T):
        cv2.line(vis, tuple(np.int32(gt_pix[i-1])), tuple(np.int32(gt_pix[i])),
                 (220, 220, 220), 2, cv2.LINE_AA)
    for i in range(T):
        cv2.circle(vis, tuple(np.int32(gt_pix[i])), 4, (240, 240, 240), 1, cv2.LINE_AA)
    for i in range(1, T):
        cv2.line(vis, tuple(np.int32(pred_pix[i-1])), tuple(np.int32(pred_pix[i])),
                 rainbow_color(i, T), 2, cv2.LINE_AA)
    for i in range(T):
        c = rainbow_color(i, T)
        cv2.drawMarker(vis, tuple(np.int32(pred_pix[i])), c, cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
        cv2.putText(vis, str(i), (int(pred_pix[i, 0]) + 6, int(pred_pix[i, 1]) - 6),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.45, c, 1, cv2.LINE_AA)
    return vis


def log_vis_batch(model, batch, device, step, voxel_centers, split):
    model.eval()
    with torch.no_grad():
        rgb = batch["rgb"][:1].to(device)
        past = batch["past_eef_world"][:1].to(device)
        cur = batch["current_eef_world"][:1].to(device)
        w2c = batch["world_to_camera"][:1].to(device)
        out = model(rgb, past, cur, w2c, target_voxel_idx=None)
        pred_idx = out["pred_voxel_idx"][0]
        pred_world = voxel_centers[pred_idx]
        gt_world = batch["target_eef_world"][0].to(device)
        pred_pix = world_to_pixel_torch(pred_world.unsqueeze(0), w2c)[0].cpu().numpy()
        gt_pix   = world_to_pixel_torch(gt_world.unsqueeze(0), w2c)[0].cpu().numpy()
        overlay = make_rainbow_overlay(batch["rgb"][0], pred_pix, gt_pix)
    wandb.log({f"vis/{split}_rainbow": wandb.Image(overlay)}, step=step)
    model.train()


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--lr", type=float, default=5e-4)
    p.add_argument("--epochs", type=int, default=60)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--vis_every_steps", type=int, default=50)
    p.add_argument("--log_scalars_every", type=int, default=10)
    p.add_argument("--run_name", type=str, default="smith300_volume_smooth_v2")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("Loading in-memory dataset...")
    full = Smith300VolumeDataset(root_dir=args.root_dir, image_size=IMAGE_SIZE,
                                  frame_stride=args.frame_stride)
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
    print(f"  Train: {n_tr}  Val: {n_val}")

    # num_workers=0 because dataset is fully in-memory; spawning workers would
    # 4× this 3.7GB tensor and crash.
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=0, pin_memory=True, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False,
                              num_workers=0, pin_memory=True)

    print("Building model...")
    model = SmoothVolumeARModelV2().to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_a = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {n_t:,} / {n_a:,}")
    wandb.config.update({"trainable_params": n_t, "total_params": n_a})

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)

    voxel_centers = voxel_centers_world().to(device)

    best_val = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} train", leave=False)
        ep_losses = []
        for batch in pbar:
            rgb   = batch["rgb"].to(device, non_blocking=True)
            past  = batch["past_eef_world"].to(device, non_blocking=True)
            cur   = batch["current_eef_world"].to(device, non_blocking=True)
            w2c   = batch["world_to_camera"].to(device, non_blocking=True)
            tgt_v = batch["target_voxel_idx"].to(device, non_blocking=True)
            tgt_g = batch["target_grip"].to(device, non_blocking=True)
            tgt_e = batch["target_rot_euler"].to(device, non_blocking=True)
            valid = batch["valid_mask"].to(device, non_blocking=True)
            B, T = tgt_v.shape

            out = model(rgb, past, cur, w2c, target_voxel_idx=tgt_v)
            v_logits = out["voxel_logits"]                                          # (B, V, T)
            mask = valid.reshape(-1).float()
            denom = mask.sum().clamp_min(1.0)
            v_loss = (F.cross_entropy(v_logits.permute(0, 2, 1).reshape(B * T, N_VOX),
                                       tgt_v.reshape(-1), reduction='none') * mask).sum() / denom

            tgt_g_bin = (tgt_g > 0).float()
            g_loss = (F.binary_cross_entropy_with_logits(out["grip_logit"].reshape(-1),
                                                          tgt_g_bin.reshape(-1), reduction='none') * mask).sum() / denom
            tgt_r_bin = discretize_rotation(tgt_e)
            r_losses = []
            for ax in range(3):
                r_losses.append((F.cross_entropy(out["rot_logits"][:, :, ax].reshape(B * T, N_ROT_BINS),
                                                  tgt_r_bin[:, :, ax].reshape(-1), reduction='none') * mask).sum() / denom)
            r_loss = torch.stack(r_losses).mean()
            loss = W_VOXEL * v_loss + W_GRIP * g_loss + W_ROT * r_loss

            opt.zero_grad(); loss.backward(); opt.step()
            ep_losses.append(loss.item())
            pbar.set_postfix(loss=f"{loss.item():.3f}", v=f"{v_loss.item():.3f}",
                              g=f"{g_loss.item():.3f}", r=f"{r_loss.item():.3f}")
            global_step += 1
            if global_step % args.log_scalars_every == 0:
                with torch.no_grad():
                    pred_idx = v_logits.argmax(dim=1)
                    pred_w   = voxel_centers[pred_idx]
                    err_m    = (pred_w - batch["target_eef_world"].to(device)).norm(dim=-1) * valid.float()
                    err_m    = err_m.sum() / denom
                wandb.log({"train/loss": loss.item(), "train/voxel_loss": v_loss.item(),
                           "train/grip_loss": g_loss.item(), "train/rot_loss": r_loss.item(),
                           "train/voxel_err_m": err_m.item(), "epoch": epoch}, step=global_step)
            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                log_vis_batch(model, next(iter(val_loader)), device, global_step,
                               voxel_centers, "val")
                log_vis_batch(model, batch, device, global_step, voxel_centers, "train")

        # Validation
        model.eval(); val_losses = []; val_errs = []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device); past = batch["past_eef_world"].to(device)
                cur = batch["current_eef_world"].to(device); w2c = batch["world_to_camera"].to(device)
                tgt_v = batch["target_voxel_idx"].to(device); valid = batch["valid_mask"].to(device)
                B, T = tgt_v.shape
                out = model(rgb, past, cur, w2c, target_voxel_idx=tgt_v)
                v_logits = out["voxel_logits"]
                mask = valid.reshape(-1).float(); denom = mask.sum().clamp_min(1.0)
                v_loss = (F.cross_entropy(v_logits.permute(0, 2, 1).reshape(B * T, N_VOX),
                                           tgt_v.reshape(-1), reduction='none') * mask).sum() / denom
                val_losses.append(v_loss.item())
                pred_idx = v_logits.argmax(dim=1)
                pred_w   = voxel_centers[pred_idx]
                err = (pred_w - batch["target_eef_world"].to(device)).norm(dim=-1)
                val_errs.append((err * valid.float()).sum().item() / denom.item())

        tr_loss = float(np.mean(ep_losses)) if ep_losses else float("inf")
        val_loss = float(np.mean(val_losses)) if val_losses else float("inf")
        val_err  = float(np.mean(val_errs)) if val_errs else float("inf")
        print(f"Epoch {epoch}: tr={tr_loss:.4f}  val={val_loss:.4f}  val_err={val_err*1000:.1f}mm")
        wandb.log({"epoch_end/train_loss": tr_loss, "epoch_end/val_voxel_loss": val_loss,
                   "epoch_end/val_voxel_err_mm": val_err * 1000, "epoch": epoch,
                   "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        ckpt = {"epoch": epoch, "global_step": global_step,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "val_loss": val_loss, "args": vars(args)}
        torch.save(ckpt, ckpt_dir / "latest.pth")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(ckpt, ckpt_dir / "best.pth")
            print(f"  ✓ best (val_loss={val_loss:.4f})")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
