"""Collect DINOv3 features from all start images and fit PCA to 32 dimensions."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/dinov3")

import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from sklearn.decomposition import PCA
from tqdm import tqdm
import pickle

REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_COMPONENTS = 64 
N_LAYERS = 12

# Image resize transform to dimensions divisible by patch size
def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))

# Configuration
import argparse
parser = argparse.ArgumentParser(description="Get DINOv3 PCA embedder")
parser.add_argument("--processed_dir", "-i", type=str, required=True,
                    help="Input directory with parsed episodes")
args = parser.parse_args()
processed_dir = args.processed_dir
output_path = "scratch/dino_pca_embedder.pkl"

device = torch.device("mps")
sequence_ids = sorted([d for d in os.listdir(processed_dir) if os.path.isdir(os.path.join(processed_dir, d))])


# Load DINOv3 model once
print(f"\nLoading DINOv3 model from {WEIGHTS_PATH}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
with torch.inference_mode():
    dinov3_model = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
    print(f"Loaded DINOv3 model on {device}")



# Collect all features
all_features = []

print("\nExtracting features from start images...")
for seq_id in tqdm(sequence_ids, desc="Processing sequences"):
    start_img_path = os.path.join(processed_dir, seq_id, "000000.png")
    
    if not os.path.exists(start_img_path):
        print(f"  ⚠ Skipping {seq_id}: 000000.png not found")
        continue
    
    # Load and preprocess image
    img = Image.open(start_img_path).convert("RGB")
    image_resized = resize_transform(img)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = dinov3_model.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            all_features.append(x)

# Concatenate all features
print(f"\nCollected features from {len(all_features)} images")
all_features_array = np.concatenate(all_features, axis=0)  # (N_total_patches, D)
print(f"Total feature vectors: {all_features_array.shape[0]}")
print(f"Feature dimension: {all_features_array.shape[1]}")

# Fit PCA to 32 dimensions
print(f"\nFitting PCA to {N_COMPONENTS} dimensions...")
pca = PCA(n_components=N_COMPONENTS, whiten=True)
pca.fit(all_features_array)

print(f"Explained variance ratio: {pca.explained_variance_ratio_.sum():.4f}")
print(f"Cumulative explained variance:")
for i in range(0, N_COMPONENTS, 8):
    end_idx = min(i + 8, N_COMPONENTS)
    cumsum = pca.explained_variance_ratio_[:end_idx].sum()
    print(f"  Components 0-{end_idx-1}: {cumsum:.4f}")

# Save PCA embedder
print(f"\nSaving PCA embedder to {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
    pickle.dump({
        'pca': pca,
        'patch_size': PATCH_SIZE,
        'image_size': IMAGE_SIZE,
        'n_components': N_COMPONENTS,
        'n_layers': N_LAYERS,
    }, f)

print("=" * 60)
print("Done!")
print("=" * 60)

