"""SVD Video Policy model - wraps SVDFeatureExtractor + ParaHeads
in the same interface as TrajectoryHeatmapPredictor for eval.py compatibility."""

import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

SVD_ROOT = Path("/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")
sys.path.insert(0, str(SVD_ROOT))

# Import from train script
from train_svd_para import SVDFeatureExtractor, ParaHeads, COMBINED_FEAT_DIM, PARA_OUT_SIZE, N_WINDOW
from train_svd_para import N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS

PRED_SIZE = PARA_OUT_SIZE  # 64
IMAGE_SIZE = 448
SVD_SIZE = (320, 576)

# Dataset stat globals (set from checkpoint)
MIN_HEIGHT = 0.0
MAX_HEIGHT = 1.0
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
MIN_ROT = [-3.14159, -3.14159, -3.14159]
MAX_ROT = [3.14159, 3.14159, 3.14159]
REF_ROTATION_QUAT = [0.0, 0.0, 0.0, 1.0]


class SVDParaPredictor(nn.Module):
    """SVD Video Policy compatible with eval.py interface."""

    def __init__(self, svd_base, svd_unet, device="cuda"):
        super().__init__()
        self.svd_extractor = SVDFeatureExtractor(
            svd_base_path=svd_base,
            svd_unet_path=svd_unet,
            device=device,
        )
        self.para_heads = ParaHeads(
            feat_dim=COMBINED_FEAT_DIM,
            para_out_size=PARA_OUT_SIZE,
            n_window=N_WINDOW,
        )
        self.embed_dim = 512  # ParaHeads internal dim
        self.n_window = N_WINDOW

        # ImageNet normalization params (for undoing normalization)
        self.register_buffer("img_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("img_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, rgb, start_keypoint_2d=None, query_pixels=None):
        """
        Args:
            rgb: (B, 3, 448, 448) ImageNet-normalized
            start_keypoint_2d: (B, 2) - not used currently
            query_pixels: (B, N_WINDOW, 2) in PRED_SIZE coords
        Returns:
            volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE)
            gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS) or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS) or None
            feats: (B, D, PRED_SIZE, PRED_SIZE)
        """
        # Undo ImageNet normalization to get [0, 1]
        rgb_01 = rgb * self.img_std + self.img_mean
        rgb_01 = rgb_01.clamp(0, 1)

        # Resize to SVD resolution
        rgb_svd = F.interpolate(rgb_01, size=SVD_SIZE, mode='bilinear', align_corners=False)

        # Extract frozen SVD features
        with torch.no_grad():
            features = self.svd_extractor.extract_features(rgb_svd)

        # Run PARA heads
        vol_logits, feats, grip_logits, rot_logits = self.para_heads(
            features, query_pixels=query_pixels
        )

        return vol_logits, grip_logits, rot_logits, feats
