"""Quick inline script to visualize a colored pointmap from an episode as a pixel-aligned image."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Configuration - modify these paths
episode_dir = "scratch/processed_grasp_dataset_keyboard/2aG7ky"  # Change this
start_image_path = os.path.join(episode_dir, "start.png")
pointmap_raw_path = os.path.join(episode_dir, "pointmap_start_raw.pt")

# Load start image to get H, W
start_image = cv2.imread(start_image_path)
start_image = cv2.cvtColor(start_image, cv2.COLOR_BGR2RGB)
if start_image.max() <= 1.0:
    start_image = (start_image * 255).astype(np.uint8)
H, W = start_image.shape[:2]

# Load raw pointmap (full HxW resolution)
pointmap_raw = torch.load(pointmap_raw_path)
points_flat = pointmap_raw["points"].cpu().numpy()  # (H*W, 3)
colors_flat = pointmap_raw["colors"].cpu().numpy()  # (H*W, 3)

# Reshape to (H, W, 3)
points_2d = points_flat.reshape(H, W, 3)  # (H, W, 3)
colors_2d = colors_flat.reshape(H, W, 3)  # (H, W, 3)

# Ensure colors are uint8 [0-255]
if colors_2d.dtype != np.uint8:
    if colors_2d.max() <= 1.0: 
        colors_2d = (colors_2d * 255).astype(np.uint8)
    else: 
        colors_2d = colors_2d.astype(np.uint8)

# Create figure with 2 subplots: original image and colored pointmap
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Plot start image
axes[0].imshow(start_image)
axes[0].set_title('Start Image')
axes[0].axis('off')

# Plot pixel-aligned colored pointmap
axes[1].imshow(colors_2d)
axes[1].set_title('Colored Pointmap (RGB)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"Visualized pointmap: {H}x{W} = {H*W} pixels from {episode_dir}")

