"""Extract MoGE pointmaps for start frames in each episode."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from moge.model.v2 import MoGeModel
import utils3d

# Configuration
dataset_dir = "scratch/dataset"
device = torch.device("mps")

print("=" * 60)
print("MoGE Pointmap Extraction")
print("=" * 60)

# Load MoGE model once
print("Loading MoGE model...")
model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
model.eval()
print(f"Model loaded on {device}")

# Find all episode directories
episode_dirs = sorted([d for d in os.listdir(dataset_dir) 
                      if d.startswith('episode_') and 
                      os.path.isdir(os.path.join(dataset_dir, d))])

print(f"Found {len(episode_dirs)} episodes")

# Process each episode
for episode_name in tqdm(episode_dirs, desc="Processing episodes"):
    episode_dir = os.path.join(dataset_dir, episode_name)
    
    # Only process start frame
    start_img_path = os.path.join(episode_dir, "start.png")
    
    if not os.path.exists(start_img_path):
        print(f"  Skipping {episode_name}: missing start.png")
        continue
    
    print(f"\n{episode_name}: Processing start frame")
    
    # Load and preprocess image
    input_image = cv2.cvtColor(cv2.imread(start_img_path), cv2.COLOR_BGR2RGB)
    input_tensor = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
    
    # Run MoGE inference
    with torch.no_grad():
        output = model.infer(input_tensor)
    
    # Extract points and mask
    points = output["points"].cpu().numpy()  # (H, W, 3)
    mask = output["mask"].cpu().numpy() & ~utils3d.np.depth_map_edge(points[:,:,2], rtol=0.005)
    
    # Save pointmap as .pt file
    pointmap_path = os.path.join(episode_dir, "pointmap_start.pt")
    torch.save({
        'points': output["points"].cpu(),
        'mask': output["mask"].cpu()
    }, pointmap_path)
    print(f"  Saved pointmap: {pointmap_path}")
    
    # Create depth visualization
    plt.imsave(os.path.join(episode_dir, "depth_vis_start.png"), 1/points[:,:,2])

print("\n✓ Done!")
