"""Train student to match teacher DINO features. Optional wandb with --log_wandb. PCA vis of GT vs pred."""

import argparse
import sys
import time
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import numpy as np
from tqdm import tqdm

vidgen_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(vidgen_root))

from dino_vid_model.model import StudentDinoVideo  # TeacherDinoVideo commented out (teacher not used)
from dino_vid_model.dataset import DroidVideoDataset, CachedClipDataset, collate_batch

NUM_FRAMES = 8


def pca_vis(feats, n_components=3, fit_ref=None):
    """feats: (B, T, D, H, W). Returns (B, T, 3, H, W) in [0,1] for RGB.
    PCA fit from fit_ref if given. Min-max norm across whole video."""
    B, T, D, H, W = feats.shape
    x = feats.permute(0, 1, 3, 4, 2).reshape(-1, D)
    x = x.double().cpu().numpy()
    if fit_ref is not None:
        ref = fit_ref.permute(0, 1, 3, 4, 2).reshape(-1, D).double().cpu().numpy()
        mean = ref.mean(axis=0)
        ref_centered = ref - mean
        U, S, Vh = np.linalg.svd(ref_centered, full_matrices=False)
        proj = Vh[:n_components].T
        x_centered = x - mean
        out = x_centered @ proj
    else:
        mean = x.mean(axis=0)
        x_centered = x - mean
        U, S, Vh = np.linalg.svd(x_centered, full_matrices=False)
        proj = Vh[:n_components].T
        out = x_centered @ proj
    out = out.reshape(B, T, H, W, n_components)
    out = out.transpose(0, 1, 4, 2, 3)
    lo, hi = out.min(), out.max()
    if hi > lo:
        out = (out - lo) / (hi - lo)
    out = out.reshape(B, T, n_components, H, W)
    return torch.from_numpy(out).float()


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--keygrip", type=str, default="/data/cameron/keygrip")
    p.add_argument("--data-root", type=str, default="/data/weiduoyuan/droid_raw/1.0.1")
    p.add_argument("--cache-dir", type=str, default=None, help="Use pre-extracted .pt clips (run precache_clips.py first) for fast loading")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=8, help="DataLoader workers (increase for large batch)")
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--log_wandb", action="store_true", help="Log to wandb")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="Save student checkpoint every --checkpoint-every steps")
    p.add_argument("--checkpoint-every", type=int, default=1000, help="Save interval when --checkpoint-dir is set")
    p.add_argument("--rgb-loss-weight", type=float, default=1.0, help="Weight for RGB L1 reconstruction loss at 256x256")
    p.add_argument("--t-star", type=float, default=0.9, help="MIP t* (noise/interpolant level for step-2)")
    p.add_argument("--name", type=str, default="dino_vid_model", help="Name of the run")
    args = p.parse_args()

    keygrip_root = Path(args.keygrip).resolve()
    device = torch.device(args.device)

    if args.log_wandb:
        import wandb
        wandb.init(project="dino_vid_model", config=vars(args),name=args.name)

    # teacher = TeacherDinoVideo(keygrip_root).to(device)
    student = StudentDinoVideo(keygrip_root).to(device)
    opt = torch.optim.AdamW(student.parameters(), lr=args.lr)

    if args.cache_dir:
        dataset = CachedClipDataset(args.cache_dir)
        n_samples = len(dataset)
        print(f"Dataset: {n_samples} samples (cache; will rescan each epoch)")
    else:
        dataset = DroidVideoDataset(args.data_root, num_frames=8, sample_fps=4.0, size=256)
        n_samples = len(dataset)
        print(f"Dataset: {n_samples} videos")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")
    while global_step < args.steps:
        if args.cache_dir:
            dataset = CachedClipDataset(args.cache_dir)
            n_samples = len(dataset)
            if n_samples == 0:
                print("Cache empty, waiting 10s...")
                time.sleep(10)
                continue
            loader = DataLoader(
                dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                collate_fn=collate_batch,
                pin_memory=True,
                persistent_workers=args.workers > 0,
                prefetch_factor=4 if args.workers > 0 else None,
            )
            print(f"Rescan cache: {n_samples} samples")
        y=None
        overfit=True
        for x in loader:
            if overfit:
                if y is None:y=x
                x=y
            if global_step >= args.steps:
                break
            x = x.to(device, non_blocking=True)
            # Target video for reconstruction supervision (B, T, 3, H, W) in [-1, 1]
            x_target = x.permute(0, 2, 1, 3, 4)
            # MIP 2-step predictions (supervise both steps)
            a0_hat, a1_hat = student.mip_train_preds(x, x_target, t_star=args.t_star)
            x_rec = a1_hat
            z_pred = a1_hat
            rgb_loss0 = F.mse_loss(a0_hat.float(), x_target.float())
            rgb_loss1 = F.mse_loss(a1_hat.float(), x_target.float())
            rgb_loss = rgb_loss0 + rgb_loss1
            loss = args.rgb_loss_weight * rgb_loss
            opt.zero_grad()
            loss.backward()
            opt.step()

            if args.log_wandb:
                import wandb
                wandb.log(
                    {
                        "train/loss_total": loss.item(),
                        "train/loss_rgb": rgb_loss.item(),
                        "train/loss_rgb0": rgb_loss0.item(),
                        "train/loss_rgb1": rgb_loss1.item(),
                    },
                    step=global_step,
                )

            if global_step % args.vis_every == 0 and args.log_wandb:
                print(f"Visualizing at step {global_step}")
                import wandb
                with torch.no_grad():
                    # z_gt_vis = z_gt[:1].cpu()
                    z_pred_vis = z_pred[:1].detach().cpu()
                    # Run deterministic MIP inference to visualize both steps
                    infer0, infer1 = student.mip_infer_steps(x[:1], t_star=args.t_star)
                # pca_gt = pca_vis(z_gt_vis)
                pca_pred = pca_vis(z_pred_vis, fit_ref=None)
                # pca_gt = pca_gt[0]
                pca_pred = pca_pred[0]
                # grid_gt = make_grid(pca_gt, nrow=4)
                grid_pred = make_grid(pca_pred, nrow=4)
                inp_rgb = x[:1]
                inp_rgb = ((inp_rgb[0].permute(1, 0, 2, 3) + 1.0) / 2.0).clamp(0, 1)
                grid_rgb = make_grid(inp_rgb, nrow=4)
                with torch.no_grad():
                    x_rec_vis = x_rec[:1].detach().cpu()
                # x_rec_vis[0]: (T, 3, H, W) already, no permute needed
                x_rec_vis = ((x_rec_vis[0] + 1.0) / 2.0).clamp(0, 1)
                grid_rgb_rec = make_grid(x_rec_vis, nrow=4)
                infer0_vis = ((infer0[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                infer1_vis = ((infer1[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                grid_rgb_step0 = make_grid(infer0_vis, nrow=4)
                grid_rgb_step1 = make_grid(infer1_vis, nrow=4)
                wandb.log({
                    "vis/input_rgb": wandb.Image(grid_rgb.permute(1, 2, 0).cpu().numpy()),
                    "vis/recon_rgb": wandb.Image(grid_rgb_rec.permute(1, 2, 0).cpu().numpy()),
                    "vis/recon_step0": wandb.Image(grid_rgb_step0.permute(1, 2, 0).cpu().numpy()),
                    "vis/recon_step1": wandb.Image(grid_rgb_step1.permute(1, 2, 0).cpu().numpy()),
                    # "vis/pca_gt": wandb.Image(grid_gt.permute(1, 2, 0).cpu().numpy()),
                    "vis/pca_pred": wandb.Image(grid_pred.permute(1, 2, 0).cpu().numpy()),
                }, step=global_step)
            global_step += 1
            pbar.update(1)
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_dir = Path(args.checkpoint_dir)
                ckpt_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {"step": global_step, "student": student.state_dict(), "optimizer": opt.state_dict()},
                    ckpt_dir / f"run_{args.name}_latest.pt",
                )
                print(f"Saved checkpoint run_{args.name}.pt")
    pbar.close()

    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
