"""Process a single helper image with kinematics:

Tasks:
 a) Render a robot mask from the provided image/joint state and visualize it
 b) Run MoGE pointmap estimation on the image
 c) Align the MoGE pointmap to the robot frame via Procrustes (using ArUco correspondences)
    and visualize it in a viser scene together with the robot at the provided joint state
"""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.append("Demos")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from PIL import Image

import mujoco
import viser
from viser.extras import ViserUrdf
import yourdfpy

from moge.model.v2 import MoGeModel
import utils3d
import fpsample

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (
    detect_and_set_link_poses,
    estimate_robot_state,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    detect_and_position_alignment_board,
    combine_xmls,
)
from demo_utils import procrustes_alignment


img_path =  "scratch/grasp_dataset_keyboard/1Ady4y_helper_grasp.png"
img2_path = "scratch/grasp_dataset_keyboard/1Ady4y_helper_grasp.png"
qpos_path = "scratch/grasp_dataset_keyboard/1Ady4y_helper_grasp.npy"

# Load image
rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
rgb2 = cv2.cvtColor(cv2.imread(img2_path), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0: rgb = (rgb * 255).astype(np.uint8)
if rgb2.max() <= 1.0: rgb2 = (rgb2 * 255).astype(np.uint8)
# Load joint state
qpos = np.load(qpos_path)
qpos[-1]=1.2

# Setup robot config and MuJoCo model
SO100AdhesiveConfig.exo_alpha = 0.2
SO100AdhesiveConfig.aruco_alpha = 0.2
robot_config = SO100AdhesiveConfig()

# Add supporting alignment board to the scene XML
combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())

model = mujoco.MjModel.from_xml_string(combined_xml)
data = mujoco.MjData(model)

# Set robot state from qpos
data.qpos[:] = qpos
data.ctrl[:] = qpos[: len(data.ctrl)]
mujoco.mj_forward(model, data)

print("=" * 60)
print("Detecting ArUco, estimating camera pose, and positioning exo meshes")
print("=" * 60)

# Detect link poses and camera pose from the image; also positions meshes
_, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses( rgb, model, data, robot_config)
np.save("scratch/tmp_camera_pose_world.npy", camera_pose_world);
link_poses, _, _, _, _, _ = detect_and_set_link_poses( rgb, model, data, robot_config)
# Detect and position supporting alignment board (adds useful correspondences)
board_result = detect_and_position_alignment_board( rgb, model, data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False)
if board_result is not None:
    board_pose, board_pts = board_result
    obj_img_pts["alignment_board"] = board_pts


position_exoskeleton_meshes(robot_config, model, data, link_poses)
#configuration = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=55)
#data.qpos[:] = configuration.q
#data.ctrl[:] = configuration.q[:len(data.ctrl)]
mujoco.mj_forward(model, data)

print("=" * 60)
print("a) Render robot mask")
print("=" * 60)

# Render segmentation mask
seg = render_from_camera_pose(
    model, data, camera_pose_world, cam_K, *rgb.shape[:2], segmentation=True
)
robot_mask = (seg[..., 0] > 0)

print("=" * 60)
print("Human mask estimation")
print("=" * 60)

# Set device for both segmentation and MoGE models
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load pre-trained segmentation model (DeepLabV3 with ResNet101)
print("Loading DeepLabV3 segmentation model...")
model_seg = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model_seg.eval()
model_seg = model_seg.to(device)

# Define preprocessing transform
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Generate human mask for the scene image (rgb)
img_pil = Image.fromarray(rgb)
input_tensor = preprocess(img_pil)
input_batch = input_tensor.unsqueeze(0).to(device)

with torch.no_grad():
    output = model_seg(input_batch)['out'][0]

# Get class predictions
output_predictions = output.argmax(0).cpu().numpy()

# Class 15 in COCO is 'person'
human_mask = (output_predictions == 15).astype(np.uint8)
print(f"Human mask pixels: {human_mask.sum()}")

# Visualize masks
print("=" * 60)
print("Visualizing robot and human masks")
print("=" * 60)

if 1:
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Row 1: Scene image
    axes[0, 0].imshow(rgb)
    axes[0, 0].set_title("Scene Image", fontsize=12, fontweight='bold')
    axes[0, 0].axis("off")

    axes[0, 1].imshow(robot_mask, cmap="gray")
    axes[0, 1].set_title("Robot Mask", fontsize=12, fontweight='bold')
    axes[0, 1].axis("off")

    robot_overlay = rgb.copy()
    robot_overlay[robot_mask] = (0.5 * robot_overlay[robot_mask] + 0.5 * np.array([255, 0, 0])).astype(np.uint8)
    axes[0, 2].imshow(robot_overlay)
    axes[0, 2].set_title("Robot Overlay (Red)", fontsize=12, fontweight='bold')
    axes[0, 2].axis("off")

    # Row 2: Human mask
    axes[1, 0].imshow(rgb)
    axes[1, 0].set_title("Scene Image", fontsize=12, fontweight='bold')
    axes[1, 0].axis("off")

    axes[1, 1].imshow(human_mask, cmap="gray")
    axes[1, 1].set_title("Human Mask", fontsize=12, fontweight='bold')
    axes[1, 1].axis("off")

    human_overlay = rgb.copy()
    human_overlay[human_mask == 1] = (0.5 * human_overlay[human_mask == 1] + 0.5 * np.array([0, 255, 0])).astype(np.uint8)
    axes[1, 2].imshow(human_overlay)
    axes[1, 2].set_title("Human Overlay (Green)", fontsize=12, fontweight='bold')
    axes[1, 2].axis("off")

    plt.tight_layout()
    plt.show()

    zz
print("=" * 60)
print("b) MoGE pointmap estimation")
print("=" * 60)

# Load MoGE model (device already set above)
if 0:
    model_moge = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
    model_moge.eval()
    # Prepare input tensor
    input_tensor = torch.tensor(rgb / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
    with torch.no_grad():
        output = model_moge.infer(input_tensor)
    torch.save(output, "scratch/tmp_moge_output.pt")
else: output = torch.load("scratch/tmp_moge_output.pt")

points = output["points"].cpu().numpy()  # (H, W, 3) in camera frame
mask = output["mask"].cpu().numpy().astype(bool)
# Filter out depth edges
mask = mask & ~utils3d.np.depth_map_edge(points[:, :, 2], rtol=0.005)

H, W = points.shape[:2]
# Colors from image
colors = rgb.reshape(-1, 3)
points_flat = points.reshape(-1, 3)
mask_flat = mask.reshape(-1)

# Apply robot and human masks to exclude those pixels from the pointcloud
print("Applying robot and human masks to pointcloud...")

# Resize masks to match MoGE pointmap resolution (H, W)
robot_mask_resized = np.array(Image.fromarray(robot_mask.astype(np.uint8) * 255).resize((W, H), Image.Resampling.LANCZOS)) > 127
human_mask_resized = np.array(Image.fromarray(human_mask.astype(np.uint8) * 255).resize((W, H), Image.Resampling.LANCZOS)) > 127

# Combine masks: exclude robot and human pixels
exclude_mask = robot_mask_resized | human_mask_resized
exclude_mask_flat = exclude_mask.reshape(-1)

# Update mask_flat to exclude robot and human pixels
mask_flat = mask_flat #& ~exclude_mask_flat

valid_points_cam = points_flat[mask_flat]
valid_colors = colors[mask_flat].astype(np.uint8)

print(f"Valid MoGE points after depth filtering: {mask.sum()}")
print(f"Robot mask pixels excluded: {robot_mask_resized.sum()}")
print(f"Human mask pixels excluded: {human_mask_resized.sum()}")
print(f"Final valid MoGE points after mask filtering: {len(valid_points_cam)}")

print("=" * 60)
print("c) Align MoGE pointmap to robot frame using ArUco correspondences")
print("=" * 60)

# Build correspondence sets:
# - aruco_corners_robot_frame: 3D points in robot base frame (from ArUco object corners)
# - moge_aruco_corners: corresponding 3D points sampled from the MoGE pointmap at the same image pixels
aruco_corners_robot_frame = []
moge_aruco_corners = []

for obj_name, (obj_img_pts_cam, img_pts_px) in obj_img_pts.items():
    if obj_name not in ["alignment_board","larger_base"]: continue
    # Compute ArUco corners in robot frame via camera pose
    obj_cam_h = np.hstack([obj_img_pts_cam, np.ones((obj_img_pts_cam.shape[0], 1))])
    obj_robot = (np.linalg.inv(camera_pose_world) @ obj_cam_h.T).T[:, :3]
    aruco_corners_robot_frame.extend(obj_robot)

    # Sample MoGE pointmap at pixel coordinates
    for pt in img_pts_px:
        x, y = int(pt[0]), int(pt[1])
        if 0 <= y < H and 0 <= x < W:
            moge_aruco_corners.append(points[y, x])

aruco_corners_robot_frame = np.array(aruco_corners_robot_frame)
moge_aruco_corners = np.array(moge_aruco_corners)

if len(aruco_corners_robot_frame) < 3 or len(moge_aruco_corners) < 3:
    print("Insufficient correspondences for Procrustes alignment; skipping alignment.")
    points_robot = valid_points_cam  # fallback: keep in camera frame
else:
    # Procrustes alignment: find transform that maps MoGE points -> robot frame
    T_procrustes, scale, rotation, translation = procrustes_alignment(
        aruco_corners_robot_frame, moge_aruco_corners
    )
    points_h = np.hstack([valid_points_cam, np.ones((len(valid_points_cam), 1))])
    points_robot = (T_procrustes @ points_h.T).T[:, :3]
    print("Applied Procrustes alignment to robot frame.")

# Volume bounds for cropping (same as in view_episode.py)

volume_bounds = {
    "x_min": -0.16,
    "x_max": 0.1,
    "y_min": -.6,
    "y_max": -0.2,
    "z_min": 0.005,
    "z_max": 0.08,
}

print("=" * 60)
print("Cropping pointcloud with volume bounds")
print("=" * 60)

# Crop pointcloud using volume bounds
mask_filtered = (
    (points_robot[:, 0] >= volume_bounds["x_min"]) & (points_robot[:, 0] <= volume_bounds["x_max"]) &
    (points_robot[:, 1] >= volume_bounds["y_min"]) & (points_robot[:, 1] <= volume_bounds["y_max"]) &
    (points_robot[:, 2] >= volume_bounds["z_min"]) & (points_robot[:, 2] <= volume_bounds["z_max"])
)
points_robot_cropped = points_robot[mask_filtered]
valid_colors_cropped = valid_colors[mask_filtered]

print(f"Cropped pointcloud: {len(points_robot)} -> {len(points_robot_cropped)} points")
print(f"Volume bounds: x=[{volume_bounds['x_min']:.2f}, {volume_bounds['x_max']:.2f}], "
      f"y=[{volume_bounds['y_min']:.2f}, {volume_bounds['y_max']:.2f}], "
      f"z=[{volume_bounds['z_min']:.2f}, {volume_bounds['z_max']:.2f}]")

print("=" * 60)
print("FPS sampling to 1024 points")
print("=" * 60)

# Downsample cropped pointcloud to 1024 points using FPS
target_points = 2048
num_points = min(target_points, len(points_robot_cropped))
if len(points_robot_cropped) > num_points:
    # FPS sampling returns indices
    indices = fpsample.fps_npdu_kdtree_sampling(points_robot_cropped, num_points)
    points_robot_fps = points_robot_cropped[indices]
    valid_colors_fps = valid_colors_cropped[indices]
    print(f"FPS downsampling: {len(points_robot_cropped)} -> {len(points_robot_fps)} points")
else:
    points_robot_fps = points_robot_cropped
    valid_colors_fps = valid_colors_cropped
    print(f"Using all {len(points_robot_fps)} points (no FPS downsampling needed)")

# Launch viser visualization
print("=" * 60)
print("Launching viser with robot and aligned pointcloud")
print("=" * 60)
server = viser.ViserServer()

# Add robot URDF
urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
urdf = yourdfpy.URDF.load(urdf_path)
viser_urdf = ViserUrdf(
    server,
    urdf_or_path=urdf,
    load_meshes=True,
    load_collision_meshes=False,
    collision_mesh_color_override=(1.0, 0.0, 0.0, 0.5),
)
# Match Mujoco joint ordering offset used elsewhere
mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])
viser_urdf.update_cfg(np.array(data.qpos - mujoco_so100_offset))

# Add full pointcloud in robot frame
server.scene.add_point_cloud(
    name="/moge_aligned_full",
    points=points_robot.astype(np.float32),
    colors=valid_colors.astype(np.uint8),
    point_size=0.001,
)

# Add filtered and FPS-sampled pointcloud
server.scene.add_point_cloud(
    name="/moge_aligned_cropped_fps",
    points=points_robot_fps.astype(np.float32),
    colors=valid_colors_fps.astype(np.uint8),
    point_size=0.002,  # Slightly larger for visibility
)

print("Viser server running at http://localhost:8080")
print("Press Ctrl+C to exit")
try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass
