"""Add DINO features and visualizations to existing dataset episodes."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import functional as TF
from PIL import Image
from tqdm import tqdm
import argparse
import pickle
from pathlib import Path

# Configuration constants
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12


def process_timestep(episode_dir, timestep_str, model_dino, pca_embedder, device):
    """Process a single timestep: extract and save DINO features."""
    # File paths
    image_path = episode_dir / f"{timestep_str}.png"
    
    if not image_path.exists():
        print(f"  ⚠ No image file found for timestep {timestep_str}")
        return False
    
    # Load image
    rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    
    H, W = rgb.shape[:2]
    
    # Extract and save DINO features (skip if already exists)
    dino_features_path = episode_dir / f"dino_features_{timestep_str}.pt"
    dino_features_hw_path = episode_dir / f"dino_features_hw_{timestep_str}.pt"
    dino_vis_path = episode_dir / f"dino_features_vis_{timestep_str}.png"
    
    if dino_features_path.exists() and dino_features_hw_path.exists() and dino_vis_path.exists():
        print(f"    ⏭ Skipping DINO feature extraction (dino_features_{timestep_str}.pt, dino_features_hw_{timestep_str}.pt and dino_features_vis_{timestep_str}.png exist)")
        return True
    
    # Track pixel coordinates for all points (use all pixels)
    y_coords, x_coords = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
    y_coords_flat = y_coords.reshape(-1)
    x_coords_flat = x_coords.reshape(-1)
    valid_y_coords = y_coords_flat  # Use all pixels
    valid_x_coords = x_coords_flat  # Use all pixels
    
    dino_features, dino_features_hw, pca_features_patches = run_dino_features(
        rgb, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords
    )
    
    # Save DINO features
    torch.save(
        torch.from_numpy(dino_features.astype(np.float32)),
        dino_features_path
    )
    
    # Save full resolution features
    torch.save(
        torch.from_numpy(dino_features_hw.astype(np.float32)),
        dino_features_hw_path
    )
    
    # Visualize and save
    visualize_dino_features(rgb, pca_features_patches, H, W, episode_dir, timestep_str)
    print(f"    ✓ Saved DINO features: {len(dino_features)} points, {dino_features_hw.shape} full resolution")
    
    return True


def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))


def run_dino_features(rgb_scene, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords):
    """Extract DINO features for all valid points and full HxW resolution."""
    # Load and preprocess image for DINO
    img_pil = Image.fromarray(rgb_scene).convert("RGB")
    image_resized = resize_transform(img_pil)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract DINO features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = model_dino.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            # Apply PCA to reduce to 32 dimensions
            pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
            
            # Get patch resolution
            h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
            pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)
            
            # Upsample features to full image resolution
            pca_features_tensor = torch.from_numpy(pca_features_patches).permute(2, 0, 1).float()  # (32, H_patches, W_patches)
            pca_features_upsampled = TF.resize(
                pca_features_tensor,
                (H, W),
                interpolation=TF.InterpolationMode.BILINEAR
            ).permute(1, 2, 0).numpy()  # (H, W, 32)
            
            # Sample features for each point using its pixel coordinate
            # For now, we'll return the patch-level features and upsampled features
            # The valid_y_coords and valid_x_coords are not used in this simplified version
    
    return pca_features_tensor.numpy(), pca_features_upsampled, pca_features_patches


def visualize_dino_features(rgb_scene, pca_features_patches, H, W, episode_dir, timestep_str):
    """Visualize DINO features and save visualization."""
    # Create visualization: map first 3 PCA components to RGB
    pca_rgb = pca_features_patches[:, :, :3]
    pca_rgb_normalized = torch.nn.functional.sigmoid(torch.from_numpy(pca_rgb).mul(2.0)).numpy()
    pca_rgb_upsampled = TF.resize(
        torch.from_numpy(pca_rgb_normalized).permute(2, 0, 1).float(),
        (H, W),
        interpolation=TF.InterpolationMode.BILINEAR
    ).permute(1, 2, 0).numpy()
    
    # Create visualization with RGB overlay
    img_normalized = rgb_scene.astype(float) / 255.0
    overlay = img_normalized * 0.5 + pca_rgb_upsampled * 0.5
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(rgb_scene)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(pca_rgb_upsampled)
    axes[1].set_title('DINO PCA (first 3 components)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay: RGB + DINO')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(episode_dir / f"dino_features_vis_{timestep_str}.png", dpi=150, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add DINO features to existing dataset episodes")
    parser.add_argument("--input_dir", "-i", type=str, required=True,
                        help="Input directory with parsed episodes")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    
    # DINOv3 configuration
    REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
    WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
    
    print("=" * 60)
    print("Finding all episodes")
    print("=" * 60)
    
    # Find all episode directories
    episode_dirs = sorted([d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
    
    print(f"Found {len(episode_dirs)} episodes")
    for ep_dir in episode_dirs:
        print(f"  - {ep_dir.name}")
    
    # Load models once
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    # Load DINOv3 model and PCA embedder
    print("Loading DINOv3 model...")
    with torch.inference_mode():
        model_dino = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
        model_dino.eval()
    
    pca_path = "scratch/dino_pca_embedder.pkl"
    if not os.path.exists(pca_path):
        raise FileNotFoundError(f"PCA embedder not found: {pca_path}. Please run get_dino_pca_emb.py first.")

    print(f"Loading PCA embedder from {pca_path}")
    with open(pca_path, 'rb') as f:
        pca_data = pickle.load(f)
        pca_embedder = pca_data['pca']
        print(f"PCA embedder: {pca_embedder.n_components_} dimensions")
    
    print("\n" + "=" * 60)
    print("Processing episodes")
    print("=" * 60)
    
    # Process each episode
    for episode_dir in tqdm(episode_dirs, desc="Processing episodes"):
        print(f"\n{'='*60}")
        print(f"Processing episode: {episode_dir.name}")
        print(f"{'='*60}")
        
        # Find all image files in this episode
        image_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
        
        if len(image_files) == 0:
            print(f"  ⚠ No images found in {episode_dir.name}, skipping")
            continue
        
        # Process each timestep
        import pdb; pdb.<tab>et_trace()
        timesteps_processed = 0
        timesteps_failed = 0
        
        for img_path in tqdm(image_files, desc=f"  {episode_dir.name}", leave=False):
            timestep_str = img_path.stem
            
            success = process_timestep(
                episode_dir,
                timestep_str,
                model_dino,
                pca_embedder,
                device
            )
            
            if success:
                timesteps_processed += 1
            else:
                timesteps_failed += 1
        
        print(f"  ✓ Processed {timesteps_processed} timesteps, {timesteps_failed} failed")
    
    print(f"\n{'='*60}")
    print(f"Done! Processed {len(episode_dirs)} episodes")
    print(f"Input directory: {input_dir}")
    print(f"{'='*60}")
