"""Visualize 3D keypoint predictions (token-based) vs GT in viser."""

import argparse
import os
import sys
from pathlib import Path

import cv2
import mujoco
import numpy as np
import torch
import trimesh
import viser
import yourdfpy
from scipy.spatial.transform import Rotation as R
from viser.extras import ViserUrdf

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "."))
from model import TokenSelectionPredictor  # noqa: E402
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig  # noqa: E402
from utils import project_3d_to_2d, rescale_coords, post_process_predictions, load_gt_trajectory_3d, load_dino_features, build_patch_positions, load_cam_data  # noqa: E402
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX  # noqa: E402

WINDOW_SIZE = 10


def main():
    parser = argparse.ArgumentParser(description="Visualize 3D predictions with viser")
    parser.add_argument("--ood", action="store_true", help="Use OOD dataset")
    parser.add_argument("--od", default=0, type=int, help="OOD dataset index")
    parser.add_argument("--episode_idx", default=0, type=int, help="Episode index")
    parser.add_argument("--start_frame", "--sf", default=0, type=int, help="Start frame")
    args = parser.parse_args()

    device = torch.device(
        "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
    )
    print(f"Using device: {device}")

    model = TokenSelectionPredictor(
        dino_feat_dim=32,
        window_size=WINDOW_SIZE,
        num_layers=3,
        num_heads=4,
        hidden_dim=128,
        num_pos_bands=4,
    ).to(device)
    model_path = Path("token_selection_keypoints/tmpstorage/model.pt")
    if model_path.exists():
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"✓ Loaded model from {model_path}")
    else:
        print(f"⚠ Model not found at {model_path}, using random weights")
    model.eval()

    if args.ood:
        if args.od == 0:
            dataset_dir = Path("scratch/parsed_propercup_ood_cuppos")
        elif args.od == 1:
            dataset_dir = Path("scratch/parsed_propercup_ood_viewpointt1")
        else:
            dataset_dir = Path("scratch/parsed_propercup_ood_viewpoint2")
    else:
        dataset_dir = Path("scratch/parsed_propercup_train")
    print(f"Using dataset: {dataset_dir}")

    episode_dirs = sorted(
        [d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")]
    )
    if len(episode_dirs) == 0:
        print("No episodes found.")
        return
    if args.episode_idx >= len(episode_dirs):
        args.episode_idx = 0
    episode_dir = episode_dirs[args.episode_idx]
    episode_id = episode_dir.name
    print(f"Episode: {episode_id}")

    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) < WINDOW_SIZE + 1:
        print("Not enough frames.")
        return

    start_idx = args.start_frame
    if start_idx < 0 or start_idx >= len(frame_files) - WINDOW_SIZE:
        start_idx = 0

    camera_pose, cam_K = load_cam_data(episode_dir, frame_files[start_idx])
    if camera_pose is None or cam_K is None:
        print("Camera data missing.")
        return
    
    # Hardcode original image resolution (before downsampling)
    H_orig = 1080
    W_orig = 1920
    
    # cam_K is already calibrated at the original resolution (1080x1920)
    # No scaling needed - use it directly with H_orig=1080, W_orig=1920
    
    start_frame_file = frame_files[start_idx]
    rgb_np = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
    if rgb_np.max() <= 1.0:
        rgb_np = (rgb_np * 255).astype(np.uint8)

    start_frame_str = f"{int(start_frame_file.stem):06d}"
    dino_path = episode_dir / f"dino_features_{start_frame_str}.pt"
    if not dino_path.exists():
        print(f"Missing DINO features: {dino_path}")
        return
    dino_tokens, H_patches_loaded, W_patches_loaded = load_dino_features(dino_path)  # (num_patches, dino_dim)
    num_patches = dino_tokens.shape[0]
    patch_positions_np, H_patches, W_patches = build_patch_positions(
        num_patches,
        H_patches=H_patches_loaded,
        W_patches=W_patches_loaded,
    )
    patch_positions = torch.from_numpy(patch_positions_np).float()

    current_pose_path = episode_dir / f"{start_frame_str}_gripper_pose.npy"
    if not current_pose_path.exists():
        print("Missing start gripper pose.")
        return
    current_pose = np.load(current_pose_path)
    current_rot = current_pose[:3, :3]
    current_pos = current_pose[:3, 3]
    kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
    current_kp_3d = current_rot @ kp_local + current_pos
    current_kp_2d = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
    if current_kp_2d is None:
        print("Failed to project current KP.")
        return
    current_kp_patches = rescale_coords(
        current_kp_2d.reshape(1, 2), H_orig, W_orig, H_patches, W_patches
    )[0]
    current_eef_pos = torch.from_numpy(current_kp_patches).float()

    # GT 3D and heights
    trajectory_gt_3d, heights_gt = load_gt_trajectory_3d(episode_dir, frame_files, start_idx, WINDOW_SIZE, kp_local, return_heights=True)

    # Inference
    with torch.no_grad():
        dino_b = dino_tokens.unsqueeze(0).to(device)
        patch_b = patch_positions.unsqueeze(0).to(device)
        current_b = current_eef_pos.unsqueeze(0).to(device)
        pixel_scores, heights_pred = model(dino_b, patch_b, current_b)
        pixel_scores = pixel_scores.squeeze(0).cpu().numpy()
        heights_pred = heights_pred.squeeze(0).cpu().numpy()

    trajectory_pred_3d, pred_image_coords, heights_pred_denorm = post_process_predictions(
        pixel_scores, heights_pred, H_patches, W_patches, H_orig, W_orig, camera_pose, cam_K
    )

    # Viser setup
    server = viser.ViserServer()
    print("✓ Viser server started")

    # Robot URDF (optional)
    urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
    mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])
    if Path(urdf_path).exists():
        urdf = yourdfpy.URDF.load(urdf_path)
        viser_urdf = ViserUrdf(
            server,
            urdf_or_path=urdf,
            load_meshes=True,
            load_collision_meshes=False,
        )
        zero_joint_state = np.zeros(6) - mujoco_so100_offset
        viser_urdf.update_cfg(zero_joint_state)
        print("✓ Loaded robot URDF")
    else:
        viser_urdf = None

    # Gripper mesh (optional)
    gripper_mesh_path = Path("robot_models/gripper_assembly.stl")
    gripper_mesh_template = trimesh.load(str(gripper_mesh_path)) if gripper_mesh_path.exists() else None

    ball_gt = trimesh.creation.icosphere(subdivisions=2, radius=0.01)
    ball_pred = trimesh.creation.icosphere(subdivisions=2, radius=0.01)

    def render_scene():
        # GT
        if len(trajectory_gt_3d) > 0:
            colors_gt = np.zeros((len(trajectory_gt_3d), 3))
            for i in range(len(trajectory_gt_3d)):
                colors_gt[i] = [0, 1 - i / len(trajectory_gt_3d), i / len(trajectory_gt_3d)]
            server.scene.add_point_cloud(
                "/gt_trajectory",
                points=trajectory_gt_3d.astype(np.float32),
                colors=colors_gt.astype(np.float32),
                point_size=0.015,
            )
            for i, kp in enumerate(trajectory_gt_3d):
                server.scene.add_mesh_trimesh(
                    f"/gt_kp_{i}",
                    ball_gt.copy(),
                    wxyz=(1, 0, 0, 0),
                    position=kp.astype(np.float32),
                )
        else:
            server.scene.add_point_cloud(
                "/gt_trajectory",
                points=np.array([]).reshape(0, 3).astype(np.float32),
                colors=np.array([]).reshape(0, 3).astype(np.float32),
                point_size=0.015,
            )

        # Pred
        if len(trajectory_pred_3d) > 0:
            colors_pred = np.zeros((len(trajectory_pred_3d), 3))
            for i in range(len(trajectory_pred_3d)):
                colors_pred[i] = [1 - i / len(trajectory_pred_3d), 0, i / len(trajectory_pred_3d)]
            server.scene.add_point_cloud(
                "/pred_trajectory",
                points=trajectory_pred_3d.astype(np.float32),
                colors=colors_pred.astype(np.float32),
                point_size=0.015,
            )
            for i, kp in enumerate(trajectory_pred_3d):
                server.scene.add_mesh_trimesh(
                    f"/pred_kp_{i}",
                    ball_pred.copy(),
                    wxyz=(1, 0, 0, 0),
                    position=kp.astype(np.float32),
                )
        else:
            server.scene.add_point_cloud(
                "/pred_trajectory",
                points=np.array([]).reshape(0, 3).astype(np.float32),
                colors=np.array([]).reshape(0, 3).astype(np.float32),
                point_size=0.015,
            )

        # Gripper at start frame
        if gripper_mesh_template is not None:
            start_pose_path = episode_dir / f"{start_frame_str}_gripper_pose.npy"
            if start_pose_path.exists():
                pose = np.load(start_pose_path)
                rot = pose[:3, :3]
                pos = pose[:3, 3]
                quat = R.from_matrix(rot).as_quat()
                server.scene.add_mesh_trimesh(
                    "/gripper",
                    gripper_mesh_template.copy(),
                    wxyz=quat[[3, 0, 1, 2]],
                    position=pos.astype(np.float32),
                )

        print(
            f"Episode: {episode_id} | Start Frame: {start_idx} | "
            f"GT pts: {len(trajectory_gt_3d)} | Pred pts: {len(trajectory_pred_3d)}"
        )

    render_scene()

    print("✓ Visualization ready. Open the viser window. Ctrl+C to exit.")
    try:
        while True:
            import time

            time.sleep(1)
    except KeyboardInterrupt:
        print("\nExiting...")


if __name__ == "__main__":
    main()


