"""Train point-track heatmap predictor on RTX tracks: first frame -> predict next N_WINDOW heatmaps for N_QUERY_POINTS."""
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import os
import sys
import tempfile

sys.path.insert(0, os.path.dirname(__file__))

from data import (
    RTXPointTrackDataset,
    N_WINDOW_POINT_TRACK,
    N_QUERY_POINTS,
    HEATMAP_SIZE,
    MOTION_PERCENTILE,
)
from model import PointTrackHeatmapPredictor

# Training config
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
N_VIS_QUERIES = 5  # Number of query points to visualize (pred/target heatmaps + colored circles on GIF)
VIS_INTERVAL = 50  # Run visualization every N iterations (inside train loop)
COLORS_5 = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]  # R, G, B, Y, M for circles


def collate_rtx_point_track(batch):
    """Collate batch; keep frames_vis and tracks_vis as lists (variable H,W per sample)."""
    return {
        "rgb": torch.stack([b["rgb"] for b in batch]),
        "query_start_2d": torch.stack([b["query_start_2d"] for b in batch]),
        "target_heatmap_indices": torch.stack([b["target_heatmap_indices"] for b in batch]),
        "visibility": torch.stack([b["visibility"] for b in batch]),
        "frames_vis": [b["frames_vis"] for b in batch],
        "tracks_vis": [b["tracks_vis"] for b in batch],
        "pt_name": [b["pt_name"] for b in batch],
    }


def compute_heatmap_loss(logits, target_indices, visibility):
    """Cross-entropy on 64x64 heatmap logits, masked by visibility.
    logits: (B, Q, T, 64, 64)
    target_indices: (B, Q, T) flat index in [0, 64*64-1]
    visibility: (B, Q, T) bool
    """
    B, Q, T, H, W = logits.shape
    num_cells = H * W
    logits_flat = logits.reshape(B, Q, T, num_cells)  # (B, Q, T, 4096)
    mask = visibility  # (B, Q, T)
    if not mask.any():
        return logits_flat.sum() * 0.0
    logits_masked = logits_flat[mask]  # (N, 4096)
    targets_masked = target_indices.clamp(0, num_cells - 1)[mask]  # (N,)
    return F.cross_entropy(logits_masked, targets_masked, reduction="mean")


def build_gt_tracks_gif(frames_vis, tracks_vis, query_indices_to_circle, colors):
    """Draw GT tracks on frames and circle the selected query points with different colors.
    frames_vis: (T, H, W, 3) uint8
    tracks_vis: (T, Q, 2) float, original image coords
    query_indices_to_circle: list of int (e.g. [0,1,2,3,4])
    colors: list of (R,G,B) tuples
    Returns: (T, H, W, 3) uint8 numpy
    """
    import cv2
    out = np.array(frames_vis, dtype=np.uint8).copy()
    T, H, W, _ = out.shape
    Q = tracks_vis.shape[1]
    for t in range(T - 1):
        for q in range(Q):
            pt0 = (int(round(tracks_vis[t, q, 0])), int(round(tracks_vis[t, q, 1])))
            pt1 = (int(round(tracks_vis[t + 1, q, 0])), int(round(tracks_vis[t + 1, q, 1])))
            color = (128, 128, 128)
            if q in query_indices_to_circle:
                idx = query_indices_to_circle.index(q)
                color = colors[idx]
            cv2.line(out[t], pt0, pt1, color, 2)
    # Draw circles for the selected queries on every frame (so they're visible throughout the GIF)
    for t in range(T):
        for i, q in enumerate(query_indices_to_circle):
            if i >= len(colors):
                break
            x, y = int(round(tracks_vis[t, q, 0])), int(round(tracks_vis[t, q, 1]))
            if 0 <= x < W and 0 <= y < H:
                cv2.circle(out[t], (x, y), 8, colors[i], 2)
    return out


def indices_to_onehot_heatmap(indices, size):
    """Convert flat indices (Q, T) to one-hot heatmaps (Q, T, size, size)."""
    Q, T = indices.shape
    h = np.zeros((Q, T, size, size), dtype=np.float32)
    for q in range(Q):
        for t in range(T):
            idx = int(indices[q, t].item())
            idx = max(0, min(idx, size * size - 1))
            y, x = divmod(idx, size)
            h[q, t, y, x] = 1.0
    return h


def run_visualization(model, device, batch, global_step, use_wandb):
    """Build GT tracks GIF, pred/target heatmaps for ~5 queries; log to wandb if use_wandb."""
    model.eval()
    rgb = batch["rgb"][:1].to(device)
    query_start_2d = batch["query_start_2d"][:1].to(device)
    target_indices = batch["target_heatmap_indices"][:1]
    frames_vis = batch["frames_vis"][0]
    tracks_vis = batch["tracks_vis"][0]
    with torch.no_grad():
        logits = model(rgb, query_start_2d)
    query_indices_to_circle = list(range(min(N_VIS_QUERIES, N_QUERY_POINTS)))
    gt_gif = build_gt_tracks_gif(
        frames_vis.numpy(), tracks_vis.numpy(), query_indices_to_circle, COLORS_5
    )
    pred_probs = F.softmax(logits.view(1, N_QUERY_POINTS, N_WINDOW_POINT_TRACK, -1), dim=-1)
    pred_probs = pred_probs.view(1, N_QUERY_POINTS, N_WINDOW_POINT_TRACK, HEATMAP_SIZE, HEATMAP_SIZE)
    target_np = target_indices[0].numpy()
    target_heatmaps = indices_to_onehot_heatmap(target_np, HEATMAP_SIZE)
    if use_wandb:
        import wandb
        from PIL import Image
        # Save GIF to temp file so wandb.Video(path) avoids moviepy encode of raw array
        with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as f:
            gif_path = f.name
        gif_saved = False
        try:
            # gt_gif: (T, H, W, 3) uint8; PIL duration in ms, 500ms = 2 fps
            frames = [Image.fromarray(gt_gif[t]) for t in range(gt_gif.shape[0])]
            frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=500, loop=0)
            vis_log = {"val/gt_tracks_gif": wandb.Video(gif_path, fps=2, format="gif")}
            gif_saved = True
        except Exception as e:
            vis_log = {"val/gt_tracks_gif_note": wandb.Image(gt_gif[0], caption=f"First frame (video save failed: {e})")}
        for ii, q in enumerate(query_indices_to_circle):
            if q >= N_QUERY_POINTS:
                break
            pred_q = pred_probs[0, q].cpu().numpy()
            targ_q = target_heatmaps[q]
            row_pred = np.concatenate([pred_q[t] for t in range(N_WINDOW_POINT_TRACK)], axis=1)
            row_targ = np.concatenate([targ_q[t] for t in range(N_WINDOW_POINT_TRACK)], axis=1)
            combined = np.concatenate([row_pred, row_targ], axis=0)
            combined = np.clip(combined * 255, 0, 255).astype(np.uint8)
            combined = np.repeat(combined[:, :, np.newaxis], 3, axis=-1)
            vis_log[f"val/heatmaps_query{q}"] = wandb.Image(combined, caption=f"Query {q}: pred (top) target (bottom)")
        wandb.log(vis_log, step=global_step)
        # Delete temp file only after wandb has copied it (during log())
        if gif_saved and os.path.exists(gif_path):
            try:
                os.unlink(gif_path)
            except Exception:
                pass
    model.train()


def train_epoch(model, loader, optimizer, device, global_step, val_ds, vis_batch_size, collate_fn, vis_interval, use_wandb):
    """Train one epoch; run visualization every vis_interval iterations. Returns (avg_loss, updated_global_step)."""
    model.train()
    total_loss = 0.0
    n = 0
    n_val = len(val_ds)
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train")):
        step = global_step + batch_idx
        rgb = batch["rgb"].to(device)
        query_start_2d = batch["query_start_2d"].to(device)
        target_indices = batch["target_heatmap_indices"].to(device)
        visibility = batch["visibility"].to(device)

        logits = model(rgb, query_start_2d)
        loss = compute_heatmap_loss(logits, target_indices, visibility)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * rgb.shape[0]
        n += rgb.shape[0]

        if vis_interval > 0 and step % vis_interval == 0 and n_val > 0:
            # Random val sample(s) so we don't always log the same first batch
            size = min(vis_batch_size, n_val)
            vis_indices = np.random.choice(n_val, size=size, replace=False)
            vis_batch = collate_fn([val_ds[i] for i in vis_indices])
            run_visualization(model, device, vis_batch, step + 1, use_wandb)

    return total_loss / max(n, 1), global_step + len(loader)


@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    total_loss = 0.0
    n = 0
    for batch in loader:
        rgb = batch["rgb"].to(device)
        query_start_2d = batch["query_start_2d"].to(device)
        target_indices = batch["target_heatmap_indices"].to(device)
        visibility = batch["visibility"].to(device)
        logits = model(rgb, query_start_2d)
        loss = compute_heatmap_loss(logits, target_indices, visibility)
        total_loss += loss.item() * rgb.shape[0]
        n += rgb.shape[0]
    return total_loss / max(n, 1)


def main():
    parser = argparse.ArgumentParser(description="Point-track pretraining on RTX tracks")
    parser.add_argument("--tracks_root", type=str, default="/data/RTX/tracks", help="Directory of .pt track files")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LEARNING_RATE)
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS)
    parser.add_argument("--run_name", type=str, default="point_track_pretraining")
    parser.add_argument("--max_train", type=int, default=None, help="Cap train samples")
    parser.add_argument("--val_frac", type=float, default=0.05)
    parser.add_argument("--wandb", action="store_true", help="Log to wandb")
    parser.add_argument("--freeze_backbone", action="store_true")
    parser.add_argument("--vis_interval", type=int, default=VIS_INTERVAL, help="Visualize every N iterations (0=disabled)")
    parser.add_argument("--motion_percentile", type=int, default=MOTION_PERCENTILE, help="Use tracks with motion >= this percentile (85=top 15%%); 0 to disable filter")
    args = parser.parse_args()

    checkpoint_dir = Path("point_track_pretraining/checkpoints") / args.run_name
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Dataset
    full_ds = RTXPointTrackDataset(
        tracks_root=args.tracks_root,
        image_size=448,
        n_window=N_WINDOW_POINT_TRACK,
        n_query=N_QUERY_POINTS,
        heatmap_size=HEATMAP_SIZE,
        max_samples=args.max_train,
        motion_percentile=args.motion_percentile if args.motion_percentile > 0 else None,
    )
    n_total = len(full_ds)
    n_val = max(1, int(n_total * args.val_frac))
    n_train = n_total - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full_ds, [n_train, n_val], generator=torch.Generator().manual_seed(42)
    )
    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True, num_workers=16, collate_fn=collate_rtx_point_track
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False, num_workers=8, collate_fn=collate_rtx_point_track
    )

    # Model
    model = PointTrackHeatmapPredictor(
        target_size=448,
        n_window=N_WINDOW_POINT_TRACK,
        n_query=N_QUERY_POINTS,
        heatmap_size=HEATMAP_SIZE,
        freeze_backbone=args.freeze_backbone,
    )
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    if args.wandb:
        import wandb
        wandb.init(project="point_track_pretraining", name=args.run_name, config=vars(args))

    best_val = float("inf")
    global_step = 0
    for epoch in range(args.epochs):
        train_loss, global_step = train_epoch(
            model, train_loader, optimizer, device,
            global_step, val_ds, args.batch_size, collate_rtx_point_track,
            args.vis_interval, args.wandb,
        )
        val_loss = validate(model, val_loader, device)
        print(f"Epoch {epoch}  train_loss={train_loss:.4f}  val_loss={val_loss:.4f}")

        if args.wandb:
            import wandb
            wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch}, step=global_step)

        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
            },
            checkpoint_dir / "latest.pth",
        )
        if val_loss < best_val:
            best_val = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "val_loss": val_loss,
                },
                checkpoint_dir / "best.pth",
            )
            print(f"  Saved best (val_loss={val_loss:.4f})")

    print(f"Done. Best val loss: {best_val:.4f}")


if __name__ == "__main__":
    main()
