"""SVD + PARA wrapper for eval_multistage.py.

Wraps the SVD UNet (with LoRA) + ParaHeadsOnUNet into the interface expected
by eval_multistage.py's PARA code path:
    volume_logits, _, _, feats = model(img_tensor, start_kp)
    gripper_logits, rotation_logits = model.predict_at_pixels(feats, pred_pixels)
"""

import sys
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

# Make SVD modules importable
SVD_ROOT = "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA"
if SVD_ROOT not in sys.path:
    sys.path.insert(0, SVD_ROOT)

from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

# ---------------------------------------------------------------------------
# Constants (must match training script)
# ---------------------------------------------------------------------------
PARA_OUT_SIZE = 64
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
PROJ_DIM = 128
N_WINDOW = 1  # SVD PARA predicts 1 timestep per forward pass

SVD_BASE_MODEL = os.path.join(SVD_ROOT, "checkpoints/stable-video-diffusion-img2vid-xt-1-1")
SVD_HEIGHT = 320
SVD_WIDTH = 576
SVD_NUM_FRAMES = 7
SVD_NUM_INFERENCE_STEPS = 25

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)


# ---------------------------------------------------------------------------
# ParaHeadsOnUNet (copied from train_svd_para_joint.py to avoid heavy imports)
# ---------------------------------------------------------------------------
class ParaHeadsOnUNet(nn.Module):
    """PARA heads that attach to SVD UNet's up_block_1 and up_block_2."""

    def __init__(self, n_window=1, n_height_bins=N_HEIGHT_BINS, proj_dim=PROJ_DIM):
        super().__init__()
        self.n_height_bins = n_height_bins

        self.proj_block1 = nn.Conv2d(1280, proj_dim, 1)
        self.proj_block2 = nn.Conv2d(640, proj_dim, 1)

        D = proj_dim * 2  # 256
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        self.volume_head = nn.Conv2d(D, n_height_bins, 1)
        self.gripper_head = nn.Conv2d(D, N_GRIPPER_BINS, 1)
        self.rotation_head = nn.Conv2d(D, 3 * N_ROT_BINS, 1)

    def forward(self, feat_block1, feat_block2, query_pixels=None):
        P = PARA_OUT_SIZE

        f1 = self.proj_block1(feat_block1)
        f1 = F.interpolate(f1, size=(P, P), mode='bilinear', align_corners=False)

        f2 = self.proj_block2(feat_block2)
        f2 = F.interpolate(f2, size=(P, P), mode='bilinear', align_corners=False)

        feats = torch.cat([f1, f2], dim=1)  # (B*T, 256, P, P)
        feats = self.feature_convs(feats)

        vol = self.volume_head(feats)  # (B*T, Nh, P, P)

        gripper_logits = rotation_logits = None
        if query_pixels is not None:
            BT = feats.shape[0]
            px = query_pixels[:, 0].long().clamp(0, P - 1)
            py = query_pixels[:, 1].long().clamp(0, P - 1)
            idx = torch.arange(BT, device=feats.device)

            grip_map = self.gripper_head(feats.detach())
            gripper_logits = grip_map[idx, :, py, px]

            rot_map = self.rotation_head(feats.detach())
            rot_at_px = rot_map[idx, :, py, px]
            rotation_logits = rot_at_px.view(BT, 3, N_ROT_BINS)

        return vol, gripper_logits, rotation_logits, feats


# ---------------------------------------------------------------------------
# SVD PARA Wrapper
# ---------------------------------------------------------------------------
class SVDParaPredictor(nn.Module):
    """Wraps SVD UNet + PARA heads for eval_multistage.py."""

    def __init__(self, checkpoint_dir, target_size=448, pred_size=PARA_OUT_SIZE,
                 svd_base=SVD_BASE_MODEL, device=None):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = N_WINDOW

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._device = device

        checkpoint_dir = Path(checkpoint_dir)

        # --- Load UNet with LoRA weights ---
        unet_dir = checkpoint_dir / "unet"
        print(f"[SVDPara] Loading UNet from {unet_dir}")
        self.unet = UNetSpatioTemporalConditionModel.from_pretrained(
            str(unet_dir), torch_dtype=torch.float16,
        )
        self.unet.eval()
        self.unet.requires_grad_(False)

        # --- Load PARA heads ---
        para_ckpt_path = checkpoint_dir / "para_checkpoint.pt"
        print(f"[SVDPara] Loading PARA heads from {para_ckpt_path}")
        para_ckpt = torch.load(str(para_ckpt_path), map_location="cpu")

        self.para_heads = ParaHeadsOnUNet(n_window=N_WINDOW).to(device)
        self.para_heads.load_state_dict(para_ckpt["para_heads"])
        self.para_heads.eval()

        # --- Store stats for eval_multistage to read ---
        self.stats = para_ckpt.get("stats", {})

        # --- Feature capture hooks ---
        self.captured = {}

        def make_hook(name):
            def hook_fn(module, inp, out):
                self.captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
            return hook_fn

        self.unet.up_blocks[1].register_forward_hook(make_hook("up_block_1"))
        self.unet.up_blocks[2].register_forward_hook(make_hook("up_block_2"))

        # --- Build SVD pipeline ---
        print(f"[SVDPara] Building SVD pipeline from {svd_base}")
        self.pipe = StableVideoDiffusionPipeline.from_pretrained(
            svd_base, unet=self.unet,
            torch_dtype=torch.float16, variant="fp16",
        )
        self.pipe.to(device)

        print("[SVDPara] Ready.")

    def _denorm_to_pil(self, img_tensor):
        """Convert ImageNet-normalized (1,3,H,W) tensor to PIL for SVD input."""
        mean = IMAGENET_MEAN.to(img_tensor.device)
        std = IMAGENET_STD.to(img_tensor.device)
        img = img_tensor * std + mean  # [0, 1]
        img = img.clamp(0, 1)
        img = img[0].permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
        img = (img * 255).astype(np.uint8)
        pil = Image.fromarray(img).resize((SVD_WIDTH, SVD_HEIGHT))
        return pil

    def forward(self, img_tensor, start_keypoint_2d, **kwargs):
        """Run SVD pipeline to capture UNet features, then run PARA heads.

        Args:
            img_tensor:        (1, 3, H, W) ImageNet-normalized
            start_keypoint_2d: (2,) or (1, 2) current EEF pixel (unused for SVD feature extraction)

        Returns:
            volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            None
            None
            feats:         (1, 256, pred_size, pred_size)
        """
        self.captured.clear()

        # Generate video via SVD pipeline (triggers UNet forward + hooks)
        pil_img = self._denorm_to_pil(img_tensor)
        with torch.inference_mode():
            _ = self.pipe(
                pil_img,
                height=SVD_HEIGHT, width=SVD_WIDTH,
                num_frames=SVD_NUM_FRAMES,
                decode_chunk_size=4,
                num_inference_steps=SVD_NUM_INFERENCE_STEPS,
            )

        # Extract captured features (from the last denoising step)
        feat1 = self.captured["up_block_1"]  # (T, 1280, 20, 36)
        feat2 = self.captured["up_block_2"]  # (T, 640, 40, 72)
        n_t = min(feat1.shape[0], N_WINDOW)

        # Run PARA heads on all timesteps
        with torch.no_grad():
            vol, _, _, feats_all = self.para_heads(feat1[:n_t], feat2[:n_t])
            # vol: (n_t, Nh, P, P), feats_all: (n_t, 256, P, P)

        # Reshape: (n_t, Nh, P, P) -> (1, n_t, Nh, P, P)
        volume_logits = vol.unsqueeze(0)

        # Use first timestep's features for predict_at_pixels
        feats = feats_all[:1]  # (1, 256, P, P)

        return volume_logits, None, None, feats

    def predict_at_pixels(self, feats, pred_pixels):
        """Index gripper/rotation heads at specified pixels.

        Args:
            feats:        (1, 256, pred_size, pred_size) from forward()
            pred_pixels:  (1, N_WINDOW, 2) pixel coords [x, y] in pred_size space

        Returns:
            gripper_logits:  (1, N_WINDOW, N_GRIPPER_BINS)
            rotation_logits: (1, N_WINDOW, 3, N_ROT_BINS)
        """
        B = feats.shape[0]
        N = pred_pixels.shape[1]
        P = self.pred_size

        px = pred_pixels[..., 0].long().clamp(0, P - 1)  # (B, N)
        py = pred_pixels[..., 1].long().clamp(0, P - 1)  # (B, N)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)

        grip_map = self.para_heads.gripper_head(feats.detach())  # (B, Ng, P, P)
        # grip_map is (B, Ng, P, P); we need (B, N, Ng)
        gripper_logits = grip_map[batch_idx, :, py, px]  # (B, N, Ng)

        rot_map = self.para_heads.rotation_head(feats.detach())  # (B, 3*Nr, P, P)
        rot_at_px = rot_map[batch_idx, :, py, px]  # (B, N, 3*Nr)
        rotation_logits = rot_at_px.view(B, N, 3, N_ROT_BINS)

        return gripper_logits, rotation_logits

    def load_state_dict(self, state_dict, strict=True):
        """No-op: weights are loaded in __init__ from checkpoint directory."""
        pass

    def to(self, device):
        """Move PARA heads to device. UNet/pipeline are already on device."""
        self.para_heads = self.para_heads.to(device)
        self._device = device
        return self
