"""Train VideoLatentModel: VidTok latents as targets (no grad), MSE loss, wandb, vis every ~100 iters."""

import argparse
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import wandb

# VidTok and vidgen (for our_vid_model package)
vidgen_root = Path(__file__).resolve().parents[1]
VidTok_root = vidgen_root / "VidTok"
sys.path.insert(0, str(vidgen_root))
sys.path.insert(0, str(VidTok_root))

from our_vid_model.model import VideoLatentModel
from our_vid_model.dataset import DroidVideoDataset, collate_batch


def load_vidtok(config_path: str, ckpt_path: str, device: torch.device):
    from scripts.inference_evaluate import load_model_from_config
    model = load_model_from_config(config_path, ckpt_path)
    model = model.to(device).eval()
    return model


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--keygrip", type=str, default="../keygrip", help="keygrip repo root for DINO")
    p.add_argument("--vidtok", type=str, default=None, help="VidTok repo root (default: ../VidTok)")
    p.add_argument("--vidtok-config", type=str, default=None)
    p.add_argument("--vidtok-ckpt", type=str, default=None)
    p.add_argument("--data-root", type=str, default="/data/weiduoyuan/droid_raw/1.0.1")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--wandb-project", type=str, default="our_vid_model")
    args = p.parse_args()

    vidtok_root = Path(args.vidtok or str(VidTok_root)).resolve()
    keygrip_root = Path(args.keygrip).resolve()
    device = torch.device(args.device)

    config_path = args.vidtok_config or str(vidtok_root / "configs/vidtok_v1_1/vidtok_kl_causal_288_8chn_v1_1.yaml")
    ckpt_path = args.vidtok_ckpt or str(vidtok_root / "checkpoints/vidtok_kl_causal_288_8chn_v1_1.ckpt")
    if not Path(config_path).is_absolute():
        config_path = str(vidtok_root / config_path)
    if not Path(ckpt_path).is_absolute():
        ckpt_path = str(vidtok_root / ckpt_path)

    wandb.init(project=args.wandb_project, config=vars(args))

    vidtok = load_vidtok(config_path, ckpt_path, device)
    model = VideoLatentModel(keygrip_root).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    dataset = DroidVideoDataset(args.data_root, num_frames=8, sample_fps=4.0, size=256)
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
    )

    global_step = 0
    while global_step < args.steps:
        for batch in loader:
            if global_step >= args.steps:
                break
            x = batch.to(device)
            with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
                z_gt, _ = vidtok.encode(x, return_reg_log=True)
            z_gt = z_gt.float()
            z_pred = model(x)
            loss = F.mse_loss(z_pred, z_gt)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            wandb.log({"train/loss": loss.item()}, step=global_step)

            if global_step > 0 and global_step % args.vis_every == 0:
                with torch.no_grad():
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        x_recon_gt = vidtok.decode(z_gt[:1])
                        x_recon_pred = vidtok.decode(z_pred[:1])
                    x_recon_gt = x_recon_gt.float()
                    x_recon_pred = x_recon_pred.float()
                # (1, 3, T, H, W) in [-1,1] -> [0,1] for logging
                to_01 = lambda t: ((t[0].permute(1, 0, 2, 3) + 1.0) / 2.0).clamp(0, 1)
                inp = to_01(x[:1])
                gt_recon = to_01(x_recon_gt)
                pred_recon = to_01(x_recon_pred)
                grid_inp = make_grid(inp, nrow=4)
                grid_gt = make_grid(gt_recon, nrow=4)
                grid_pred = make_grid(pred_recon, nrow=4)
                wandb.log({
                    "vis/input_frames": wandb.Image(grid_inp.permute(1, 2, 0).cpu().numpy()),
                    "vis/gt_recon_frames": wandb.Image(grid_gt.permute(1, 2, 0).cpu().numpy()),
                    "vis/pred_recon_frames": wandb.Image(grid_pred.permute(1, 2, 0).cpu().numpy()),
                }, step=global_step)
            global_step += 1

    wandb.finish()


if __name__ == "__main__":
    main()
