"""
LoRA (Low-Rank Adaptation) for video diffusion models.
Injects trainable low-rank matrices into attention linear layers (to_q, to_k, to_v, to_out).
"""
import math
from typing import List, Optional, Set

import torch
import torch.nn as nn


class LoRALinear(nn.Module):
    """Wraps an nn.Linear with a low-rank update: out = linear(x) + (x @ A @ B) * scale."""

    def __init__(
        self,
        linear: nn.Linear,
        rank: int = 4,
        scale: float = 1.0,
    ):
        super().__init__()
        self.linear = linear
        in_features = linear.in_features
        out_features = linear.out_features
        self.rank = min(rank, in_features, out_features)
        self.scale = scale
        self.lora_A = nn.Parameter(torch.zeros(self.rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, self.rank))
        self._init_lora()

    def _init_lora(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.linear(x)
        out = out + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
        return out


def _get_inner_model(model: nn.Module) -> nn.Module:
    """Unwrap OpenAIWrapper to get the actual diffusion model."""
    if hasattr(model, "diffusion_model"):
        return model.diffusion_model
    return model


def _find_attention_linears(module: nn.Module, prefix: str = "") -> List[tuple]:
    """Recursively find CrossAttention / MemoryEfficientCrossAttention and their to_q, to_k, to_v, to_out."""
    results = []
    for name, child in module.named_children():
        full_name = f"{prefix}.{name}" if prefix else name
        if "CrossAttention" in type(child).__name__ or "CrossAttention" in str(type(child)):
            for attn_name in ["to_q", "to_k", "to_v", "to_out"]:
                if hasattr(child, attn_name):
                    sub = getattr(child, attn_name)
                    if isinstance(sub, nn.Linear):
                        results.append((child, attn_name, sub, f"{full_name}.{attn_name}"))
                    elif isinstance(sub, nn.Sequential) and len(sub) > 0 and isinstance(sub[0], nn.Linear):
                        results.append((child, attn_name, sub[0], f"{full_name}.{attn_name}[0]"))
        else:
            results.extend(_find_attention_linears(child, full_name))
    return results


def inject_lora(
    model: nn.Module,
    rank: int = 4,
    scale: float = 1.0,
    target_modules: Optional[Set[str]] = None,
) -> List[nn.Parameter]:
    """
    Inject LoRA into attention linears (to_q, to_k, to_v, to_out) and freeze the rest.
    Returns list of LoRA parameters (for optimizer).
    """
    model = _get_inner_model(model)
    pairs = _find_attention_linears(model)
    lora_params = []
    for parent, attr_name, linear_layer, full_name in pairs:
        if target_modules is not None and full_name not in target_modules:
            continue
        lora_layer = LoRALinear(linear_layer, rank=rank, scale=scale)
        if isinstance(getattr(parent, attr_name), nn.Sequential):
            seq = getattr(parent, attr_name)
            new_seq = nn.Sequential(lora_layer, *list(seq.children())[1:])
            setattr(parent, attr_name, new_seq)
        else:
            setattr(parent, attr_name, lora_layer)
        lora_params.extend([lora_layer.lora_A, lora_layer.lora_B])
    for p in model.parameters():
        p.requires_grad = False
    for p in lora_params:
        p.requires_grad = True
    return lora_params


def save_lora_state_dict(model: nn.Module, path: str, rank: Optional[int] = None):
    """Save only LoRA parameters (lora_A, lora_B) for all LoRALinear layers. Optionally save rank."""
    model = _get_inner_model(model)
    state = {}
    for name, mod in model.named_modules():
        if isinstance(mod, LoRALinear):
            state[f"{name}.lora_A"] = mod.lora_A.cpu()
            state[f"{name}.lora_B"] = mod.lora_B.cpu()
    if rank is not None:
        state["_lora_rank"] = rank
    torch.save(state, path)


def load_lora_state_dict(model: nn.Module, path: str, device=None):
    """Load LoRA state dict into model. If model has no LoRALinear layers, inject LoRA first using saved rank."""
    state = torch.load(path, map_location=device)
    rank = state.pop("_lora_rank", 4)
    inner = _get_inner_model(model)
    has_lora = any(isinstance(m, LoRALinear) for m in inner.modules())
    if not has_lora:
        inject_lora(model, rank=rank)
        inner = _get_inner_model(model)
    # Use model's device so loaded LoRA params match (newly injected params default to CPU)
    if device is not None:
        target_device = torch.device(device) if isinstance(device, str) else device
    else:
        target_device = next(inner.parameters()).device
    for name, mod in inner.named_modules():
        if isinstance(mod, LoRALinear):
            key_a = f"{name}.lora_A"
            key_b = f"{name}.lora_B"
            if key_a in state:
                mod.lora_A.data = state[key_a].to(target_device)
            if key_b in state:
                mod.lora_B.data = state[key_b].to(target_device)
