"""Visualize MoGE pointcloud in robot frame with matplotlib 3D."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.widgets import Slider

# Configuration
from glob import glob
all_seqs = glob("scratch/processed_grasp_dataset_keyboard/*")
seq_i =0
for seq_i in range(len(all_seqs)):
    if seq_i<19:continue
    sequence_id = all_seqs[seq_i].split("/")[-1]
    print(seq_i,sequence_id)
    processed_dir = "scratch/processed_grasp_dataset_keyboard"
    sequence_dir = os.path.join(processed_dir, sequence_id)

    print("=" * 60)
    print("Loading Episode Data")
    print("=" * 60)

    # Load saved data
    #start_img_path = os.path.join(episode_dir, "start.png")
    #
    #ls scratch/processed_grasp_dataset_keyboard 
    #1Ady4y  3YDl5c  6SOpiq  bXeMfO  D7w52r  HBrHvT  ii0StQ  MSOAfZ  pN4oC5  snOaDo  uipqGA  zEbyqX  ZRqn0E
    #2aG7ky  54qg3z  bwSsO8  c9a7ht  GCHYCe  HiwBra  jczmij  neEvPo  ro728n  U7hjy9  UNIIlZ  zJ915B
    #$:ls scratch/processed_grasp_dataset_keyboard/1Ady4y 
    #dino_features_fps.pt            grasp.png                       human_mask.png                  moge_output_raw.pt              pointmap_start_cropped.pt       robot_mask.png
    #dino_features_vis.png           gripper_pose_grasp.npy          joint_states_grasp.npy          pointmap_start_cropped_fps.pt   pointmap_start.pt               start.png

    pointmap_path = f"scratch/processed_grasp_dataset_keyboard/{sequence_id}/pointmap_start.pt"

    # Load pointmap (already in robot frame from processing)
    pointmap_data = torch.load(pointmap_path)
    points = pointmap_data["points"].cpu().numpy()  # (P, 3) - already in robot frame
    colors = pointmap_data["colors"].cpu().numpy()  # (P, 3)

    print(f"  Pointmap shape: {points.shape}")
    print(f"  Points already in robot frame (from processed dataset)")

    # Pointcloud is already in robot frame from processing
    moge_aligned = points  # Already aligned
    valid_colors = colors  # Already extracted

    print("\n" + "=" * 60)
    print("Uniform downsampling")
    print("=" * 60)

    # Uniform downsampling
    uniform_ds_factor = 11
    moge_aligned_ds = moge_aligned[::uniform_ds_factor]
    valid_colors_ds = valid_colors[::uniform_ds_factor]
    print(f"Uniform downsampling ({uniform_ds_factor}x): {len(moge_aligned)} -> {len(moge_aligned_ds)} points")

    # Volume bounds (start with large volume)
    volume_bounds = {
        "x_min": -1.,
        "x_max": 0.1,
        "y_min": -.6,
        "y_max": -0.1,
        "z_min": 0.02,
        "z_max": 0.05,
    }

    def filter_pointcloud(x_min, x_max, y_min, y_max, z_min, z_max):
        """Filter pointcloud by volume bounds."""
        mask_filtered = (
            (moge_aligned_ds[:, 0] >= x_min) & (moge_aligned_ds[:, 0] <= x_max) &
            (moge_aligned_ds[:, 1] >= y_min) & (moge_aligned_ds[:, 1] <= y_max) &
            (moge_aligned_ds[:, 2] >= z_min) & (moge_aligned_ds[:, 2] <= z_max)
        )
        filtered_points = moge_aligned_ds[mask_filtered]
        filtered_colors = valid_colors_ds[mask_filtered] / 255.0  # Normalize to [0,1] for matplotlib
        return filtered_points, filtered_colors

    # Initial filtered pointcloud
    filtered_points, filtered_colors = filter_pointcloud(**volume_bounds)

    print("\n" + "=" * 60)
    print("Creating matplotlib 3D visualization")
    print("=" * 60)

    # Create figure with 3D subplot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Plot initial pointcloud
    scatter = ax.scatter(filtered_points[:, 0], filtered_points[:, 1], filtered_points[:, 2],
                        c=filtered_colors, s=1)

    ax.set_xlabel('X (m)')
    ax.set_ylabel('Y (m)')
    ax.set_zlabel('Z (m)')
    ax.set_box_aspect([1,1,1])

    # Add sliders


    plt.subplots_adjust(bottom=0.25)
    ax_x_min = plt.axes([0.1, 0.15, 0.35, 0.03])
    ax_x_max = plt.axes([0.55, 0.15, 0.35, 0.03])
    ax_y_min = plt.axes([0.1, 0.10, 0.35, 0.03])
    ax_y_max = plt.axes([0.55, 0.10, 0.35, 0.03])
    ax_z_min = plt.axes([0.1, 0.05, 0.35, 0.03])
    ax_z_max = plt.axes([0.55, 0.05, 0.35, 0.03])

    slider_x_min = Slider(ax_x_min, 'x_min', -2.0, 2.0, valinit=volume_bounds["x_min"], valstep=0.01)
    slider_x_max = Slider(ax_x_max, 'x_max', -2.0, 2.0, valinit=volume_bounds["x_max"], valstep=0.01)
    slider_y_min = Slider(ax_y_min, 'y_min', -2.0, 2.0, valinit=volume_bounds["y_min"], valstep=0.01)
    slider_y_max = Slider(ax_y_max, 'y_max', -2.0, 2.0, valinit=volume_bounds["y_max"], valstep=0.01)
    slider_z_min = Slider(ax_z_min, 'z_min', -2.0, 2.0, valinit=volume_bounds["z_min"], valstep=0.01)
    slider_z_max = Slider(ax_z_max, 'z_max', -2.0, 2.0, valinit=volume_bounds["z_max"], valstep=0.01)

    def update(val):
        """Update pointcloud based on slider values."""
        filtered_points, filtered_colors = filter_pointcloud(
            slider_x_min.val, slider_x_max.val,
            slider_y_min.val, slider_y_max.val,
            slider_z_min.val, slider_z_max.val
        )
        
        # Clear and redraw
        ax.clear()
        ax.scatter(filtered_points[:, 0], filtered_points[:, 1], filtered_points[:, 2],
                c=filtered_colors, s=1)
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_zlabel('Z (m)')
        ax.set_box_aspect([1,1,1])
        fig.canvas.draw_idle()

    # Register callbacks
    slider_x_min.on_changed(update)
    slider_x_max.on_changed(update)
    slider_y_min.on_changed(update)
    slider_y_max.on_changed(update)
    slider_z_min.on_changed(update)
    slider_z_max.on_changed(update)

    print(f"Initial filtered pointcloud: {len(filtered_points)} points")
    print("\nAdjust sliders to crop the pointcloud")
    plt.show()
