"""SVD + Global Action wrapper for eval_multistage.py.

Wraps the SVD UNet (with LoRA) + GlobalActionHead into the ACT-style interface
expected by eval_multistage.py:
    pos_pred, rot_pred, gripper_pred = model(img_tensor, start_kp,
                                              current_eef_pos=..., current_gripper=...)
"""

import sys
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

# Make SVD modules importable
SVD_ROOT = "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA"
if SVD_ROOT not in sys.path:
    sys.path.insert(0, SVD_ROOT)

from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

# ---------------------------------------------------------------------------
# Constants (must match training script)
# ---------------------------------------------------------------------------
PROJ_DIM = 128
N_GRIPPER_BINS = 32
N_WINDOW = 1  # SVD global predicts 1 timestep per forward pass

SVD_BASE_MODEL = os.path.join(SVD_ROOT, "checkpoints/stable-video-diffusion-img2vid-xt-1-1")
SVD_HEIGHT = 320
SVD_WIDTH = 576
SVD_NUM_FRAMES = 7
SVD_NUM_INFERENCE_STEPS = 25

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)


# ---------------------------------------------------------------------------
# GlobalActionHead (copied from train_svd_global_action_regressor.py)
# ---------------------------------------------------------------------------
class GlobalActionHead(nn.Module):
    """Global regression head on SVD UNet features."""

    def __init__(self, proj_dim=PROJ_DIM, hidden_dim=256):
        super().__init__()
        self.proj_block1 = nn.Conv2d(1280, proj_dim, 1)
        self.proj_block2 = nn.Conv2d(640, proj_dim, 1)

        D = proj_dim * 2  # 256
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.position_mlp = nn.Sequential(
            nn.Linear(D, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, 3),
        )
        self.gripper_mlp = nn.Sequential(
            nn.Linear(D, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, N_GRIPPER_BINS),
        )

    def forward(self, feat_block1, feat_block2):
        P = 64

        f1 = self.proj_block1(feat_block1)
        f1 = F.interpolate(f1, size=(P, P), mode='bilinear', align_corners=False)

        f2 = self.proj_block2(feat_block2)
        f2 = F.interpolate(f2, size=(P, P), mode='bilinear', align_corners=False)

        feats = torch.cat([f1, f2], dim=1)
        feats = self.feature_convs(feats)

        pooled = self.pool(feats).squeeze(-1).squeeze(-1)

        pos_pred = self.position_mlp(pooled)   # (B*T, 3)
        grip_logits = self.gripper_mlp(pooled)  # (B*T, N_GRIPPER_BINS)

        return pos_pred, grip_logits


# ---------------------------------------------------------------------------
# SVD Global Action Wrapper
# ---------------------------------------------------------------------------
class SVDGlobalPredictor(nn.Module):
    """Wraps SVD UNet + GlobalActionHead for eval_multistage.py ACT-style interface."""

    def __init__(self, checkpoint_dir, target_size=448,
                 svd_base=SVD_BASE_MODEL, device=None):
        super().__init__()
        self.target_size = target_size
        self.n_window = N_WINDOW

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._device = device

        checkpoint_dir = Path(checkpoint_dir)

        # --- Load UNet with LoRA weights ---
        unet_dir = checkpoint_dir / "unet"
        print(f"[SVDGlobal] Loading UNet from {unet_dir}")
        self.unet = UNetSpatioTemporalConditionModel.from_pretrained(
            str(unet_dir), torch_dtype=torch.float16,
        )
        self.unet.eval()
        self.unet.requires_grad_(False)

        # --- Load action head ---
        action_ckpt_path = checkpoint_dir / "action_checkpoint.pt"
        print(f"[SVDGlobal] Loading action head from {action_ckpt_path}")
        action_ckpt = torch.load(str(action_ckpt_path), map_location="cpu")

        self.action_head = GlobalActionHead().to(device)
        self.action_head.load_state_dict(action_ckpt["action_head"])
        self.action_head.eval()

        # --- Store stats for eval_multistage to read ---
        self.stats = action_ckpt.get("stats", {})

        # --- Feature capture hooks ---
        self.captured = {}

        def make_hook(name):
            def hook_fn(module, inp, out):
                self.captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
            return hook_fn

        self.unet.up_blocks[1].register_forward_hook(make_hook("up_block_1"))
        self.unet.up_blocks[2].register_forward_hook(make_hook("up_block_2"))

        # --- Build SVD pipeline ---
        print(f"[SVDGlobal] Building SVD pipeline from {svd_base}")
        self.pipe = StableVideoDiffusionPipeline.from_pretrained(
            svd_base, unet=self.unet,
            torch_dtype=torch.float16, variant="fp16",
        )
        self.pipe.to(device)

        print("[SVDGlobal] Ready.")

    def _denorm_to_pil(self, img_tensor):
        """Convert ImageNet-normalized (1,3,H,W) tensor to PIL for SVD input."""
        mean = IMAGENET_MEAN.to(img_tensor.device)
        std = IMAGENET_STD.to(img_tensor.device)
        img = img_tensor * std + mean  # [0, 1]
        img = img.clamp(0, 1)
        img = img[0].permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
        img = (img * 255).astype(np.uint8)
        pil = Image.fromarray(img).resize((SVD_WIDTH, SVD_HEIGHT))
        return pil

    def forward(self, img_tensor, start_keypoint_2d, current_eef_pos=None,
                current_gripper=None, **kwargs):
        """Run SVD pipeline to capture UNet features, then run global action head.

        Args:
            img_tensor:        (1, 3, H, W) ImageNet-normalized
            start_keypoint_2d: (2,) or (1, 2) current EEF pixel (unused for SVD)
            current_eef_pos:   (1, 3) normalized [0,1] EEF position (unused by this head)
            current_gripper:   (1, 1) normalized [0,1] gripper state (unused by this head)

        Returns:
            pos_pred:      (1, N_WINDOW, 3) normalized [0, 1] position
            rot_pred:      (1, N_WINDOW, 3) zeros (rotation not predicted)
            gripper_pred:  (1, N_WINDOW) gripper logits (argmax of binned prediction)
        """
        self.captured.clear()

        # Generate video via SVD pipeline (triggers UNet forward + hooks)
        pil_img = self._denorm_to_pil(img_tensor)
        with torch.inference_mode():
            _ = self.pipe(
                pil_img,
                height=SVD_HEIGHT, width=SVD_WIDTH,
                num_frames=SVD_NUM_FRAMES,
                decode_chunk_size=4,
                num_inference_steps=SVD_NUM_INFERENCE_STEPS,
            )

        # Extract captured features (from the last denoising step)
        feat1 = self.captured["up_block_1"]  # (B*T, 1280, 20, 36)
        feat2 = self.captured["up_block_2"]  # (B*T, 640, 40, 72)

        # Use only the first frame's features
        feat1 = feat1[0:1]  # (1, 1280, 20, 36)
        feat2 = feat2[0:1]  # (1, 640, 40, 72)

        # Run global action head
        with torch.no_grad():
            pos_raw, grip_logits = self.action_head(feat1, feat2)
            # pos_raw: (1, 3) -- raw MLP output
            # grip_logits: (1, N_GRIPPER_BINS)

        # Sigmoid to normalize position to [0, 1]
        pos_pred = torch.sigmoid(pos_raw)  # (1, 3)
        pos_pred = pos_pred.unsqueeze(1)   # (1, N_WINDOW, 3)

        # No rotation prediction -- return zeros
        rot_pred = torch.zeros(1, self.n_window, 3, device=pos_pred.device)

        # Gripper: convert bin logits to a single scalar logit
        # Positive logit = close, negative = open (use mean of top-half bins minus bottom-half)
        half = N_GRIPPER_BINS // 2
        gripper_pred = grip_logits[:, half:].sum(dim=-1) - grip_logits[:, :half].sum(dim=-1)
        gripper_pred = gripper_pred.unsqueeze(1)  # (1, N_WINDOW)

        return pos_pred, rot_pred, gripper_pred

    def load_state_dict(self, state_dict, strict=True):
        """No-op: weights are loaded in __init__ from checkpoint directory."""
        pass

    def to(self, device):
        """Move action head to device. UNet/pipeline are already on device."""
        self.action_head = self.action_head.to(device)
        self._device = device
        return self
