"""General utilities for token selection model: geometry, visualization, and IK."""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import torch
import cv2
import mujoco
import time
from pathlib import Path
from torchvision.utils import make_grid
from scipy.spatial.transform import Rotation as R
import mink

# ========== Data Loading Functions ==========

def load_dino_features(path):
    """Load and process DINO features from file.
    
    Args:
        path: Path to DINO features file (.pt)
    
    Returns:
        feats: (num_patches, dino_feat_dim) tensor of features
        H_patches: Height in patches (or None if not determinable)
        W_patches: Width in patches (or None if not determinable)
    """
    feats = torch.load(path, map_location="cpu", weights_only=False)
    if isinstance(feats, np.ndarray):
        feats = torch.from_numpy(feats)
    feats = feats.float()
    H_patches = W_patches = None
    if feats.dim() == 4 and feats.shape[0] == 1:
        feats = feats.squeeze(0)
    if feats.dim() == 3:
        if feats.shape[0] < feats.shape[-1]:
            feats = feats.permute(1, 2, 0)
        H_patches, W_patches, D = feats.shape
        feats = feats.reshape(-1, D)
    elif feats.dim() == 2:
        if feats.shape[0] < feats.shape[1] and feats.shape[0] <= 128:
            feats = feats.transpose(0, 1)
    else:
        feats = feats.reshape(feats.shape[0], -1)
    return feats, H_patches, W_patches

def build_patch_positions(num_patches, H_patches=None, W_patches=None):
    """Build patch position coordinates.
    
    Args:
        num_patches: Total number of patches
        H_patches: Height in patches (if known)
        W_patches: Width in patches (if known)
    
    Returns:
        patch_positions: (num_patches, 2) array of [x, y] patch coordinates
        H_patches: Height in patches
        W_patches: Width in patches
    """
    if H_patches is None or W_patches is None:
        H_patches = int(np.sqrt(num_patches))
        while num_patches % H_patches != 0:
            H_patches -= 1
        W_patches = num_patches // H_patches
    y_coords, x_coords = np.meshgrid(np.arange(H_patches), np.arange(W_patches), indexing="ij")
    patch_positions = np.stack([x_coords.flatten(), y_coords.flatten()], axis=1).astype(np.float32)
    return patch_positions, H_patches, W_patches

def load_cam_data(episode_dir, frame_file):
    """Load camera pose and intrinsics for a frame.
    
    Args:
        episode_dir: Path to episode directory
        frame_file: Path to frame file
    
    Returns:
        camera_pose: 4x4 camera pose matrix (or None if not found)
        cam_K: 3x3 camera intrinsics (or None if not found)
    """
    frame_str = f"{int(frame_file.stem):06d}"
    cam_pose_path = episode_dir / f"robot_camera_pose_{frame_str}.npy"
    cam_K_path = episode_dir / f"cam_K_{frame_str}.npy"
    if cam_pose_path.exists() and cam_K_path.exists():
        return np.load(cam_pose_path), np.load(cam_K_path)
    cam_pose_static = episode_dir / "robot_camera_pose.npy"
    cam_K_static = episode_dir / "cam_K.npy"
    if cam_pose_static.exists() and cam_K_static.exists():
        return np.load(cam_pose_static), np.load(cam_K_static)
    return None, None

# ========== Geometric Transformations ==========

def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point to 2D image coordinates."""
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = (camera_pose @ point_3d_h)[:3]
    if point_cam[2] <= 0: 
        return None
    point_2d_h = cam_K @ point_cam
    return point_2d_h[:2] / point_2d_h[2]

def rescale_coords(coords, H_orig, W_orig, H_new, W_new):
    """Rescale 2D coordinates from original image size to new size.
    
    Args:
        coords: Can be None, empty, 1D array (2,), or 2D array (N, 2)
        H_orig, W_orig: Original image dimensions
        H_new, W_new: New image dimensions
    
    Returns:
        Rescaled coordinates in same shape as input (or None if input was None)
    """
    if coords is None:
        return None
    coords = np.asarray(coords, dtype=np.float32)
    if coords.size == 0:
        return coords
    # Handle 1D case: (2,) -> reshape to (1, 2)
    if coords.ndim == 1:
        coords = coords.reshape(1, -1)
        was_1d = True
    else:
        was_1d = False
    scale_x = W_new / W_orig
    scale_y = H_new / H_orig
    coords_rescaled = np.stack(
        [coords[..., 0] * scale_x, coords[..., 1] * scale_y], axis=-1
    )
    # Return in original shape
    if was_1d:
        return coords_rescaled[0]
    return coords_rescaled

def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to a ray in robot frame.
    
    Args:
        point_2d: 2D point in image coordinates
        camera_pose: 4x4 transformation matrix from robot frame to camera frame
        cam_K: 3x3 camera intrinsics
    
    Returns:
        cam_pos_robot: Camera position in robot frame
        ray_robot: Ray direction in robot frame
    """
    cam_pose_inv = np.linalg.inv(camera_pose)
    cam_pos_robot = cam_pose_inv[:3, 3]
    cam_rot_c2r = cam_pose_inv[:3, :3]
    fx, fy = cam_K[0, 0], cam_K[1, 1]
    cx, cy = cam_K[0, 2], cam_K[1, 2]
    x_cam = (point_2d[0] - cx) / fx
    y_cam = (point_2d[1] - cy) / fy
    z_cam = 1.0
    ray_cam = np.array([x_cam, y_cam, z_cam])
    ray_cam = ray_cam / np.linalg.norm(ray_cam)
    ray_robot = cam_rot_c2r @ ray_cam
    return cam_pos_robot, ray_robot

def recover_3d_from_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D image projection of ground projection and height.
    
    The 2D point is the projection of the keypoint's ground projection (Y=0), not the keypoint itself.
    So we need to:
    1. Find where the ray from the camera through the 2D point intersects the ground plane (Y=0, index 2)
    2. Then move up by the height to get the actual 3D keypoint
    
    Args:
        kp_2d_image: 2D image coordinates of the ground projection of the keypoint
        height: Height (Z coordinate, index 2) of the actual 3D keypoint
        camera_pose: 4x4 camera pose matrix
        cam_K: 3x3 camera intrinsics
    
    Returns:
        3D keypoint position
    """
    if kp_2d_image is None or height is None:
        return None
    cam_pos, ray_image = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Find where the ray intersects the ground plane (Y=0, which is index 2 in MuJoCo Z-up convention)
    # Ray equation: point = cam_pos + t * ray_image
    # We want point[2] = 0 (ground plane)
    if abs(ray_image[2]) < 1e-6:
        return None  # Ray is parallel to ground plane
    
    # Solve for t where cam_pos[2] + t * ray_image[2] = 0
    t_ground = -cam_pos[2] / ray_image[2]
    ground_point = cam_pos + t_ground * ray_image
    
    # Now move up by the height to get the actual 3D keypoint
    # The height is the Z coordinate (index 2), so we add it to the ground point's Z coordinate
    kp_3d = ground_point.copy()
    kp_3d[2] = height  # Set Z coordinate to the height
    
    return kp_3d

def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from direct 2D keypoint projection and height.
    
    The 2D point is the direct projection of the 3D keypoint (not its ground projection).
    We find the point along the ray from the camera through the 2D point at the specified height.
    
    Args:
        kp_2d_image: 2D image coordinates of the direct keypoint projection
        height: Height (Y coordinate, index 2) of the 3D keypoint in MuJoCo Z-up convention
        camera_pose: 4x4 camera pose matrix
        cam_K: 3x3 camera intrinsics
    
    Returns:
        3D keypoint position, or None if invalid
    """
    if kp_2d_image is None or height is None:
        return None
    
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (Y coordinate, index 2 in MuJoCo Z-up convention)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D keypoint position
    kp_3d = cam_pos + t * ray_direction
    
    return kp_3d

# ========== GT Trajectory Loading ==========

def load_gt_trajectory_3d(episode_dir, frame_files, start_idx, window_size, kp_local, return_heights=False, return_orientations=False):
    """
    Load ground truth 3D trajectory from gripper poses.
    
    Args:
        episode_dir: Path to episode directory
        frame_files: List of frame file paths
        start_idx: Starting frame index
        window_size: Number of future frames to load
        kp_local: Local keypoint offset (from KEYPOINTS_LOCAL_M_ALL[KP_INDEX])
        return_heights: If True, also return heights array
        return_orientations: If True, also return orientations as rotation matrices
    
    Returns:
        trajectory_gt_3d: (N, 3) array of 3D keypoints
        heights_gt: (N,) array of heights (only if return_heights=True)
        orientations_gt: (N, 3, 3) array of rotation matrices (only if return_orientations=True)
    """
    trajectory_gt_3d = []
    heights_gt = []
    orientations_gt = []
    
    for offset in range(1, window_size + 1):
        f_idx = start_idx + offset
        if f_idx >= len(frame_files):
            break
        frame_str = f"{int(frame_files[f_idx].stem):06d}"
        pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if not pose_path.exists():
            continue
        pose = np.load(pose_path)
        rot = pose[:3, :3]
        pos = pose[:3, 3]
        kp_3d = rot @ kp_local + pos
        trajectory_gt_3d.append(kp_3d)
        if return_heights:
            heights_gt.append(kp_3d[2])
        if return_orientations:
            orientations_gt.append(rot)
    
    trajectory_gt_3d = np.array(trajectory_gt_3d) if len(trajectory_gt_3d) > 0 else np.array([]).reshape(0, 3)
    
    ret = [trajectory_gt_3d]
    if return_heights:
        heights_gt = np.array(heights_gt) if len(heights_gt) > 0 else np.array([])
        ret.append(heights_gt)
    if return_orientations:
        orientations_gt = np.array(orientations_gt) if len(orientations_gt) > 0 else np.array([]).reshape(0, 3, 3)
        ret.append(orientations_gt)
    
    if len(ret) == 1:
        return ret[0]
    return tuple(ret)

# ========== Prediction Post-Processing ==========

def post_process_predictions(pixel_scores, heights_pred, H_patches, W_patches, H_orig, W_orig,
                            camera_pose, cam_K):
    """
    Post-process model predictions: convert pixel scores to 3D trajectory.
    
    Args:
        pixel_scores: (window_size, num_patches) or (num_patches,) - attention/pixel scores
        heights_pred: (window_size,) or scalar - predicted heights (normalized)
        H_patches, W_patches: Patch grid dimensions
        H_orig, W_orig: Original image dimensions
        camera_pose: 4x4 camera pose matrix
        cam_K: 3x3 camera intrinsics
    
    Returns:
        trajectory_pred_3d: (N, 3) array of 3D keypoints
        trajectory_pred_2d_image: (N, 2) array of 2D image coordinates
        heights_pred_denorm: (N,) array of denormalized heights
    """
    pixel_scores = np.asarray(pixel_scores)
    heights_pred = np.asarray(heights_pred)
    
    # Handle single timestep case
    if pixel_scores.ndim == 1:
        pixel_scores = pixel_scores.reshape(1, -1)
    if heights_pred.ndim == 0:
        heights_pred = heights_pred.reshape(1)
    
    window_size = pixel_scores.shape[0]
    num_patches = pixel_scores.shape[1]
    
    # Convert pixel scores to patch indices
    pred_patch_idx = pixel_scores.argmax(axis=1)  # (window_size,)
    
    # Convert patch indices to patch coordinates
    pred_patches = []
    for idx in pred_patch_idx:
        py = idx // W_patches
        px = idx % W_patches
        pred_patches.append([px, py])
    pred_patches = np.array(pred_patches)  # (window_size, 2)
    
    # Convert patch coordinates to image coordinates
    pred_image_coords = rescale_coords(pred_patches, H_patches, W_patches, H_orig, W_orig)
    
    # Denormalize heights
    heights_pred_denorm = heights_pred * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    
    # Convert 2D + height to 3D using direct keypoint projection (not groundplane)
    trajectory_pred_3d = []
    for t in range(min(len(pred_image_coords), len(heights_pred_denorm))):
        kp_2d = pred_image_coords[t]
        h = heights_pred_denorm[t]
        kp_3d_pred = recover_3d_from_direct_keypoint_and_height(kp_2d, h, camera_pose, cam_K)
        if kp_3d_pred is not None:
            trajectory_pred_3d.append(kp_3d_pred)
    
    trajectory_pred_3d = np.array(trajectory_pred_3d) if len(trajectory_pred_3d) > 0 else np.array([]).reshape(0, 3)
    
    return trajectory_pred_3d, pred_image_coords, heights_pred_denorm

# ========== IK Functions ==========

def ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, configuration, robot_config, mj_model, mj_data, max_iterations=100, min_height_above_ground=0.02):
    """Solve IK with separate tasks for keypoint position and gripper rotation.
    
    Args:
        target_kp_pos: (3,) target keypoint position
        target_gripper_rot: (3, 3) target gripper rotation matrix
        configuration: mink.Configuration object
        robot_config: Robot configuration
        mj_model: MuJoCo model
        mj_data: MuJoCo data
        max_iterations: Maximum number of IK iterations
        min_height_above_ground: Minimum height above ground plane (Y=0) in meters. Default 0.02m (2cm).
    """
    from exo_utils import get_link_poses_from_robot, position_exoskeleton_meshes
    
    # Clamp target position to be above ground plane (Y >= min_height_above_ground, where Y is index 2)
    target_kp_pos_constrained = target_kp_pos.copy()
    if target_kp_pos_constrained[2] < min_height_above_ground:
        target_kp_pos_constrained[2] = min_height_above_ground
    
    for iteration in range(max_iterations):
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        configuration.update(mj_data.qpos)
        
        # Task 1: Keypoint position (virtual_gripper_keypoint)
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.0)
        # Get current orientation of keypoint to maintain it
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
        kp_quat = R.from_matrix(kp_rot).as_quat()
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[kp_quat[3], kp_quat[0], kp_quat[1], kp_quat[2]], target_kp_pos_constrained])))
        
        # Task 2: Gripper rotation (Fixed_Jaw)
        gripper_quat = R.from_matrix(target_gripper_rot).as_quat()
        gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
        gripper_pos = mj_data.xpos[gripper_body_id].copy()
        gripper_task = mink.FrameTask("Fixed_Jaw", "body", position_cost=0.0, orientation_cost=.015)
        gripper_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[gripper_quat[3], gripper_quat[0], gripper_quat[1], gripper_quat[2]], gripper_pos])))
        
        posture_task = mink.PostureTask(mj_model, cost=1e-3)
        posture_task.set_target(mj_data.qpos)

        vel = mink.solve_ik(configuration, [kp_task, gripper_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
        configuration.integrate_inplace(vel, 0.01)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_step(mj_model, mj_data)
        
        # Enforce ground plane constraint: clamp keypoint to minimum height above ground
        mujoco.mj_forward(mj_model, mj_data)
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        current_kp_pos = mj_data.xpos[kp_body_id].copy()
        
        # If keypoint is below minimum height, adjust target position upward for next iteration
        if current_kp_pos[2] < min_height_above_ground:
            # Adjust target position upward to enforce minimum height
            target_kp_pos_constrained[2] = max(target_kp_pos_constrained[2], min_height_above_ground)
        
        # Check convergence (optional early stopping)
        if iteration % 10 == 0:
            error = np.linalg.norm(current_kp_pos - target_kp_pos_constrained)
            if error < 0.001:  # 1mm tolerance
                break
    
    # Final enforcement: ensure keypoint is above ground after IK completes
    mujoco.mj_forward(mj_model, mj_data)
    kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
    final_kp_pos = mj_data.xpos[kp_body_id].copy()
    if final_kp_pos[2] < min_height_above_ground:
        # If still below ground, do additional iterations with target clamped to minimum height
        target_kp_pos_constrained[2] = min_height_above_ground
        for _ in range(10):  # Additional iterations to enforce constraint
            link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            configuration.update(mj_data.qpos)
            
            kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.0)
            kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
            kp_quat = R.from_matrix(kp_rot).as_quat()
            kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[kp_quat[3], kp_quat[0], kp_quat[1], kp_quat[2]], target_kp_pos_constrained])))
            
            gripper_quat = R.from_matrix(target_gripper_rot).as_quat()
            gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
            gripper_pos = mj_data.xpos[gripper_body_id].copy()
            gripper_task = mink.FrameTask("Fixed_Jaw", "body", position_cost=0.0, orientation_cost=.1)
            gripper_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[gripper_quat[3], gripper_quat[0], gripper_quat[1], gripper_quat[2]], gripper_pos])))
            
            posture_task = mink.PostureTask(mj_model, cost=1e-3)
            posture_task.set_target(mj_data.qpos)
            
            vel = mink.solve_ik(configuration, [kp_task, gripper_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
            configuration.integrate_inplace(vel, 0.01)
            mj_data.qpos[:] = configuration.q
            mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
            mujoco.mj_step(mj_model, mj_data)
            mujoco.mj_forward(mj_model, mj_data)
            
            final_kp_pos = mj_data.xpos[kp_body_id].copy()
            if final_kp_pos[2] >= min_height_above_ground:
                break

# Commented out: keypoint-based IK (replaced by ik_to_targ_se3)
# def ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data, target_rot=None):
#     """Solve IK to move virtual_gripper_keypoint to target position and optionally orientation.
#     
#     Args:
#         target_pos: (3,) target position
#         configuration: mink.Configuration object
#         robot_config: Robot configuration
#         mj_model: MuJoCo model
#         mj_data: MuJoCo data
#         target_rot: (3, 3) target rotation matrix (optional). If None, keeps current orientation.
#     """
#     for _ in range(50):
#         from exo_utils import get_link_poses_from_robot, position_exoskeleton_meshes
#         link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
#         position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
#         mujoco.mj_forward(mj_model, mj_data)
#         configuration.update(mj_data.qpos)
#         kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
#         
#         if target_rot is not None:
#             # Use provided target rotation
#             target_quat = R.from_matrix(target_rot).as_quat()
#         else:
#             # Keep current orientation
#             kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
#             target_quat = R.from_matrix(kp_rot).as_quat()
#         
#         kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=1)
#         kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[target_quat[3], target_quat[0], target_quat[1], target_quat[2]], target_pos])))
#         posture_task = mink.PostureTask(mj_model, cost=1e-3)
#         posture_task.set_target(mj_data.qpos)
#         vel = mink.solve_ik(configuration, [kp_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
#         configuration.integrate_inplace(vel, 0.01)
#         mj_data.qpos[:] = configuration.q
#         mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
#         mujoco.mj_step(mj_model, mj_data)

# ========== Visualization Functions ==========

def visualize_predictions(
    rgb_lowres, dino_vis, 
    trajectory_points_lowres, trajectory_points_patches,
    predicted_trajectory_lowres, predicted_trajectory_patches,
    current_kp_2d_lowres, current_kp_2d_patches,
    attention_scores, H_patches, W_patches,
    episode_id, start_idx, window_size=10,
    fig=None, axes_dict=None
):
    """
    Create 4x4 grid visualization of predictions.
    
    Args:
        rgb_lowres: (H, W, 3) RGB image at low resolution
        dino_vis: (H_patches, W_patches, 3) DINO visualization
        trajectory_points_lowres: (N, 2) GT trajectory in low-res coordinates
        trajectory_points_patches: (N, 2) GT trajectory in patch coordinates
        predicted_trajectory_lowres: (M, 2) Predicted trajectory in low-res coordinates
        predicted_trajectory_patches: (M, 2) Predicted trajectory in patch coordinates
        current_kp_2d_lowres: (2,) Current EEF position in low-res coordinates
        current_kp_2d_patches: (2,) Current EEF position in patch coordinates
        attention_scores: (window_size, num_patches) Attention scores
        H_patches, W_patches: Patch grid dimensions
        episode_id: Episode identifier string
        start_idx: Start frame index
        window_size: Number of future timesteps
        fig: Optional figure to plot on (for live updates)
        axes_dict: Optional dict of axes (for live updates)
    
    Returns:
        fig, axes_dict: Figure and axes dict for live updates
    """
    RES_LOW = rgb_lowres.shape[0]
    
    # Create figure if not provided
    if fig is None:
        # Use 5x5 grid (25 positions) to fit: 3 initial panes + 10 one-hot + 10 attention = 23 positions
        fig = plt.figure(figsize=(20, 20))
        gs = GridSpec(5, 5, figure=fig, hspace=0.3, wspace=0.3)
        axes_dict = {}
        
        # Position (0,0): Low-res image
        axes_dict['rgb'] = fig.add_subplot(gs[0, 0])
        
        # Position (0,1): DINO features
        axes_dict['dino'] = fig.add_subplot(gs[0, 1])
        
        # Position (0,2): DINO features with one-hot pixels overlaid
        axes_dict['dino_onehot'] = fig.add_subplot(gs[0, 2])
        
        # Remaining positions: Alternate one-hot and attention maps for all timesteps
        # Start at position (0,3) and go row by row
        grid_idx = 3  # Start after the 3 initial panes (positions 0, 1, 2)
        for t in range(window_size):
            row = grid_idx // 5
            col = grid_idx % 5
            if row >= 5:  # Check if we've exceeded grid bounds
                break
            axes_dict[f'onehot_{t}'] = fig.add_subplot(gs[row, col])
            grid_idx += 1
            
            # Add attention map right after one-hot
            row = grid_idx // 5
            col = grid_idx % 5
            if row >= 5:  # Check if we've exceeded grid bounds
                break
            axes_dict[f'attention_{t}'] = fig.add_subplot(gs[row, col])
            grid_idx += 1
    
    # Update RGB visualization
    ax1 = axes_dict['rgb']
    ax1.clear()
    ax1.imshow(rgb_lowres)
    if trajectory_points_lowres is not None and len(trajectory_points_lowres) > 0:
        ax1.plot(trajectory_points_lowres[:, 0], trajectory_points_lowres[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT Trajectory')
        for i, (x, y) in enumerate(trajectory_points_lowres):
            color = plt.cm.viridis(i / len(trajectory_points_lowres))
            ax1.plot(x, y, 'o', color=color, markersize=5, markeredgecolor='white', markeredgewidth=0.5)
        
        if predicted_trajectory_lowres is not None and len(predicted_trajectory_lowres) > 0:
            ax1.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
            for i, (x, y) in enumerate(predicted_trajectory_lowres):
                color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
                ax1.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    
    if current_kp_2d_lowres is not None:
        ax1.plot(current_kp_2d_lowres[0], current_kp_2d_lowres[1], 'ro', markersize=8, 
                 markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax1.set_title(f'Low-Res Image ({RES_LOW}x{RES_LOW})\n{episode_id} - Frame {start_idx}', fontsize=10, fontweight='bold')
    ax1.axis('off')
    ax1.legend(loc='upper right', fontsize=8)
    
    # Update DINO visualization
    ax2 = axes_dict['dino']
    ax2.clear()
    ax2.imshow(dino_vis)
    if trajectory_points_patches is not None and len(trajectory_points_patches) > 0:
        ax2.plot(trajectory_points_patches[:, 0], trajectory_points_patches[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT Trajectory')
        for i, (x, y) in enumerate(trajectory_points_patches):
            color = plt.cm.viridis(i / len(trajectory_points_patches))
            ax2.plot(x, y, 'o', color=color, markersize=5, markeredgecolor='white', markeredgewidth=0.5)
        
        if predicted_trajectory_patches is not None and len(predicted_trajectory_patches) > 0:
            ax2.plot(predicted_trajectory_patches[:, 0], predicted_trajectory_patches[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
            for i, (x, y) in enumerate(predicted_trajectory_patches):
                color = plt.cm.plasma(i / len(predicted_trajectory_patches))
                ax2.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    
    if current_kp_2d_patches is not None:
        ax2.plot(current_kp_2d_patches[0], current_kp_2d_patches[1], 'ro', markersize=8,
                 markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax2.set_title(f'DINO Patch Features ({H_patches}x{W_patches})\n{episode_id} - Frame {start_idx}', fontsize=10, fontweight='bold')
    ax2.axis('off')
    ax2.legend(loc='upper right', fontsize=8)
    
    # Update DINO with one-hot pixels overlaid
    if 'dino_onehot' in axes_dict:
        ax3 = axes_dict['dino_onehot']
        ax3.clear()
        ax3.imshow(dino_vis)
        
        # Create overlay image for one-hot pixels
        overlay = np.zeros((H_patches, W_patches, 3), dtype=np.float32)
        
        # Overlay all GT one-hot pixels (white)
        if trajectory_points_patches is not None and len(trajectory_points_patches) > 0:
            for t in range(len(trajectory_points_patches)):
                kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
                patch_x_gt = int(np.round(np.clip(kp_x, 0, W_patches - 1)))
                patch_y_gt = int(np.round(np.clip(kp_y, 0, H_patches - 1)))
                overlay[patch_y_gt, patch_x_gt, :] = [1.0, 1.0, 1.0]  # White
        
        # Overlay all predicted one-hot pixels (red)
        if predicted_trajectory_patches is not None and len(predicted_trajectory_patches) > 0:
            for t in range(len(predicted_trajectory_patches)):
                patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
                patch_x_pred = int(np.round(np.clip(patch_x_pred, 0, W_patches - 1)))
                patch_y_pred = int(np.round(np.clip(patch_y_pred, 0, H_patches - 1)))
                overlay[patch_y_pred, patch_x_pred, :] = [1.0, 0.0, 0.0]  # Red
        
        # Blend overlay with DINO vis (alpha blending)
        alpha = 0.7
        blended = dino_vis * (1 - alpha) + overlay * alpha
        ax3.imshow(blended)
        ax3.set_title(f'DINO + One-hot Pixels\n(White=GT, Red=Pred)', fontsize=10, fontweight='bold')
        ax3.axis('off')
    
    # Update one-hot visualizations (prioritized - show all timesteps, but limit to 10 for display)
    for t in range(min(max_timesteps, 10)):
        if f'onehot_{t}' not in axes_dict:
            continue  # Skip if we don't have an axis for this timestep (due to grid space)
        
        ax_onehot = axes_dict[f'onehot_{t}']
        ax_onehot.clear()
        onehot_img = np.zeros((H_patches, W_patches, 3), dtype=np.float32)
        
        # GT one-hot (white)
        if trajectory_points_patches is not None and t < len(trajectory_points_patches):
            kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
            patch_x_gt = int(np.round(np.clip(kp_x, 0, W_patches - 1)))
            patch_y_gt = int(np.round(np.clip(kp_y, 0, H_patches - 1)))
            onehot_img[patch_y_gt, patch_x_gt, :] = [1.0, 1.0, 1.0]  # White
        
        # Predicted one-hot (red)
        if predicted_trajectory_patches is not None and t < len(predicted_trajectory_patches):
            patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
            patch_x_pred = int(np.round(np.clip(patch_x_pred, 0, W_patches - 1)))
            patch_y_pred = int(np.round(np.clip(patch_y_pred, 0, H_patches - 1)))
            onehot_img[patch_y_pred, patch_x_pred, :] = [1.0, 0.0, 0.0]  # Red
        
        ax_onehot.imshow(onehot_img)
        ax_onehot.set_title(f'One-hot t+{t+1}\n(White=GT, Red=Pred)', fontsize=10)
        ax_onehot.axis('off')
    
    # Update attention maps (if we have space for them, limit to 10 for display)
    for t in range(min(max_timesteps, 10)):
        if f'attention_{t}' not in axes_dict:
            continue  # Skip if we don't have an axis for this timestep (due to grid space)
        
        # Attention map
        if attention_scores is not None and t < attention_scores.shape[0]:
            attention_map = attention_scores[t].reshape(H_patches, W_patches)
            attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
            
            ax_attn = axes_dict[f'attention_{t}']
            ax_attn.clear()
            ax_attn.imshow(attention_map_norm, cmap='hot', vmin=0, vmax=1)
            ax_attn.set_title(f'Attention t+{t+1}', fontsize=10)
            ax_attn.axis('off')
    
    return fig, axes_dict

def visualize_training_sample(
    dino_tokens_sample, groundplane_coords_sample, current_eef_pos_sample, 
    onehot_targets_sample, heights_sample, grippers_sample, seq_id_sample,
    attention_scores, heights_pred_sample, grippers_pred_sample, H_patches, W_patches, max_timesteps,
    ax_vis, ax_attn, ax_height, ax_gripper, epoch, volume_mask_sample=None
):
    """
    Visualize a training sample during training.
    
    Args:
        dino_tokens_sample: (num_patches, dino_feat_dim) DINO tokens
        groundplane_coords_sample: (num_patches, 2) Ground-plane XZ coordinates
        current_eef_pos_sample: (2,) Current EEF position (optional, can be None)
        onehot_targets_sample: (max_timesteps, num_patches) One-hot targets
        heights_sample: (max_timesteps,) GT heights
        seq_id_sample: Episode ID string
        attention_scores: (max_timesteps, num_patches) Attention scores
        heights_pred_sample: (max_timesteps,) Predicted heights
        H_patches, W_patches: Patch grid dimensions
        max_timesteps: Maximum number of timesteps
        ax_vis, ax_attn, ax_height: Matplotlib axes
        epoch: Current epoch number
    """
    # Get predicted patch indices
    predicted_patch_indices = attention_scores.argmax(axis=1)  # (max_timesteps,)
    predicted_trajectory_patches = []
    for idx in predicted_patch_indices:
        patch_y = idx // W_patches
        patch_x = idx % W_patches
        predicted_trajectory_patches.append([patch_x, patch_y])
    predicted_trajectory_patches = np.array(predicted_trajectory_patches)
    
    # Get GT trajectory patches
    gt_patch_indices = onehot_targets_sample.argmax(dim=1).numpy()  # (max_timesteps,)
    trajectory_points_patches = []
    for idx in gt_patch_indices:
        patch_y = idx // W_patches
        patch_x = idx % W_patches
        trajectory_points_patches.append([patch_x, patch_y])
    trajectory_points_patches = np.array(trajectory_points_patches)
    
    # Current EEF position is optional (not used for groundplane model)
    if current_eef_pos_sample is not None:
        current_kp_2d_patches = current_eef_pos_sample.numpy()
    else:
        # Use first GT trajectory point as reference
        current_kp_2d_patches = trajectory_points_patches[0] if len(trajectory_points_patches) > 0 else np.array([W_patches//2, H_patches//2])
    
    # Create DINO vis
    dino_vis = dino_tokens_sample[:, :3].view(H_patches, W_patches, 3).numpy()
    # Normalize dino_vis
    for i in range(3):
        channel = dino_vis[:, :, i]
        min_val, max_val = channel.min(), channel.max()
        if max_val > min_val:
            dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
        else:
            dino_vis[:, :, i] = 0.5
    dino_vis = np.clip(dino_vis, 0, 1)
    
    # Visualize DINO features with trajectories
    ax_vis.clear()
    ax_vis.imshow(dino_vis)
    
    if len(trajectory_points_patches) > 0:
        # Draw GT trajectory
        ax_vis.plot(trajectory_points_patches[:, 0], trajectory_points_patches[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT')
        for i, (x, y) in enumerate(trajectory_points_patches):
            color = plt.cm.viridis(i / len(trajectory_points_patches))
            ax_vis.plot(x, y, 'o', color=color, markersize=4, markeredgecolor='white', markeredgewidth=0.5)
        
        # Draw predicted trajectory
        if len(predicted_trajectory_patches) > 0:
            ax_vis.plot(predicted_trajectory_patches[:, 0], predicted_trajectory_patches[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred')
            for i, (x, y) in enumerate(predicted_trajectory_patches):
                color = plt.cm.plasma(i / len(predicted_trajectory_patches))
                ax_vis.plot(x, y, 'x', color=color, markersize=5, markeredgewidth=1)
    
    if current_kp_2d_patches is not None:
        ax_vis.plot(current_kp_2d_patches[0], current_kp_2d_patches[1], 'ro', markersize=6,
                   markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    
    ax_vis.set_title(f'{seq_id_sample} | Epoch {epoch+1} | DINO Patches ({H_patches}x{W_patches})', fontsize=10)
    ax_vis.legend(loc='upper right', fontsize=8)
    ax_vis.axis('off')
    
    # Create attention maps and one-hot visualizations grid
    attention_imgs = []
    onehot_imgs = []
    
    # Get volume mask if provided
    volume_mask_2d = None
    if volume_mask_sample is not None:
        if isinstance(volume_mask_sample, torch.Tensor):
            volume_mask_2d = volume_mask_sample.view(H_patches, W_patches).numpy() > 0.5
        else:
            volume_mask_2d = volume_mask_sample.reshape(H_patches, W_patches) > 0.5
    
    for t in range(min(max_timesteps, 10)):  # Show up to 10 timesteps
        # Attention map
        if t < attention_scores.shape[0]:
            attention_map = attention_scores[t].reshape(H_patches, W_patches)
            
            # Apply volume mask visualization: fill masked regions with minimum valid value
            if volume_mask_2d is not None and np.any(volume_mask_2d):
                min_valid_value = attention_map[volume_mask_2d].min()
                attention_map[~volume_mask_2d] = min_valid_value
            
            attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
            # Convert to RGB for make_grid (hot colormap)
            attention_rgb = plt.cm.hot(attention_map_norm)[:, :, :3]  # (H, W, 3)
            
            # Overlay volume mask as a semi-transparent overlay (cyan tint for masked regions)
            if volume_mask_2d is not None:
                # Create mask overlay: cyan tint for invalid regions
                mask_overlay = np.zeros_like(attention_rgb)
                mask_overlay[~volume_mask_2d] = [0.0, 1.0, 1.0]  # Cyan for masked regions
                attention_rgb = 0.7 * attention_rgb + 0.3 * mask_overlay  # Blend
            
            attention_tensor = torch.from_numpy(attention_rgb).permute(2, 0, 1).float()  # (3, H, W)
            attention_imgs.append(attention_tensor)
        
        # One-hot visualization
        onehot_img = np.zeros((3, H_patches, W_patches), dtype=np.float32)
        
        # GT one-hot (white)
        if t < len(trajectory_points_patches):
            kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
            patch_x_gt = int(np.round(np.clip(kp_x, 0, W_patches - 1)))
            patch_y_gt = int(np.round(np.clip(kp_y, 0, H_patches - 1)))
            onehot_img[:, patch_y_gt, patch_x_gt] = 1.0  # White
        
        # Predicted one-hot (red)
        if t < len(predicted_trajectory_patches):
            patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
            patch_x_pred = int(np.round(np.clip(patch_x_pred, 0, W_patches - 1)))
            patch_y_pred = int(np.round(np.clip(patch_y_pred, 0, H_patches - 1)))
            onehot_img[0, patch_y_pred, patch_x_pred] = 1.0  # Red channel
            onehot_img[1, patch_y_pred, patch_x_pred] = 0.0
            onehot_img[2, patch_y_pred, patch_x_pred] = 0.0
        
        onehot_tensor = torch.from_numpy(onehot_img).float()
        onehot_imgs.append(onehot_tensor)
    
    # Create grids
    if len(attention_imgs) > 0:
        attention_grid = make_grid(attention_imgs, nrow=5, padding=2, pad_value=0.5)  # (3, H_grid, W_grid)
        attention_grid_np = attention_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    if len(onehot_imgs) > 0:
        onehot_grid = make_grid(onehot_imgs, nrow=5, padding=2, pad_value=0.0)  # (3, H_grid, W_grid)
        onehot_grid_np = onehot_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    # Stack attention and one-hot grids vertically
    ax_attn.clear()
    if len(attention_imgs) > 0 and len(onehot_imgs) > 0:
        # Resize to same width for stacking
        target_width = max(attention_grid_np.shape[1], onehot_grid_np.shape[1])
        attention_resized = cv2.resize(attention_grid_np, (target_width, attention_grid_np.shape[0]), interpolation=cv2.INTER_LINEAR)
        onehot_resized = cv2.resize(onehot_grid_np, (target_width, onehot_grid_np.shape[0]), interpolation=cv2.INTER_LINEAR)
        combined = np.vstack([attention_resized, onehot_resized])
        ax_attn.imshow(combined)
        ax_attn.set_title(f'Attention Maps (top) | One-hot: White=GT, Red=Pred (bottom)', fontsize=10)
    elif len(attention_imgs) > 0:
        ax_attn.imshow(attention_grid_np)
        mask_label = " (Volume Masked)" if volume_mask_2d is not None else ""
        ax_attn.set_title(f'Attention Maps{mask_label}', fontsize=10)
    elif len(onehot_imgs) > 0:
        ax_attn.imshow(onehot_grid_np)
        ax_attn.set_title('One-hot: White=GT, Red=Pred', fontsize=10)
    ax_attn.axis('off')

    # Height bar chart
    ax_height.clear()
    heights_gt_np = heights_sample.numpy()
    heights_pred_denorm = heights_pred_sample * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    heights_gt_denorm = heights_gt_np * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    timesteps = np.arange(1, len(heights_gt_denorm) + 1)
    ax_height.bar(timesteps - 0.2, heights_gt_denorm, width=0.4, alpha=0.6, color='green', label='GT Height')
    ax_height.bar(timesteps + 0.2, heights_pred_denorm[:len(heights_gt_denorm)], width=0.4, alpha=0.6, color='red', label='Pred Height')
    ax_height.set_xlabel('Timestep', fontsize=9)
    ax_height.set_ylabel('Height (m)', fontsize=9)
    ax_height.set_title('Height Trajectory', fontsize=10)
    ax_height.legend(fontsize=8)
    ax_height.grid(alpha=0.3)
    
    # Gripper bar chart
    ax_gripper.clear()
    grippers_gt_np = grippers_sample.numpy() if isinstance(grippers_sample, torch.Tensor) else grippers_sample
    grippers_pred_np = grippers_pred_sample if isinstance(grippers_pred_sample, np.ndarray) else grippers_pred_sample
    timesteps_gripper = np.arange(1, len(grippers_gt_np) + 1)
    ax_gripper.bar(timesteps_gripper - 0.2, grippers_gt_np, width=0.4, alpha=0.6, color='blue', label='GT Gripper')
    ax_gripper.bar(timesteps_gripper + 0.2, grippers_pred_np[:len(grippers_gt_np)], width=0.4, alpha=0.6, color='orange', label='Pred Gripper')
    ax_gripper.set_xlabel('Timestep', fontsize=9)
    ax_gripper.set_ylabel('Gripper Value', fontsize=9)
    ax_gripper.set_title('Gripper Open/Close Trajectory', fontsize=10)
    ax_gripper.legend(fontsize=8)
    ax_gripper.grid(alpha=0.3)
