"""Train UVA (simple_uva MAR) with PARA head on keygrip data.

Uses the exact dataloader from keygrip/volume_dino_tracks/train.py: RealTrajectoryDataset only,
with dataset_root = scratch/parsed_pickplace_exp1_feb9 (one parsed folder, same as keygrip train default).
Single-image samples: rgb (ImageNet normalized), trajectory_2d/3d; video = same image repeated N_FRAMES.

Same supervision as volume_dino_tracks: volume CE loss at GT (pixel, height_bin).
Viz: max-along-ray heatmap (saved to checkpoints/<run_name>/para_vis.png).
"""
import argparse
import io
import sys
import tempfile
from pathlib import Path
from types import SimpleNamespace

import wandb
import torchvision
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from einops import rearrange
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

# Keygrip volume_dino_tracks: same dataset/dataloader as train.py
KEYGRIP_VOLUME = REPO_ROOT.resolve().parents[1] / "keygrip" / "volume_dino_tracks"
sys.path.insert(0, str(KEYGRIP_VOLUME))

from model import N_HEIGHT_BINS
import model as keygrip_model_module
from data import RealTrajectoryDataset

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only

# keygrip/scratch/parsed_pickplace_exp1_feb9 under data/cameron (parents[2]=/data)
DEFAULT_DATASET_ROOT = REPO_ROOT.resolve().parents[2] / "cameron" / "keygrip" / "scratch" / "parsed_pickplace_exp1_feb9"
DEFAULT_CKPT = REPO_ROOT / "checkpoints" / "keygrip_continue_uva_latest.pt"
LATENT_SCALE = 0.2325

# MAR uses 4 frames; we supervise on first 4 waypoints
N_FRAMES = 4
PARA_OUT_SIZE = 64


def load_frame(episode_dir: Path, frame_idx: int, image_size: int) -> np.ndarray:
    """Load one frame as (H, W, 3) uint8, resized to image_size (matches keygrip dataset)."""
    frame_str = f"{frame_idx:06d}"
    rgb_path = episode_dir / f"{frame_str}.png"
    if not rgb_path.exists():
        raise FileNotFoundError(rgb_path)
    rgb = cv2.imread(str(rgb_path))
    if rgb is None:
        raise RuntimeError(f"Failed to read {rgb_path}")
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


class KeygripVideoWrapper(Dataset):
    """Wraps RealTrajectoryDataset to yield real N_FRAMES consecutive video frames per sample.
    Only includes samples where episode has frame_idx .. frame_idx+N_FRAMES-1.
    """

    def __init__(self, real_traj_dataset, n_frames: int, image_size: int):
        self.real = real_traj_dataset
        self.n_frames = n_frames
        self.image_size = image_size
        self.valid_indices = []
        for idx in range(len(self.real)):
            episode_dir, fi = self.real.samples[idx]
            ep = Path(episode_dir)
            if all((ep / f"{fi + i:06d}.png").exists() for i in range(n_frames)):
                self.valid_indices.append(idx)

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, i):
        idx = self.valid_indices[i]
        sample = self.real[idx]
        ep = Path(sample["episode_dir"])
        fi = sample["frame_idx"]
        frames = []
        for t in range(self.n_frames):
            fr = load_frame(ep, fi + t, self.image_size)  # (H, W, 3) uint8
            frames.append(fr)
        video_np = np.stack(frames, axis=0).astype(np.float32) / 255.0
        video_np = video_np * 2.0 - 1.0
        video = torch.from_numpy(video_np).permute(3, 0, 1, 2)  # (C, T, H, W)
        return {
            "video": video,
            "trajectory_2d": sample["trajectory_2d"],
            "trajectory_3d": sample["trajectory_3d"],
            "heatmap_target": sample["heatmap_target"],
        }


def discretize_height(height_values, min_h, max_h, n_bins):
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    bin_indices = (normalized * (n_bins - 1)).long().clamp(0, n_bins - 1)
    return bin_indices


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins, sample_mask=None):
    """CE over 3D cells. trajectory_2d and logits at same resolution (e.g. 64).
    If sample_mask (B,) is provided, only average over samples where mask is True.
    """
    B, N, Nh, H, W = pred_volume_logits.shape
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)
    h_bin = target_height_bins.clamp(0, Nh - 1)
    losses = []
    for t in range(N):
        logits_t = pred_volume_logits[:, t].reshape(B, -1)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()
        losses.append(F.cross_entropy(logits_t, target_idx, reduction="none"))
    loss_per_sample = torch.stack(losses, dim=1).mean(dim=1)
    if sample_mask is not None:
        n = sample_mask.sum().clamp(min=1)
        return (loss_per_sample * sample_mask.float()).sum() / n
    return loss_per_sample.mean()


def build_max_along_ray_heatmaps(volume_logits):
    """volume_logits (B, T, Nh, H, W) -> per-timestep (B, T, H, W) max over height."""
    vol_probs = F.softmax(volume_logits.reshape(volume_logits.shape[0], volume_logits.shape[1], -1), dim=2)
    vol_probs = vol_probs.view(volume_logits.shape)
    max_along_ray = vol_probs.max(dim=2)[0]
    return max_along_ray


def build_vae(repo_root, vae_ckpt, device):
    ckpt_path = repo_root / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def main():
    p = argparse.ArgumentParser(description="Train UVA + PARA head on keygrip trajectory data (RealTrajectoryDataset only)")
    p.add_argument("--dataset-root", type=str, default=str(DEFAULT_DATASET_ROOT), help="Keygrip parsed dir. Default: keygrip/scratch/parsed_pickplace_exp1_feb9.")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--checkpoint", type=str, default=str(DEFAULT_CKPT), help="UVA MAR checkpoint (keygrip_continue_uva_latest.pt)")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--vis-every", type=int, default=100, help="Steps between viz updates")
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--run-name", type=str, default="uva_para_keygrip")
    p.add_argument("--image-size", type=int, default=256, help="Input image size (VAE res)")
    p.add_argument("--log_wandb", action="store_true", help="Log to wandb (loss, denoising video, PARA heatmaps)")
    p.add_argument("--num-iter", type=int, default=64, help="Sampling steps for denoising video vis")
    args = p.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    if args.log_wandb:
        import wandb
        wandb.init(project="simple_uva", config=vars(args), name=args.run_name, mode="online")

    vae = build_vae(REPO_ROOT, args.vae_ckpt, device)
    model = mar_base_video_only(
        img_size=args.image_size,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
        predict_para=True,
        para_n_bins=N_HEIGHT_BINS,
        para_out_size=PARA_OUT_SIZE,
    ).to(device)

    # Load pretrained UVA checkpoint (MAR only; PARA head stays random)
    ckpt_path = Path(args.checkpoint)
    if ckpt_path.exists():
        try:
            payload = torch.load(ckpt_path, map_location=device, weights_only=False)
        except Exception:
            payload = torch.load(ckpt_path, map_location=device)
        state_dicts = payload.get("state_dicts") or {}
        sd = state_dicts.get("ema_model") or state_dicts.get("model") or payload.get("model")
        if sd is not None:
            if any(k.startswith("model.") for k in sd):
                model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
            else:
                model_sd = dict(sd)
            current = model.state_dict()
            loadable = {k: v for k, v in model_sd.items() if k in current and current[k].shape == v.shape}
            current.update(loadable)
            model.load_state_dict(current, strict=False)
            print(f"Loaded MAR from {ckpt_path} ({len(loadable)} keys); PARA head random init")
        else:
            print(f"No state_dict in {ckpt_path}; training from scratch")
    else:
        print(f"Checkpoint not found: {ckpt_path}; training from scratch")

    # Real video frames: RealTrajectoryDataset + wrapper that loads N_FRAMES consecutive frames per sample
    dataset_root = Path(args.dataset_root).resolve()
    real_traj = RealTrajectoryDataset(dataset_root=str(dataset_root), image_size=args.image_size)
    dataset = KeygripVideoWrapper(real_traj, N_FRAMES, args.image_size)
    print(f"Using KeygripVideoWrapper: {len(dataset)} samples (real {N_FRAMES}-frame video) from {dataset_root}")

    scale_2d = PARA_OUT_SIZE / args.image_size

    def collate_video(batch):
        video = torch.stack([b["video"] for b in batch])
        trajectory_2d = torch.stack([b["trajectory_2d"] for b in batch])
        trajectory_3d = torch.stack([b["trajectory_3d"] for b in batch])
        heatmap_target = torch.stack([b["heatmap_target"] for b in batch])
        return {"video": video, "trajectory_2d": trajectory_2d, "trajectory_3d": trajectory_3d, "heatmap_target": heatmap_target}

    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_video,
        pin_memory=True,
    )

    # Height range from dataset (match volume_dino_tracks)
    all_heights = []
    n_scan = min(500, len(dataset))
    for i in range(n_scan):
        t3d = dataset[i]["trajectory_3d"].numpy()  # (N_WINDOW, 3)
        all_heights.extend(t3d[:, 2].tolist())
    all_heights = np.array(all_heights)
    if len(all_heights) > 0:
        min_height = float(all_heights.min())
        max_height = float(all_heights.max())
        keygrip_model_module.MIN_HEIGHT = min_height
        keygrip_model_module.MAX_HEIGHT = max_height
        print(f"Height range: [{min_height:.6f}, {max_height:.6f}]")
    else:
        keygrip_model_module.MIN_HEIGHT = 0.0
        keygrip_model_module.MAX_HEIGHT = 1.0
        print("No trajectory data for height range; using [0, 1]")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    ckpt_dir = Path(args.checkpoint_dir) / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Plot: one row of max-along-ray heatmaps (first sample, 4 timesteps); loss is in wandb
    fig, axes = plt.subplots(1, N_FRAMES, figsize=(3 * N_FRAMES, 3))
    if N_FRAMES == 1:
        axes = [axes]

    global_step = 0
    while True: # just train until user interrupts
        model.train()
        pbar = tqdm(loader)

        for batch in pbar:
            video = batch["video"].to(device)  # (B, C, T, H, W) in [-1, 1], real N_FRAMES
            trajectory_2d = batch["trajectory_2d"].to(device)
            trajectory_3d = batch["trajectory_3d"].to(device)

            B, C, T, H, W = video.shape
            traj_2d = trajectory_2d[:, :N_FRAMES]
            traj_3d = trajectory_3d[:, :N_FRAMES]
            target_height = traj_3d[:, :, 2]
            target_height_bins = discretize_height(
                target_height,
                keygrip_model_module.MIN_HEIGHT,
                keygrip_model_module.MAX_HEIGHT,
                N_HEIGHT_BINS,
            )

            # Encode real video (same as train_with_self_collected): condition on frame 0, predict frames 0..T-1
            frames = rearrange(video, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames.float())
                z_vae = posterior.sample() * LATENT_SCALE
            z_vae = rearrange(z_vae, "(b t) c h w -> b t c h w", b=B, t=N_FRAMES)
            x_tokens = model.patchify(rearrange(z_vae, "b t c h w -> (b t) c h w"))
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B, t=N_FRAMES)
            cond_tokens = x_tokens[:, :1].expand(-1, N_FRAMES, -1, -1)

            # Video denoising loss (same compute_loss / timestep sampling as train_with_self_collected)
            diffusion_loss = model.compute_loss(x_tokens, cond_tokens)

            volume_logits = model.forward_para(x_tokens, cond_tokens, mask=None)
            # Scale GT pixel coords to PARA output resolution (64)
            traj_64 = traj_2d * scale_2d
            traj_64 = traj_64.clamp(0, PARA_OUT_SIZE - 1.001)

            para_loss = compute_volume_loss(
                volume_logits, traj_64, target_height_bins,
                sample_mask=None,
            )/1e1
            total_loss = diffusion_loss + para_loss
            opt.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            pbar.set_postfix(diff=diffusion_loss.item(), para=para_loss.item(), total=total_loss.item())

            if args.log_wandb:
                import wandb
                wandb.log({
                    "train/diffusion_loss": diffusion_loss.item(),
                    "train/para_loss": para_loss.item(),
                    "train/loss": total_loss.item(),
                }, step=global_step)

            if global_step % args.vis_every == 0:
                with torch.no_grad():
                    heatmaps = build_max_along_ray_heatmaps(volume_logits[:1])
                    heatmaps = heatmaps[0].cpu().numpy()
                    for t in range(N_FRAMES):
                        ax = axes[t]
                        ax.clear()
                        ax.imshow(heatmaps[t], cmap="hot")
                        ax.scatter(
                            traj_64[0, t, 0].cpu().item(),
                            traj_64[0, t, 1].cpu().item(),
                            c="cyan",
                            s=30,
                            marker="x",
                        )
                        ax.set_title(f"t={t}")
                        ax.axis("off")
                    plt.tight_layout()
                    plt.savefig(ckpt_dir / "para_vis.png", dpi=100)
                    if args.log_wandb:
                        import wandb
                        buf = io.BytesIO()
                        plt.savefig(buf, format="png", dpi=100)
                        buf.seek(0)
                        from PIL import Image
                        wandb.log({"vis/para_heatmaps": wandb.Image(Image.open(buf))}, step=global_step)
                    plt.close("all")
                    fig, axes = plt.subplots(1, N_FRAMES, figsize=(3 * N_FRAMES, 3))
                    if N_FRAMES == 1:
                        axes = [axes]
                if args.log_wandb:
                    print("predicting video")
                    with torch.no_grad():
                        first_frame = video[:1, :, 0]
                        posterior0 = vae.encode(first_frame.float())
                        z0 = posterior0.sample() * LATENT_SCALE
                        cond = z0.unsqueeze(1).expand(1, N_FRAMES, -1, -1, -1)
                        tokens, _ = model.sample_tokens( bsz=1, cond=cond, num_iter=args.num_iter, cfg=1.0, temperature=0.95,)
                        pred = vae.decode(tokens / LATENT_SCALE)
                        pred = pred.view(1, N_FRAMES, 3, args.image_size, args.image_size)
                    print("done predicting video")
                    pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                        tmp_path = f.name
                    frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                    torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                    wandb.log({"vis/denoising_video": wandb.Video(tmp_path, format="mp4")}, step=global_step)
                    Path(tmp_path).unlink(missing_ok=True)
                    # GT video (real N frames or same image repeated) with keypoint drawn on each frame
                    gt_video = video[:1]
                    gt_np = ((gt_video[0].cpu().permute(1, 2, 3, 0) + 1.0) / 2.0).clamp(0, 1).numpy()
                    gt_np = (gt_np * 255).astype("uint8")
                    # Draw trajectory keypoint on each frame (traj_2d is in image coords, same as gt_np)
                    H, W = gt_np.shape[1], gt_np.shape[2]
                    for t in range(N_FRAMES):
                        x = int(traj_2d[0, t, 0].cpu().item())
                        y = int(traj_2d[0, t, 1].cpu().item())
                        x = max(0, min(W - 1, x))
                        y = max(0, min(H - 1, y))
                        r = 6
                        for dy in range(-r, r + 1):
                            for dx in range(-r, r + 1):
                                if dx * dx + dy * dy <= r * r:
                                    ny, nx = y + dy, x + dx
                                    if 0 <= ny < H and 0 <= nx < W:
                                        gt_np[t, ny, nx] = [0, 255, 255]
                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                        gt_path = f.name
                    torchvision.io.write_video(gt_path, torch.from_numpy(gt_np), fps=4)
                    wandb.log({"vis/gt_video": wandb.Video(gt_path, format="mp4")}, step=global_step)
                    Path(gt_path).unlink(missing_ok=True)

            global_step += 1
            if global_step % args.checkpoint_every == 0:
                print("saving model")
                torch.save({
                    "step": global_step,
                    "model": model.state_dict(),
                    "optimizer": opt.state_dict(),
                }, ckpt_dir / "latest.pt")
                print(f"Saved {ckpt_dir / 'latest.pt'}")

    plt.ioff()
    if args.log_wandb:
        import wandb
        wandb.finish()
    print("Done.")


if __name__ == "__main__":
    main()
