"""Visualize processed keyboard grasp episode with MoGE pointcloud in robot frame using viser."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
from viser.extras import ViserUrdf
import yourdfpy
import mujoco
from scipy.spatial.transform import Rotation as R
import trimesh
import torch.nn.functional as F
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import combine_xmls, get_link_poses_from_robot, position_exoskeleton_meshes
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG

parser = argparse.ArgumentParser()
parser.add_argument("--sequence_id", type=str, default="0vld05", help="Sequence ID to visualize")
args = parser.parse_args()

# Configuration
processed_dir = "scratch/processed_grasp_dataset_keyboard"
sequence_id = args.sequence_id
sequence_dir = os.path.join(processed_dir, sequence_id)

if not os.path.exists(sequence_dir):
    raise FileNotFoundError(f"Sequence directory not found: {sequence_dir}")

print("=" * 60)
print(f"Loading processed sequence: {sequence_id}")
print("=" * 60)

# Load joint states
joint_states_path = os.path.join(sequence_dir, "joint_states_grasp.npy")
joint_states = np.load(joint_states_path)
print(f"Loaded grasp joint states: {joint_states.shape}")

# Load pointclouds
pointmap_full_path = os.path.join(sequence_dir, "pointmap_start.pt")
pointmap_cropped_path = os.path.join(sequence_dir, "pointmap_start_cropped.pt")
pointmap_cropped_fps_path = os.path.join(sequence_dir, "pointmap_start_cropped_fps.pt")

pointmap_full = torch.load(pointmap_full_path)
points_full = pointmap_full["points"].numpy()  # (N, 3) already in robot frame
colors_full = pointmap_full["colors"].numpy()  # (N, 3)

if os.path.exists(pointmap_cropped_path):
    pointmap_cropped = torch.load(pointmap_cropped_path)
    points_cropped = pointmap_cropped["points"].numpy()  # (N, 3)
    colors_cropped = pointmap_cropped["colors"].numpy()  # (N, 3)
else:
    points_cropped = None
    colors_cropped = None
    print("  Warning: Cropped pointcloud not found")

if os.path.exists(pointmap_cropped_fps_path):
    pointmap_cropped_fps = torch.load(pointmap_cropped_fps_path)
    points_cropped_fps = pointmap_cropped_fps["points"].numpy()  # (N, 3)
    colors_cropped_fps = pointmap_cropped_fps["colors"].numpy()  # (N, 3)
else:
    points_cropped_fps = None
    colors_cropped_fps = None
    print("  Warning: Cropped FPS pointcloud not found")

# Load DINO features if available
dino_features_path = os.path.join(sequence_dir, "dino_features_fps.pt")

if os.path.exists(dino_features_path):
    dino_features_fps = torch.load(dino_features_path).numpy()  # (N, 32)
    # Map first 3 PCA components to RGB
    dino_rgb = F.sigmoid(torch.from_numpy(dino_features_fps[:, :3]).mul(2.0)).numpy()
    dino_colors_fps = (dino_rgb * 255).astype(np.uint8)
    print(f"Loaded DINO features: {len(dino_features_fps)} points, {dino_features_fps.shape[1]} dimensions")
else:
    dino_features_fps = None
    dino_colors_fps = None
    print("  Warning: DINO features not found")

print(f"Full pointcloud: {len(points_full)} points")
if points_cropped is not None:
    print(f"Cropped pointcloud: {len(points_cropped)} points")
if points_cropped_fps is not None:
    print(f"Cropped FPS pointcloud: {len(points_cropped_fps)} points")

# Get gripper pose from MuJoCo
print("\n" + "=" * 60)
print("Computing gripper pose from MuJoCo")
print("=" * 60)
SO100AdhesiveConfig.exo_alpha = 0.2
SO100AdhesiveConfig.aruco_alpha = 0.2
robot_config = SO100AdhesiveConfig()
combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
mj_model = mujoco.MjModel.from_xml_string(combined_xml)
mj_data = mujoco.MjData(mj_model)
mj_data.qpos[:] = joint_states
mj_data.ctrl[:] = joint_states[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)

# Position exoskeleton meshes to ensure gripper is at correct pose
link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

exo_mesh_body_name = "fixed_gripper_exo_mesh"
exo_mesh_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
exo_mesh_mocap_id = mj_model.body_mocapid[exo_mesh_body_id]
exo_pos = mj_data.mocap_pos[exo_mesh_mocap_id].copy()
exo_quat_wxyz = mj_data.mocap_quat[exo_mesh_mocap_id].copy()

gripper_pose = np.eye(4)
gripper_pose[:3, :3] = R.from_quat(exo_quat_wxyz[[1, 2, 3, 0]]).as_matrix()
gripper_pose[:3, 3] = exo_pos
print(f"  Gripper position: [{exo_pos[0]:.3f}, {exo_pos[1]:.3f}, {exo_pos[2]:.3f}]")

# Load fixed_gripper STL
fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
if isinstance(fixed_gripper_mesh, trimesh.Scene):
    fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]
bounds = fixed_gripper_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    fixed_gripper_mesh.apply_scale(0.001)

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser with robot and pointcloud")
print("=" * 60)
server = viser.ViserServer()

# Add robot URDF
urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
urdf = yourdfpy.URDF.load(urdf_path)
viser_urdf = ViserUrdf(
    server,
    urdf_or_path=urdf,
    load_meshes=True,
    load_collision_meshes=False,
    collision_mesh_color_override=(1.0, 0.0, 0.0, 0.5),
)

# Match Mujoco joint ordering offset used elsewhere
mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])
viser_urdf.update_cfg(np.array(joint_states - mujoco_so100_offset))
print(f"  Robot URDF loaded with grasp joint state")

# Add full pointcloud in robot frame
server.scene.add_point_cloud(
    name="/moge_aligned_full",
    points=points_full.astype(np.float32),
    colors=colors_full.astype(np.uint8),
    point_size=0.001,
)
print(f"  Added full pointcloud: {len(points_full)} points")

# Add cropped pointcloud if available
if points_cropped is not None:
    server.scene.add_point_cloud(
        name="/moge_aligned_cropped",
        points=points_cropped.astype(np.float32),
        colors=colors_cropped.astype(np.uint8),
        point_size=0.002,  # Slightly larger for visibility
    )
    print(f"  Added cropped pointcloud: {len(points_cropped)} points")

# Add cropped FPS1024 pointcloud if available
if points_cropped_fps is not None:
    server.scene.add_point_cloud(
        name="/moge_aligned_cropped_fps",
        points=points_cropped_fps.astype(np.float32),
        colors=colors_cropped_fps.astype(np.uint8),
        point_size=0.003,  # Larger for visibility
    )
    print(f"  Added cropped FPS pointcloud: {len(points_cropped_fps)} points")

# Add DINO feature pointcloud if available
if points_cropped_fps is not None and dino_colors_fps is not None:
    server.scene.add_point_cloud(
        name="/dino_features_fps",
        points=points_cropped_fps.astype(np.float32),
        colors=dino_colors_fps.astype(np.uint8),
        point_size=0.003,
    )
    print(f"  Added DINO feature pointcloud: {len(points_cropped_fps)} points")

# Add gripper mesh
if fixed_gripper_mesh is not None:
    pos = gripper_pose[:3, 3]
    rot = gripper_pose[:3, :3]
    quat = R.from_matrix(rot).as_quat()  # (x, y, z, w)
    
    vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
    transformed_vertices = (gripper_pose @ vertices_homogeneous.T).T[:, :3]
    
    try:
        server.scene.add_mesh_trimesh(
            name="/fixed_gripper",
            mesh=fixed_gripper_mesh,
            wxyz=quat[[3, 0, 1, 2]],  # (w, x, y, z)
            position=pos,
        )
        print(f"  Added gripper mesh")
    except:
        server.scene.add_mesh(
            name="/fixed_gripper",
            vertices=transformed_vertices.astype(np.float32),
            faces=fixed_gripper_mesh.faces.astype(np.int32),
            color=(100, 150, 200, 255),
        )
        print(f"  Added gripper mesh (fallback)")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Sequence: {sequence_id}")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

