"""Model that uses DINOv3 backbone with LoRA fine-tuning for heatmap predictions."""
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINO configuration
DINO_REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
DINO_WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
DINO_PATCH_SIZE = 16
DINO_N_LAYERS = 12
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

WINDOW_SIZE = 3  # Extrema trajectory: close, open, end


class LoRALayer(nn.Module):
    """LoRA (Low-Rank Adaptation) layer for efficient fine-tuning."""
    def __init__(self, original_layer: nn.Linear, rank: int = 8, alpha: float = 16.0, dropout: float = 0.0):
        """
        Args:
            original_layer: The original linear layer to adapt
            rank: LoRA rank (lower = fewer parameters)
            alpha: LoRA scaling factor (typically rank * 2)
            dropout: Dropout probability for LoRA layers
        """
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Store in_features and out_features for compatibility
        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features
        
        # Freeze original layer
        for param in original_layer.parameters():
            param.requires_grad = False
        
        # Create LoRA matrices
        self.lora_A = nn.Parameter(torch.randn(rank, self.in_features) * 0.02)
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
    def forward(self, x):
        # Original output
        original_out = self.original_layer(x)
        
        # LoRA adaptation: x @ A.T @ B.T * scaling
        lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
        
        return original_out + lora_out


def apply_lora_to_linear(module: nn.Module, name: str = "", rank: int = 8, alpha: float = 16.0, 
                       target_modules: list = None) -> None:
    """
    Recursively apply LoRA to linear layers matching target_modules.
    
    Args:
        module: Module to process
        name: Name of the module (for filtering)
        rank: LoRA rank
        alpha: LoRA alpha scaling
        target_modules: List of module name patterns to apply LoRA to (e.g., ['qkv', 'proj', 'fc1', 'fc2'])
    """
    if target_modules is None:
        target_modules = ['qkv', 'proj', 'fc1', 'fc2']
    
    # Get all named children to avoid modification during iteration
    children = list(module.named_children())
    
    for child_name, child in children:
        full_name = f"{name}.{child_name}" if name else child_name
        
        # Check if this is a target linear layer
        if isinstance(child, nn.Linear) and any(target in child_name for target in target_modules):
            # Replace with LoRA layer
            lora_layer = LoRALayer(child, rank=rank, alpha=alpha)
            setattr(module, child_name, lora_layer)
        else:
            # Recursively apply to children
            apply_lora_to_linear(child, full_name, rank=rank, alpha=alpha, target_modules=target_modules)


class DINOHeatmapPredictorLoRA(nn.Module):
    """Uses DINOv3 backbone with LoRA fine-tuning for heatmap predictions.
    
    - Uses DINOv3 backbone with LoRA adapters on attention and MLP layers
    - Interprets first window_size dimensions of DINO patch features as heatmap predictions
    - Applies inverse depth transformation and upsamples to target resolution
    - Only LoRA parameters are trainable (much fewer than full fine-tuning)
    """
    def __init__(self, window_size=3, target_size=224, dino_image_size=768, lora_rank=8, lora_alpha=16.0):
        super().__init__()
        self.window_size = window_size
        self.target_size = target_size
        self.dino_image_size = dino_image_size
        self.patch_size = DINO_PATCH_SIZE
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        
        # Load DINOv3 model (will be moved to device later)
        print("Loading DINOv3 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR, 
            'dinov3_vits16plus', 
            source='local', 
            weights=DINO_WEIGHTS_PATH
        )
        
        # Freeze all DINO parameters first
        for param in self.dino.parameters():
            param.requires_grad = False
        
        # Apply LoRA to attention and MLP layers
        print(f"Applying LoRA (rank={lora_rank}, alpha={lora_alpha}) to DINO backbone...")
        apply_lora_to_linear(
            self.dino, 
            name="", 
            rank=lora_rank, 
            alpha=lora_alpha,
            target_modules=['qkv', 'proj', 'fc1', 'fc2']
        )
        
        # Count LoRA parameters
        lora_params = sum(p.numel() for p in self.dino.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.dino.parameters())
        print(f"✓ LoRA applied: {lora_params:,} trainable parameters ({100*lora_params/total_params:.2f}% of total)")
        print(f"✓ Using first {window_size} DINO feature dimensions as heatmap predictions")
        
        # Add learnable timestep embeddings for height/gripper prediction
        embed_dim = self.dino.embed_dim
        self.timestep_embedding = nn.Embedding(window_size, embed_dim)
        
        # Tiny MLP to regress height and gripper from pooled features + timestep embedding
        self.height_gripper_head = nn.Sequential(
            nn.Linear(embed_dim * 2, 128),  # pooled features + timestep embedding
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 2)  # height and gripper for this timestep
        )
        print(f"✓ Added timestep embeddings and MLP head for height/gripper prediction from pooled patch features")
    
    def to(self, device):
        """Override to() to also move DINO model to device."""
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self
    
    def _extract_dino_features_batch(self, x):
        """
        Re-implement DINO forward pass to ensure gradients flow properly.
        This directly calls the internal methods to avoid any gradient blocking.
        
        Args:
            x: (B, 3, H, W) - Normalized image tensor
        
        Returns:
            patch_features: (B, D, H_patches, W_patches) - Patch features from last layer
        """
        # Prepare tokens: patch embedding + CLS/storage tokens
        x, (H, W) = self.dino.prepare_tokens_with_masks(x)
        
        # Run through all transformer blocks
        for i, blk in enumerate(self.dino.blocks):
            if self.dino.rope_embed is not None:
                rope_sincos = self.dino.rope_embed(H=H, W=W)
            else:
                rope_sincos = None
            x = blk(x, rope_sincos)
        
        # Apply normalization
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls_reg = self.dino.cls_norm(x[:, : self.dino.n_storage_tokens + 1])
            x_norm_patch = self.dino.norm(x[:, self.dino.n_storage_tokens + 1 :])
            x_norm = torch.cat((x_norm_cls_reg, x_norm_patch), dim=1)
        else:
            x_norm = self.dino.norm(x)
        
        # Extract CLS token (first token after storage tokens)
        cls_token = x_norm[:, 0]  # (B, D) - CLS token
        
        # Extract patch tokens (skip CLS and storage tokens)
        patch_tokens = x_norm[:, self.dino.n_storage_tokens + 1 :]  # (B, N_patches, D)
        
        # Reshape to spatial format: (B, D, H_patches, W_patches)
        B = patch_tokens.shape[0]
        patch_features = patch_tokens.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        return patch_features, cls_token
    
    def forward(self, rgb_image):
        """
        Args:
            rgb_image: (B, 3, H, W) - RGB image tensor in [0, 1] range
        
        Returns:
            heatmap_logits: (B, window_size, target_size, target_size) - Heatmap logits (inverse depth)
            dino_output: Dictionary with patch features and other info for visualization
        """
        B, _, H_in, W_in = rgb_image.shape
        
        # Resize to DINO input size (divisible by patch_size) using tensor operations
        # Match the logic from resize_transform: h_patches = image_size / patch_size, w_patches = (w * image_size) / (h * patch_size)
        h_patches = int(self.dino_image_size / self.patch_size)
        w_patches = int((W_in * self.dino_image_size) / (H_in * self.patch_size))
        target_h = h_patches * self.patch_size
        target_w = w_patches * self.patch_size
        
        # Resize using bilinear interpolation
        x_resized = F.interpolate(
            rgb_image,
            size=(target_h, target_w),
            mode='bilinear',
            align_corners=False
        )
        
        # Normalize: rgb_image is in [0, 1], convert to ImageNet normalized
        mean = torch.tensor(IMAGENET_MEAN, device=rgb_image.device).view(1, 3, 1, 1)
        std = torch.tensor(IMAGENET_STD, device=rgb_image.device).view(1, 3, 1, 1)
        x_norm = (x_resized - mean) / std
        
        # Extract DINO features using our custom forward pass
        patch_features, cls_token = self._extract_dino_features_batch(x_norm)  # (B, D, H_patches, W_patches), (B, D)

        # Get patch dimensions
        _, D, H_patches, W_patches = patch_features.shape
        
        # Check that we have enough feature dimensions
        if D < self.window_size:
            raise ValueError(f"DINO features have {D} dimensions, but need at least {self.window_size} for window_size={self.window_size}")
        
        # Take first window_size dimensions as heatmap predictions
        heatmap_logits_patches = patch_features[:, :self.window_size, :, :]  # (B, window_size, H_patches, W_patches)
        
        # Upsample from patch space to target_size
        if heatmap_logits_patches.shape[-2] != self.target_size or heatmap_logits_patches.shape[-1] != self.target_size:
            heatmap_logits = F.interpolate(
                heatmap_logits_patches,
                size=(self.target_size, self.target_size),
                mode='bilinear',
                align_corners=False
            )  # (B, window_size, target_size, target_size)
        else:
            heatmap_logits = heatmap_logits_patches
        
        # Predict height and gripper using softmax-weighted pooling of patch features per timestep
        # Reshape patch_features for pooling: (B, D, H_patches, W_patches)
        B, D, H_patches, W_patches = patch_features.shape
        
        # For each timestep, use softmax of heatmap logits to pool patch features
        heights_pred_list = []
        grippers_pred_list = []
        
        # Create timestep indices
        timestep_indices = torch.arange(self.window_size, device=patch_features.device)  # (window_size,)
        
        for t in range(self.window_size):
            # Get heatmap logits for this timestep (in patch space)
            heatmap_t_patches = heatmap_logits_patches[:, t, :, :]  # (B, H_patches, W_patches)
            
            # Apply softmax to get attention weights
            heatmap_t_flat = heatmap_t_patches.reshape(B, H_patches * W_patches)  # (B, H_patches * W_patches)
            attention_weights = F.softmax(heatmap_t_flat, dim=1)  # (B, H_patches * W_patches)
            attention_weights = attention_weights.reshape(B, 1, H_patches, W_patches)  # (B, 1, H_patches, W_patches)
            
            # Weighted pool patch features using attention
            pooled_features = (patch_features * attention_weights).sum(dim=(2, 3))  # (B, D)
            
            # Get timestep embedding
            timestep_emb = self.timestep_embedding(timestep_indices[t])  # (D,)
            timestep_emb = timestep_emb.unsqueeze(0).expand(B, -1)  # (B, D)
            
            # Concatenate pooled features and timestep embedding
            combined_features = torch.cat([pooled_features, timestep_emb], dim=1)  # (B, D * 2)
            
            # Predict height and gripper for this timestep
            height_gripper_t = self.height_gripper_head(combined_features)  # (B, 2)
            heights_pred_list.append(height_gripper_t[:, 0])  # (B,)
            grippers_pred_list.append(height_gripper_t[:, 1])  # (B,)
        
        # Stack predictions
        heights_pred = torch.stack(heights_pred_list, dim=1)  # (B, window_size)
        grippers_pred = torch.stack(grippers_pred_list, dim=1)  # (B, window_size)
        
        # Apply normalization to [0, 1] (using *0.5+0.5 instead of sigmoid)
        heights_pred = (heights_pred) * 0.5 + 0.5  # (B, window_size)
        grippers_pred = (grippers_pred) * 0.5 + 0.5  # (B, window_size)
        
        # Prepare output dict for compatibility
        dino_output_dict = {
            'patch_features': patch_features,  # (B, D, H_patches, W_patches)
            'H_patches': H_patches,
            'W_patches': W_patches,
            'heights_pred': heights_pred,  # (B, window_size)
            'grippers_pred': grippers_pred,  # (B, window_size)
        }
        
        return heatmap_logits, dino_output_dict
    
    def get_lora_state_dict(self):
        """Get state dict containing only LoRA parameters."""
        lora_state = {}
        for name, param in self.named_parameters():
            if param.requires_grad and ('lora_A' in name or 'lora_B' in name):
                lora_state[name] = param
        return lora_state
    
    def load_lora_state_dict(self, lora_state_dict):
        """Load LoRA parameters from state dict."""
        model_dict = self.state_dict()
        for name, param in lora_state_dict.items():
            if name in model_dict:
                model_dict[name].copy_(param)
            else:
                print(f"Warning: LoRA parameter {name} not found in model")
