"""
DINO-based future frame predictor for LIBERO.
PARA-style: frozen DINOv2 encoder → bilinear upsample → Conv decoder.
Separate streams for agentview and wrist (no cross-communication).
Predicts 4 future frames at t+7, t+14, t+21, t+28.
"""

import os
import random
import time
from pathlib import Path

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import wandb

# ─── Config ───────────────────────────────────────────────────────────────────
GPU = 4
IMAGE_SIZE = 224
PATCH_SIZE = 14  # DINOv2 ViT-S/14
NUM_PATCHES = IMAGE_SIZE // PATCH_SIZE  # 16
DINO_DIM = 384  # ViT-S feature dim
FUTURE_OFFSETS = (7, 14, 21, 28)
NUM_FUTURE = len(FUTURE_OFFSETS)
BATCH_SIZE = 16
LR = 1e-3
MAX_ITER = 10000
LOG_EVERY = 50
VIDEO_LOG_EVERY = 200
SAVE_EVERY = 2000
NUM_WORKERS = 4
DATA_DIR = "/data/cameron/vidgen/cosmos-policy/LIBERO-Cosmos-Policy/success_only"
CKPT_DIR = "/data/cameron/vidgen/cosmos-policy/checkpoints/dino_frame_predictor"
DEBUG_DIR = "/data/cameron/vidgen/cosmos-policy/rollout_outputs/dino_debug"

os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(DEBUG_DIR, exist_ok=True)


# ─── Dataset ──────────────────────────────────────────────────────────────────
def decode_jpeg(jpeg_bytes):
    """Decode JPEG bytes to numpy array."""
    from io import BytesIO
    return np.array(Image.open(BytesIO(jpeg_bytes)))


class LIBEROFrameDataset(Dataset):
    """Simple LIBERO dataset for frame prediction. Loads agentview + wrist + future frames."""

    def __init__(self, data_dir, future_offsets=(7, 14, 21, 28), image_size=224):
        self.future_offsets = future_offsets
        self.image_size = image_size
        self.samples = []  # (hdf5_path, demo_key, max_step)

        # Scan all HDF5 files
        for suite_dir in sorted(Path(data_dir).iterdir()):
            if not suite_dir.is_dir():
                continue
            for hdf5_file in sorted(suite_dir.glob("*.hdf5")):
                with h5py.File(str(hdf5_file), "r") as f:
                    for demo_key in sorted(f["data"].keys()):
                        num_steps = f[f"data/{demo_key}/actions"].shape[0]
                        # Need enough frames for the largest future offset
                        max_start = num_steps - 1 - max(future_offsets)
                        if max_start > 0:
                            self.samples.append((str(hdf5_file), f"data/{demo_key}", max_start, num_steps))

        print(f"LIBEROFrameDataset: {len(self.samples)} episodes, future_offsets={future_offsets}")

    def __len__(self):
        return len(self.samples) * 20  # oversample: ~20 random starts per episode

    def __getitem__(self, idx):
        ep_idx = idx % len(self.samples)
        hdf5_path, demo_key, max_start, num_steps = self.samples[ep_idx]
        t = random.randint(0, max_start)

        with h5py.File(hdf5_path, "r") as f:
            # Load current frames
            agent_img = decode_jpeg(f[f"{demo_key}/obs/agentview_rgb_jpeg"][t])
            wrist_img = decode_jpeg(f[f"{demo_key}/obs/eye_in_hand_rgb_jpeg"][t])

            # Load future frames (agentview only for now)
            future_agent = []
            future_wrist = []
            for offset in self.future_offsets:
                fi = min(t + offset, num_steps - 1)
                future_agent.append(decode_jpeg(f[f"{demo_key}/obs/agentview_rgb_jpeg"][fi]))
                future_wrist.append(decode_jpeg(f[f"{demo_key}/obs/eye_in_hand_rgb_jpeg"][fi]))

        # Resize and normalize to [0, 1]
        def process(img):
            img = np.array(Image.fromarray(img).resize((self.image_size, self.image_size), Image.BILINEAR))
            return torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

        agent_t = process(agent_img)
        wrist_t = process(wrist_img)
        future_agent_t = torch.stack([process(f) for f in future_agent])  # (4, 3, H, W)
        future_wrist_t = torch.stack([process(f) for f in future_wrist])  # (4, 3, H, W)

        return {
            "agent_current": agent_t,         # (3, H, W)
            "wrist_current": wrist_t,          # (3, H, W)
            "agent_future": future_agent_t,    # (4, 3, H, W)
            "wrist_future": future_wrist_t,    # (4, 3, H, W)
        }


# ─── Model ────────────────────────────────────────────────────────────────────
class ConvDecoder(nn.Module):
    """Bilinear upsample → 3x Conv2d with GELU → output frames. PARA-style."""

    def __init__(self, in_dim=384, hidden_dim=256, num_frames=4, upsample_size=64):
        super().__init__()
        self.upsample_size = upsample_size
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, hidden_dim, 3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, 3 * num_frames, 3, padding=1),  # 12 channels = 4 RGB frames
        )
        self.num_frames = num_frames

    def forward(self, features):
        """
        features: (B, C, h, w) - DINO patch features
        returns: (B, num_frames, 3, H, W) predicted frames in [0, 1]
        """
        # Bilinear upsample to intermediate resolution
        x = F.interpolate(features, size=self.upsample_size, mode="bilinear", align_corners=False)
        x = self.convs(x)  # (B, 12, 64, 64)
        # Upsample to full resolution
        x = F.interpolate(x, size=IMAGE_SIZE, mode="bilinear", align_corners=False)  # (B, 12, 224, 224)
        # Reshape to (B, 4, 3, 224, 224) and sigmoid for [0, 1] range
        B = x.shape[0]
        x = x.view(B, self.num_frames, 3, IMAGE_SIZE, IMAGE_SIZE)
        return torch.sigmoid(x)


class DINOFramePredictor(nn.Module):
    """
    Frozen DINOv2 encoder → ConvDecoder for each view (agentview + wrist).
    No cross-communication between views.
    """

    def __init__(self):
        super().__init__()
        # Frozen DINOv2 encoder (shared between views)
        self.encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14", pretrained=True)
        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad = False

        # Separate decoders for each view
        self.agent_decoder = ConvDecoder(in_dim=DINO_DIM, num_frames=NUM_FUTURE)
        self.wrist_decoder = ConvDecoder(in_dim=DINO_DIM, num_frames=NUM_FUTURE)

        # DINO normalization
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def extract_features(self, images):
        """Extract DINO patch features from images. images: (B, 3, H, W) in [0, 1]."""
        # Normalize for DINO
        x = (images - self.mean) / self.std
        # Get patch tokens (skip CLS)
        with torch.no_grad():
            features = self.encoder.forward_features(x)
            patch_tokens = features["x_norm_patchtokens"]  # (B, N_patches, dim)
        # Reshape to spatial grid
        B, N, D = patch_tokens.shape
        h = w = int(N ** 0.5)
        return patch_tokens.permute(0, 2, 1).view(B, D, h, w)  # (B, 384, 16, 16)

    def forward(self, agent_current, wrist_current):
        """
        agent_current: (B, 3, H, W) in [0, 1]
        wrist_current: (B, 3, H, W) in [0, 1]
        returns: agent_pred (B, 4, 3, H, W), wrist_pred (B, 4, 3, H, W)
        """
        agent_feats = self.extract_features(agent_current)
        wrist_feats = self.extract_features(wrist_current)

        agent_pred = self.agent_decoder(agent_feats)
        wrist_pred = self.wrist_decoder(wrist_feats)

        return agent_pred, wrist_pred


# ─── Training ─────────────────────────────────────────────────────────────────
def to_numpy_video(tensor):
    """Convert (N, 3, H, W) tensor in [0,1] to (N, H, W, 3) uint8 numpy."""
    return (tensor.detach().cpu().clamp(0, 1).permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)


def main():
    device = torch.device("cuda:0")

    print("=" * 60)
    print("DINO Frame Predictor - PARA-style")
    print(f"  Future offsets: {FUTURE_OFFSETS}")
    print(f"  Batch size: {BATCH_SIZE}, LR: {LR}")
    print("=" * 60)

    # wandb
    wandb.init(project="dino-frame-predictor", name="para_style_4frame",
               config={"lr": LR, "batch_size": BATCH_SIZE, "future_offsets": FUTURE_OFFSETS,
                       "image_size": IMAGE_SIZE, "dino": "dinov2_vits14"})

    # Dataset
    print("Loading dataset...")
    dataset = LIBEROFrameDataset(DATA_DIR, future_offsets=FUTURE_OFFSETS, image_size=IMAGE_SIZE)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                            num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
                            persistent_workers=True)

    # Model
    print("Loading DINOv2 + decoder...")
    model = DINOFramePredictor().to(device)
    # Count trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"  Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
    print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # Optimizer (only decoder params)
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=LR, weight_decay=0.01
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_ITER)

    # Training loop
    data_iter = iter(dataloader)
    model.train()
    t0 = time.time()

    for iteration in range(1, MAX_ITER + 1):
        # Get batch (infinite loop over dataloader)
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)

        agent_current = batch["agent_current"].to(device)
        wrist_current = batch["wrist_current"].to(device)
        agent_future_gt = batch["agent_future"].to(device)   # (B, 4, 3, H, W)
        wrist_future_gt = batch["wrist_future"].to(device)

        # Forward
        agent_pred, wrist_pred = model(agent_current, wrist_current)

        # L1 loss on both views
        loss_agent = F.l1_loss(agent_pred, agent_future_gt)
        loss_wrist = F.l1_loss(wrist_pred, wrist_future_gt)
        loss = loss_agent + loss_wrist

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # ─── Logging ───
        if iteration % LOG_EVERY == 0 or iteration == 1:
            elapsed = time.time() - t0
            iter_per_sec = iteration / elapsed
            print(f"  iter {iteration:5d} | loss {loss.item():.4f} (agent {loss_agent.item():.4f} wrist {loss_wrist.item():.4f}) | {iter_per_sec:.1f} it/s | lr {scheduler.get_last_lr()[0]:.2e}")
            wandb.log({
                "train/loss": loss.item(),
                "train/loss_agent": loss_agent.item(),
                "train/loss_wrist": loss_wrist.item(),
                "train/lr": scheduler.get_last_lr()[0],
            }, step=iteration)

        # ─── Video logging ───
        if iteration % VIDEO_LOG_EVERY == 0 or iteration == 1:
            model.eval()
            with torch.no_grad():
                # Use first sample from current batch
                ap, wp = model(agent_current[:1], wrist_current[:1])

            # GT: current + 4 future
            gt_agent_vid = to_numpy_video(torch.cat([agent_current[:1], agent_future_gt[0]], dim=0))  # (5, H, W, 3)
            gt_wrist_vid = to_numpy_video(torch.cat([wrist_current[:1], wrist_future_gt[0]], dim=0))

            # Pred: current + 4 predicted
            pred_agent_vid = to_numpy_video(torch.cat([agent_current[:1], ap[0]], dim=0))
            pred_wrist_vid = to_numpy_video(torch.cat([wrist_current[:1], wp[0]], dim=0))

            wandb.log({
                "video/gt_agent": wandb.Video(gt_agent_vid.transpose(0, 3, 1, 2), fps=2, format="mp4"),
                "video/pred_agent": wandb.Video(pred_agent_vid.transpose(0, 3, 1, 2), fps=2, format="mp4"),
                "video/gt_wrist": wandb.Video(gt_wrist_vid.transpose(0, 3, 1, 2), fps=2, format="mp4"),
                "video/pred_wrist": wandb.Video(pred_wrist_vid.transpose(0, 3, 1, 2), fps=2, format="mp4"),
            }, step=iteration)

            # Save debug frames
            Image.fromarray(gt_agent_vid[0]).save(f"{DEBUG_DIR}/iter{iteration:06d}_gt_agent_t0.png")
            Image.fromarray(gt_agent_vid[-1]).save(f"{DEBUG_DIR}/iter{iteration:06d}_gt_agent_t3.png")
            Image.fromarray(pred_agent_vid[1]).save(f"{DEBUG_DIR}/iter{iteration:06d}_pred_agent_t0.png")
            Image.fromarray(pred_agent_vid[-1]).save(f"{DEBUG_DIR}/iter{iteration:06d}_pred_agent_t3.png")

            print(f"  [VideoLog] Logged GT + pred videos at iter {iteration}")
            model.train()

        # ─── Checkpoint ───
        if iteration % SAVE_EVERY == 0:
            ckpt_path = f"{CKPT_DIR}/iter_{iteration:06d}.pt"
            torch.save({
                "iteration": iteration,
                "model_state_dict": {k: v for k, v in model.state_dict().items() if "encoder" not in k},
                "optimizer_state_dict": optimizer.state_dict(),
            }, ckpt_path)
            print(f"  Saved checkpoint: {ckpt_path}")

    print("\n" + "=" * 60)
    print("Training done!")
    print("=" * 60)
    wandb.finish()


if __name__ == "__main__":
    main()
