"""Visualize 3D offset fields from pointmap to keypoints."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
import cv2
import matplotlib.pyplot as plt
import trimesh
from scipy.spatial.transform import Rotation as R

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0

parser = argparse.ArgumentParser()
parser.add_argument("--sequence_id", type=str, required=True, help="Sequence ID to visualize")
parser.add_argument("--processed_dir", type=str, default="scratch/processed_grasp_dataset_keyboard", 
                    help="Directory with processed episodes")
args = parser.parse_args()

# Configuration
processed_dir = args.processed_dir
sequence_id = args.sequence_id
sequence_dir = os.path.join(processed_dir, sequence_id)

print("=" * 60)
print(f"Loading episode: {sequence_id}")
print("=" * 60)

# Load robot-aligned pointmap
pointmap_path = os.path.join(sequence_dir, "pointmap_start.pt")
pointmap = torch.load(pointmap_path)

points = pointmap["points"].cpu().numpy()  # (N, 3) in robot frame
colors = pointmap["colors"].cpu().numpy()  # (N, 3) RGB colors

# Ensure colors are uint8 [0-255]
if colors.dtype != np.uint8:
    if colors.max() <= 1.0: colors = (colors * 255).astype(np.uint8)
    else: colors = colors.astype(np.uint8)

# Load gripper pose
gripper_pose_path = os.path.join(sequence_dir, "gripper_pose_grasp.npy")
gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix

# Transform keypoints from gripper local frame to robot frame
gripper_rot = gripper_pose[:3, :3]
gripper_pos = gripper_pose[:3, 3]
keypoints_robot = (gripper_rot @ KEYPOINTS_LOCAL_M.T).T + gripper_pos.reshape(1, 3)

print(f"Loaded {len(points)} points")
print(f"Keypoints in robot frame:")
for i, kp in enumerate(keypoints_robot):
    print(f"  KP {i}: [{kp[0]:.4f}, {kp[1]:.4f}, {kp[2]:.4f}]")

# Compute offset field for keypoint 4 only
kp_idx = 4
keypoint = keypoints_robot[kp_idx]
# offset[i] = keypoint - points[i] for each point
offsets = keypoint.reshape(1, 3) - points  # (N, 3)
print(f"  KP {kp_idx} offset range: [{offsets.min():.4f}, {offsets.max():.4f}]")

# Compute inverse distance (1 / distance) for color intensity
distances = np.linalg.norm(offsets, axis=1)  # (N,)
# Avoid division by zero
distances = np.maximum(distances, 1e-6)
inverse_distances = 1.0 / distances  # (N,)

# Normalize inverse distances to [0, 1] for color mapping
inv_dist_min, inv_dist_max = inverse_distances.min(), inverse_distances.max()
if inv_dist_max > inv_dist_min:
    inv_dist_normalized = (inverse_distances - inv_dist_min) / (inv_dist_max - inv_dist_min)
else:
    inv_dist_normalized = np.ones_like(inverse_distances)

# Map to grayscale intensity (white = close, black = far)
intensity_colors = (inv_dist_normalized.reshape(-1, 1) * 255).astype(np.uint8)
intensity_colors = np.repeat(intensity_colors, 3, axis=1)  # (N, 3) grayscale RGB

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser visualization")
print("=" * 60)
server = viser.ViserServer()

# Add original pointcloud
server.scene.add_point_cloud(
    name="/pointmap_original",
    points=points.astype(np.float32),
    colors=colors.astype(np.uint8),
    point_size=0.002,
)

# Add pointcloud colored by inverse distance intensity
server.scene.add_point_cloud(
    name="/pointmap_inverse_distance",
    points=points.astype(np.float32),
    colors=intensity_colors,
    point_size=0.002,
)

# Add offset field pointcloud for keypoint 4
# Color by offset magnitude
offset_magnitude = np.linalg.norm(offsets, axis=1)  # (N,)

# Normalize magnitude to [0, 1] for coloring
mag_min, mag_max = offset_magnitude.min(), offset_magnitude.max()
if mag_max > mag_min:
    mag_normalized = (offset_magnitude - mag_min) / (mag_max - mag_min)
else:
    mag_normalized = np.zeros_like(offset_magnitude)

# Map to colormap (using a colormap: blue (low) to red (high))
offset_colors = np.zeros((len(offsets), 3), dtype=np.uint8)
offset_colors[:, 0] = (mag_normalized * 255).astype(np.uint8)  # Red channel
offset_colors[:, 2] = ((1.0 - mag_normalized) * 255).astype(np.uint8)  # Blue channel

server.scene.add_point_cloud(
    name=f"/offset_field_kp{kp_idx}",
    points=points.astype(np.float32),
    colors=offset_colors,
    point_size=0.002,
)

# Add line segments showing offset vectors
# Sample points for visualization (too many lines can be slow)
num_samples = 1000
sample_indices = np.random.choice(len(points), num_samples, replace=False)

line_points = np.zeros((num_samples, 2, 3), dtype=np.float32)
line_colors = np.zeros((num_samples, 2, 3), dtype=np.uint8)

for i, idx in enumerate(sample_indices):
    # Start point: original pointmap point
    line_points[i, 0] = points[idx]
    # End point: pointmap point + offset (which equals keypoint)
    line_points[i, 1] = points[idx] + offsets[idx]
    
    # Color by offset magnitude
    mag = offset_magnitude[idx]
    mag_norm = (mag - mag_min) / (mag_max - mag_min) if mag_max > mag_min else 0.0
    line_color = np.array([int(mag_norm * 255), 0, int((1.0 - mag_norm) * 255)], dtype=np.uint8)
    line_colors[i, 0] = line_color
    line_colors[i, 1] = line_color

server.scene.add_line_segments(
    f"/offset_vectors_kp{kp_idx}",
    points=line_points,
    colors=line_colors,
    line_width=1.0,
)

# Add keypoint sphere for keypoint 4
ball_stl_path = "robot_models/so100_blender_testings/ball.stl"
ball_mesh = trimesh.load(ball_stl_path)
if isinstance(ball_mesh, trimesh.Scene):
    ball_mesh = list(ball_mesh.geometry.values())[0]
bounds = ball_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    ball_mesh.apply_scale(0.001)

server.scene.add_mesh_trimesh(
    name=f"/keypoint_{kp_idx}",
    mesh=ball_mesh,
    wxyz=(1.0, 0.0, 0.0, 0.0),
    position=keypoint.astype(np.float32),
)

print(f"\nViser server running at http://localhost:8080")
print(f"Sequence: {sequence_id}")
print(f"Visualizing keypoint {kp_idx} offset field")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

