"""Visualization utilities for token selection model."""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import torch
import cv2
from torchvision.utils import make_grid
from pathlib import Path
from model import MIN_HEIGHT, MAX_HEIGHT

def visualize_predictions(
    rgb_lowres, dino_vis, 
    trajectory_points_lowres, trajectory_points_patches,
    predicted_trajectory_lowres, predicted_trajectory_patches,
    current_kp_2d_lowres, current_kp_2d_patches,
    attention_scores, H_patches, W_patches,
    episode_id, start_idx, window_size=10,
    fig=None, axes_dict=None
):
    """
    Create 4x4 grid visualization of predictions.
    
    Args:
        rgb_lowres: (H, W, 3) RGB image at low resolution
        dino_vis: (H_patches, W_patches, 3) DINO visualization
        trajectory_points_lowres: (N, 2) GT trajectory in low-res coordinates
        trajectory_points_patches: (N, 2) GT trajectory in patch coordinates
        predicted_trajectory_lowres: (M, 2) Predicted trajectory in low-res coordinates
        predicted_trajectory_patches: (M, 2) Predicted trajectory in patch coordinates
        current_kp_2d_lowres: (2,) Current EEF position in low-res coordinates
        current_kp_2d_patches: (2,) Current EEF position in patch coordinates
        attention_scores: (window_size, num_patches) Attention scores
        H_patches, W_patches: Patch grid dimensions
        episode_id: Episode identifier string
        start_idx: Start frame index
        window_size: Number of future timesteps
        fig: Optional figure to plot on (for live updates)
        axes_dict: Optional dict of axes (for live updates)
    
    Returns:
        fig, axes_dict: Figure and axes dict for live updates
    """
    RES_LOW = rgb_lowres.shape[0]
    
    # Create figure if not provided
    if fig is None:
        # Use 3-row grid: Row 0 (RGB, DINO, DINO+onehot), Row 1 (Attention grid spanning all), Row 2 (Height chart spanning all)
        fig = plt.figure()
        gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3, 
                     height_ratios=[1, 2, 0.3], width_ratios=[1, 1, 1])
        axes_dict = {}
        
        # Row 0: First 3 panes
        axes_dict['rgb'] = fig.add_subplot(gs[0, 0])
        axes_dict['dino'] = fig.add_subplot(gs[0, 1])
        axes_dict['dino_onehot'] = fig.add_subplot(gs[0, 2])
        
        # Row 1: Attention maps grid spanning all columns
        axes_dict['attention_grid'] = fig.add_subplot(gs[1, :])
        
        # Row 2: Height chart spanning all columns
        axes_dict['height'] = fig.add_subplot(gs[2, :])
    
    # Update RGB visualization (resize to exactly 240x425)
    ax1 = axes_dict['rgb']
    ax1.clear()
    # Resize RGB to exactly 240x425
    rgb_new_h, rgb_new_w = 240, 425
    rgb_resized = cv2.resize(rgb_lowres, (rgb_new_w, rgb_new_h), interpolation=cv2.INTER_LINEAR)
    
    # Create overlay for one-hot pixels on RGB
    rgb_overlay = np.zeros((rgb_new_h, rgb_new_w, 3), dtype=np.float32)
    rgb_scale_x = rgb_new_w / RES_LOW
    rgb_scale_y = rgb_new_h / RES_LOW
    square_radius_rgb = max(2, int(2.5 * rgb_scale_x))  # Half size
    
    # Draw GT one-hot pixels (white) on RGB overlay
    if trajectory_points_lowres is not None and len(trajectory_points_lowres) > 0:
        for t in range(len(trajectory_points_lowres)):
            kp_x, kp_y = trajectory_points_lowres[t, 0], trajectory_points_lowres[t, 1]
            kp_x_scaled = kp_x * rgb_scale_x
            kp_y_scaled = kp_y * rgb_scale_y
            patch_x = int(np.round(np.clip(kp_x_scaled, 0, rgb_new_w - 1)))
            patch_y = int(np.round(np.clip(kp_y_scaled, 0, rgb_new_h - 1)))
            y_min = max(0, patch_y - square_radius_rgb)
            y_max = min(rgb_new_h, patch_y + square_radius_rgb + 1)
            x_min = max(0, patch_x - square_radius_rgb)
            x_max = min(rgb_new_w, patch_x + square_radius_rgb + 1)
            rgb_overlay[y_min:y_max, x_min:x_max, :] = [1.0, 1.0, 1.0]  # White
    
    # Draw predicted one-hot pixels (red) on RGB overlay
    if predicted_trajectory_lowres is not None and len(predicted_trajectory_lowres) > 0:
        for t in range(len(predicted_trajectory_lowres)):
            kp_x, kp_y = predicted_trajectory_lowres[t, 0], predicted_trajectory_lowres[t, 1]
            kp_x_scaled = kp_x * rgb_scale_x
            kp_y_scaled = kp_y * rgb_scale_y
            patch_x = int(np.round(np.clip(kp_x_scaled, 0, rgb_new_w - 1)))
            patch_y = int(np.round(np.clip(kp_y_scaled, 0, rgb_new_h - 1)))
            y_min = max(0, patch_y - square_radius_rgb)
            y_max = min(rgb_new_h, patch_y + square_radius_rgb + 1)
            x_min = max(0, patch_x - square_radius_rgb)
            x_max = min(rgb_new_w, patch_x + square_radius_rgb + 1)
            rgb_overlay[y_min:y_max, x_min:x_max, :] = [1.0, 0.0, 0.0]  # Red
    
    # Blend RGB with overlay
    alpha_rgb = 0.7
    rgb_blended = rgb_resized.astype(np.float32) / 255.0 * (1 - alpha_rgb) + rgb_overlay * alpha_rgb
    ax1.imshow(rgb_blended)
    
    # Rescale trajectory points to resized RGB coordinates for plotting
    if trajectory_points_lowres is not None and len(trajectory_points_lowres) > 0:
        traj_scaled = trajectory_points_lowres.copy()
        traj_scaled[:, 0] *= rgb_scale_x
        traj_scaled[:, 1] *= rgb_scale_y
        ax1.plot(traj_scaled[:, 0], traj_scaled[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT Trajectory')
        for i, (x, y) in enumerate(traj_scaled):
            color = plt.cm.viridis(i / len(traj_scaled))
            ax1.plot(x, y, 'o', color=color, markersize=5, markeredgecolor='white', markeredgewidth=0.5)
        
        if predicted_trajectory_lowres is not None and len(predicted_trajectory_lowres) > 0:
            pred_scaled = predicted_trajectory_lowres.copy()
            pred_scaled[:, 0] *= rgb_scale_x
            pred_scaled[:, 1] *= rgb_scale_y
            ax1.plot(pred_scaled[:, 0], pred_scaled[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
            for i, (x, y) in enumerate(pred_scaled):
                color = plt.cm.plasma(i / len(pred_scaled))
                ax1.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    
    if current_kp_2d_lowres is not None:
        curr_scaled = current_kp_2d_lowres.copy()
        curr_scaled[0] *= rgb_scale_x
        curr_scaled[1] *= rgb_scale_y
        ax1.plot(curr_scaled[0], curr_scaled[1], 'ro', markersize=8, 
                 markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax1.set_title(f'RGB Image ({rgb_new_w}x{rgb_new_h})\n{episode_id} - Frame {start_idx}', fontsize=10, fontweight='bold')
    ax1.axis('off')
    ax1.legend(loc='upper right', fontsize=8)
    
    # Update DINO visualization (resize to exactly 240x425)
    ax2 = axes_dict['dino']
    ax2.clear()
    # Resize DINO vis to exactly 240x425
    dino_h, dino_w = dino_vis.shape[:2]
    dino_new_h, dino_new_w = 240, 425
    dino_upscaled = cv2.resize(dino_vis, (dino_new_w, dino_new_h), interpolation=cv2.INTER_LINEAR)
    ax2.imshow(dino_upscaled)
    # Scale trajectory points to upscaled DINO coordinates
    dino_scale_x = dino_new_w / dino_w
    dino_scale_y = dino_new_h / dino_h
    if trajectory_points_patches is not None and len(trajectory_points_patches) > 0:
        traj_scaled = trajectory_points_patches.copy()
        traj_scaled[:, 0] *= dino_scale_x
        traj_scaled[:, 1] *= dino_scale_y
        ax2.plot(traj_scaled[:, 0], traj_scaled[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT Trajectory')
        for i, (x, y) in enumerate(traj_scaled):
            color = plt.cm.viridis(i / len(traj_scaled))
            ax2.plot(x, y, 'o', color=color, markersize=5, markeredgecolor='white', markeredgewidth=0.5)
        
        if predicted_trajectory_patches is not None and len(predicted_trajectory_patches) > 0:
            pred_scaled = predicted_trajectory_patches.copy()
            pred_scaled[:, 0] *= dino_scale_x
            pred_scaled[:, 1] *= dino_scale_y
            ax2.plot(pred_scaled[:, 0], pred_scaled[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
            for i, (x, y) in enumerate(pred_scaled):
                color = plt.cm.plasma(i / len(pred_scaled))
                ax2.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    
    if current_kp_2d_patches is not None:
        curr_scaled = current_kp_2d_patches.copy()
        curr_scaled[0] *= dino_scale_x
        curr_scaled[1] *= dino_scale_y
        ax2.plot(curr_scaled[0], curr_scaled[1], 'ro', markersize=8,
                 markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax2.set_title(f'DINO Patch Features ({H_patches}x{W_patches})\n{episode_id} - Frame {start_idx}', fontsize=10, fontweight='bold')
    ax2.axis('off')
    ax2.legend(loc='upper right', fontsize=8)
    
    # Update DINO with one-hot pixels overlaid (upscale to match aspect ratio)
    if 'dino_onehot' in axes_dict:
        ax3 = axes_dict['dino_onehot']
        ax3.clear()
        # Use same upscaled DINO vis
        ax3.imshow(dino_upscaled)
        
        # Create overlay image for one-hot pixels (upscaled)
        overlay = np.zeros((dino_new_h, dino_new_w, 3), dtype=np.float32)
        
        # Draw square radius around pixels for visibility (size based on upscaled resolution)
        square_radius = max(2, int(2.5 * dino_scale_x))  # Scale the square size with upsampling (half size)
        
        # Overlay all GT one-hot pixels (white) - scale to upscaled coordinates
        if trajectory_points_patches is not None and len(trajectory_points_patches) > 0:
            for t in range(len(trajectory_points_patches)):
                kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
                kp_x_scaled = kp_x * dino_scale_x
                kp_y_scaled = kp_y * dino_scale_y
                patch_x_gt = int(np.round(np.clip(kp_x_scaled, 0, dino_new_w - 1)))
                patch_y_gt = int(np.round(np.clip(kp_y_scaled, 0, dino_new_h - 1)))
                # Draw square around the pixel
                y_min = max(0, patch_y_gt - square_radius)
                y_max = min(dino_new_h, patch_y_gt + square_radius + 1)
                x_min = max(0, patch_x_gt - square_radius)
                x_max = min(dino_new_w, patch_x_gt + square_radius + 1)
                overlay[y_min:y_max, x_min:x_max, :] = [1.0, 1.0, 1.0]  # White
        
        # Overlay all predicted one-hot pixels (red)
        if predicted_trajectory_patches is not None and len(predicted_trajectory_patches) > 0:
            for t in range(len(predicted_trajectory_patches)):
                patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
                patch_x_pred_scaled = patch_x_pred * dino_scale_x
                patch_y_pred_scaled = patch_y_pred * dino_scale_y
                patch_x_pred = int(np.round(np.clip(patch_x_pred_scaled, 0, dino_new_w - 1)))
                patch_y_pred = int(np.round(np.clip(patch_y_pred_scaled, 0, dino_new_h - 1)))
                # Draw square around the pixel
                y_min = max(0, patch_y_pred - square_radius)
                y_max = min(dino_new_h, patch_y_pred + square_radius + 1)
                x_min = max(0, patch_x_pred - square_radius)
                x_max = min(dino_new_w, patch_x_pred + square_radius + 1)
                overlay[y_min:y_max, x_min:x_max, :] = [1.0, 0.0, 0.0]  # Red
        
        # Blend overlay with DINO vis (alpha blending)
        alpha = 0.7
        blended = dino_upscaled * (1 - alpha) + overlay * alpha
        ax3.imshow(blended)
        ax3.set_title(f'DINO + One-hot Pixels\n(White=GT, Red=Pred)', fontsize=10, fontweight='bold')
        ax3.axis('off')
    
    # Create attention and one-hot grids using make_grid
    attention_imgs = []
    onehot_imgs = []
    
    for t in range(window_size):
        # Attention map
        if attention_scores is not None and t < attention_scores.shape[0]:
            attention_map = attention_scores[t].reshape(H_patches, W_patches)
            attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
            # Convert to RGB for make_grid (hot colormap)
            attention_rgb = plt.cm.hot(attention_map_norm)[:, :, :3]  # (H, W, 3)
            attention_tensor = torch.from_numpy(attention_rgb).permute(2, 0, 1).float()  # (3, H, W)
            attention_imgs.append(attention_tensor)
        
        # One-hot visualization
        onehot_img = np.zeros((3, H_patches, W_patches), dtype=np.float32)
        
        # GT one-hot (white)
        if trajectory_points_patches is not None and t < len(trajectory_points_patches):
            kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
            patch_x_gt = int(np.round(np.clip(kp_x, 0, W_patches - 1)))
            patch_y_gt = int(np.round(np.clip(kp_y, 0, H_patches - 1)))
            onehot_img[:, patch_y_gt, patch_x_gt] = 1.0  # White
        
        # Predicted one-hot (red)
        if predicted_trajectory_patches is not None and t < len(predicted_trajectory_patches):
            patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
            patch_x_pred = int(np.round(np.clip(patch_x_pred, 0, W_patches - 1)))
            patch_y_pred = int(np.round(np.clip(patch_y_pred, 0, H_patches - 1)))
            onehot_img[0, patch_y_pred, patch_x_pred] = 1.0  # Red channel
            onehot_img[1, patch_y_pred, patch_x_pred] = 0.0
            onehot_img[2, patch_y_pred, patch_x_pred] = 0.0
        
        onehot_tensor = torch.from_numpy(onehot_img).float()
        onehot_imgs.append(onehot_tensor)
    
    # Create grids
    if len(attention_imgs) > 0:
        attention_grid = make_grid(attention_imgs, nrow=5, padding=2, pad_value=0.5)  # (3, H_grid, W_grid)
        attention_grid_np = attention_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    if len(onehot_imgs) > 0:
        onehot_grid = make_grid(onehot_imgs, nrow=5, padding=2, pad_value=0.0)  # (3, H_grid, W_grid)
        onehot_grid_np = onehot_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    # Display attention grid and one-hot grid stacked (resize to fit)
    if 'attention_grid' in axes_dict:
        ax_attn = axes_dict['attention_grid']
        ax_attn.clear()
        if len(attention_imgs) > 0 and len(onehot_imgs) > 0:
            # Resize both grids to same width for stacking
            target_width = 1275
            attn_h, attn_w = attention_grid_np.shape[:2]
            onehot_h, onehot_w = onehot_grid_np.shape[:2]
            
            # Resize to target width, maintaining aspect ratio
            attn_new_h = int(attn_h * target_width / attn_w)
            onehot_new_h = int(onehot_h * target_width / onehot_w)
            
            # Resize attention grid
            attention_grid_resized = cv2.resize(attention_grid_np, (target_width, attn_new_h), interpolation=cv2.INTER_LINEAR)
            # Resize one-hot grid
            onehot_grid_resized = cv2.resize(onehot_grid_np, (target_width, onehot_new_h), interpolation=cv2.INTER_LINEAR)
            
            # Stack vertically: attention on top, one-hot on bottom
            combined = np.vstack([attention_grid_resized, onehot_grid_resized])
            ax_attn.imshow(combined)
            ax_attn.set_title('Attention Maps (top) | One-hot: White=GT, Red=Pred (bottom)', fontsize=10, fontweight='bold')
        elif len(attention_imgs) > 0:
            # Only attention maps available
            attn_new_h, attn_new_w = 280, 1275
            attention_grid_resized = cv2.resize(attention_grid_np, (attn_new_w, attn_new_h), interpolation=cv2.INTER_LINEAR)
            ax_attn.imshow(attention_grid_resized)
            ax_attn.set_title('Attention Maps (all timesteps)', fontsize=10, fontweight='bold')
        elif len(onehot_imgs) > 0:
            # Only one-hot available
            onehot_new_h, onehot_new_w = 280, 1275
            onehot_grid_resized = cv2.resize(onehot_grid_np, (onehot_new_w, onehot_new_h), interpolation=cv2.INTER_LINEAR)
            ax_attn.imshow(onehot_grid_resized)
            ax_attn.set_title('One-hot: White=GT, Red=Pred (all timesteps)', fontsize=10, fontweight='bold')
        ax_attn.axis('off')
    
    return fig, axes_dict

def visualize_training_sample(
    dino_tokens_sample, patch_positions_sample, current_eef_pos_sample, 
    onehot_targets_sample, heights_sample, seq_id_sample,
    attention_scores, heights_pred_sample, H_patches, W_patches, window_size,
    ax_vis, ax_attn, ax_height, epoch
):
    """
    Visualize a training sample during training.
    
    Args:
        dino_tokens_sample: (num_patches, dino_feat_dim) DINO tokens
        patch_positions_sample: (num_patches, 2) Patch positions
        current_eef_pos_sample: (2,) Current EEF position
        onehot_targets_sample: (window_size, num_patches) One-hot targets
        heights_sample: (window_size,) GT heights
        seq_id_sample: Episode ID string
        attention_scores: (window_size, num_patches) Attention scores
        heights_pred_sample: (window_size,) Predicted heights
        H_patches, W_patches: Patch grid dimensions
        window_size: Number of future timesteps
        ax_vis, ax_attn, ax_height: Matplotlib axes
        epoch: Current epoch number
    """
    # Get predicted patch indices
    predicted_patch_indices = attention_scores.argmax(axis=1)  # (window_size,)
    predicted_trajectory_patches = []
    for idx in predicted_patch_indices:
        patch_y = idx // W_patches
        patch_x = idx % W_patches
        predicted_trajectory_patches.append([patch_x, patch_y])
    predicted_trajectory_patches = np.array(predicted_trajectory_patches)
    
    # Get GT trajectory patches
    gt_patch_indices = onehot_targets_sample.argmax(dim=1).numpy()  # (window_size,)
    trajectory_points_patches = []
    for idx in gt_patch_indices:
        patch_y = idx // W_patches
        patch_x = idx % W_patches
        trajectory_points_patches.append([patch_x, patch_y])
    trajectory_points_patches = np.array(trajectory_points_patches)
    
    current_kp_2d_patches = current_eef_pos_sample.numpy()
    
    # Create DINO vis
    dino_vis = dino_tokens_sample[:, :3].view(H_patches, W_patches, 3).numpy()
    # Normalize dino_vis
    for i in range(3):
        channel = dino_vis[:, :, i]
        min_val, max_val = channel.min(), channel.max()
        if max_val > min_val:
            dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
        else:
            dino_vis[:, :, i] = 0.5
    dino_vis = np.clip(dino_vis, 0, 1)
    
    # Visualize DINO features with trajectories
    ax_vis.clear()
    ax_vis.imshow(dino_vis)
    
    if len(trajectory_points_patches) > 0:
        # Draw GT trajectory
        ax_vis.plot(trajectory_points_patches[:, 0], trajectory_points_patches[:, 1], 'b-', linewidth=2, alpha=0.7, label='GT')
        for i, (x, y) in enumerate(trajectory_points_patches):
            color = plt.cm.viridis(i / len(trajectory_points_patches))
            ax_vis.plot(x, y, 'o', color=color, markersize=4, markeredgecolor='white', markeredgewidth=0.5)
        
        # Draw predicted trajectory
        if len(predicted_trajectory_patches) > 0:
            ax_vis.plot(predicted_trajectory_patches[:, 0], predicted_trajectory_patches[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred')
            for i, (x, y) in enumerate(predicted_trajectory_patches):
                color = plt.cm.plasma(i / len(predicted_trajectory_patches))
                ax_vis.plot(x, y, 'x', color=color, markersize=5, markeredgewidth=1)
    
    if current_kp_2d_patches is not None:
        ax_vis.plot(current_kp_2d_patches[0], current_kp_2d_patches[1], 'ro', markersize=6,
                   markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    
    ax_vis.set_title(f'{seq_id_sample} | Epoch {epoch+1} | DINO Patches ({H_patches}x{W_patches})', fontsize=10)
    ax_vis.legend(loc='upper right', fontsize=8)
    ax_vis.axis('off')
    
    # Create attention maps and one-hot visualizations grid
    attention_imgs = []
    onehot_imgs = []
    
    for t in range(min(window_size, 10)):  # Show up to 10 timesteps
        # Attention map
        if t < attention_scores.shape[0]:
            attention_map = attention_scores[t].reshape(H_patches, W_patches)
            attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
            # Convert to RGB for make_grid (hot colormap)
            attention_rgb = plt.cm.hot(attention_map_norm)[:, :, :3]  # (H, W, 3)
            attention_tensor = torch.from_numpy(attention_rgb).permute(2, 0, 1).float()  # (3, H, W)
            attention_imgs.append(attention_tensor)
        
        # One-hot visualization
        onehot_img = np.zeros((3, H_patches, W_patches), dtype=np.float32)
        
        # GT one-hot (white)
        if t < len(trajectory_points_patches):
            kp_x, kp_y = trajectory_points_patches[t, 0], trajectory_points_patches[t, 1]
            patch_x_gt = int(np.round(np.clip(kp_x, 0, W_patches - 1)))
            patch_y_gt = int(np.round(np.clip(kp_y, 0, H_patches - 1)))
            onehot_img[:, patch_y_gt, patch_x_gt] = 1.0  # White
        
        # Predicted one-hot (red)
        if t < len(predicted_trajectory_patches):
            patch_x_pred, patch_y_pred = predicted_trajectory_patches[t, 0], predicted_trajectory_patches[t, 1]
            patch_x_pred = int(np.round(np.clip(patch_x_pred, 0, W_patches - 1)))
            patch_y_pred = int(np.round(np.clip(patch_y_pred, 0, H_patches - 1)))
            onehot_img[0, patch_y_pred, patch_x_pred] = 1.0  # Red channel
            onehot_img[1, patch_y_pred, patch_x_pred] = 0.0
            onehot_img[2, patch_y_pred, patch_x_pred] = 0.0
        
        onehot_tensor = torch.from_numpy(onehot_img).float()
        onehot_imgs.append(onehot_tensor)
    
    # Create grids
    if len(attention_imgs) > 0:
        attention_grid = make_grid(attention_imgs, nrow=5, padding=2, pad_value=0.5)  # (3, H_grid, W_grid)
        attention_grid_np = attention_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    if len(onehot_imgs) > 0:
        onehot_grid = make_grid(onehot_imgs, nrow=5, padding=2, pad_value=0.0)  # (3, H_grid, W_grid)
        onehot_grid_np = onehot_grid.permute(1, 2, 0).cpu().numpy()  # (H_grid, W_grid, 3)
    
    # Stack attention and one-hot grids vertically
            ax_attn.clear()
    if len(attention_imgs) > 0 and len(onehot_imgs) > 0:
        # Resize to same width for stacking
        target_width = max(attention_grid_np.shape[1], onehot_grid_np.shape[1])
        attention_resized = cv2.resize(attention_grid_np, (target_width, attention_grid_np.shape[0]), interpolation=cv2.INTER_LINEAR)
        onehot_resized = cv2.resize(onehot_grid_np, (target_width, onehot_grid_np.shape[0]), interpolation=cv2.INTER_LINEAR)
        combined = np.vstack([attention_resized, onehot_resized])
        ax_attn.imshow(combined)
        ax_attn.set_title(f'Attention Maps (top) | One-hot: White=GT, Red=Pred (bottom)', fontsize=10)
    elif len(attention_imgs) > 0:
        ax_attn.imshow(attention_grid_np)
        ax_attn.set_title('Attention Maps', fontsize=10)
    elif len(onehot_imgs) > 0:
        ax_attn.imshow(onehot_grid_np)
        ax_attn.set_title('One-hot: White=GT, Red=Pred', fontsize=10)
            ax_attn.axis('off')
    
    # Height bar chart
    ax_height.clear()
    heights_gt_np = heights_sample.numpy()
    heights_pred_denorm = heights_pred_sample * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    heights_gt_denorm = heights_gt_np * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    timesteps = np.arange(1, len(heights_gt_denorm) + 1)
    ax_height.bar(timesteps - 0.2, heights_gt_denorm, width=0.4, alpha=0.6, color='green', label='GT Height')
    ax_height.bar(timesteps + 0.2, heights_pred_denorm[:len(heights_gt_denorm)], width=0.4, alpha=0.6, color='red', label='Pred Height')
    ax_height.set_xlabel('Timestep', fontsize=9)
    ax_height.set_ylabel('Height (m)', fontsize=9)
    ax_height.set_title('Height Trajectory', fontsize=10)
    ax_height.legend(fontsize=8)
    ax_height.grid(alpha=0.3)

def visualize_evaluation_full(
    rgb_lowres, dino_vis,
    trajectory_points_lowres, trajectory_points_patches,
    predicted_trajectory_lowres, predicted_trajectory_patches,
    current_kp_2d_lowres, current_kp_2d_patches,
    attention_scores, H_patches, W_patches,
    heights_pred, heights_gt=None,
    episode_id="eval", start_idx=0, window_size=10,
    save_path=None
):
    """
    Full evaluation visualization including height chart.
    
    Args:
        rgb_lowres: (H, W, 3) RGB image at low resolution
        dino_vis: (H_patches, W_patches, 3) DINO visualization
        trajectory_points_lowres: (N, 2) GT trajectory in low-res coordinates (can be None)
        trajectory_points_patches: (N, 2) GT trajectory in patch coordinates (can be None)
        predicted_trajectory_lowres: (M, 2) Predicted trajectory in low-res coordinates
        predicted_trajectory_patches: (M, 2) Predicted trajectory in patch coordinates
        current_kp_2d_lowres: (2,) Current EEF position in low-res coordinates
        current_kp_2d_patches: (2,) Current EEF position in patch coordinates
        attention_scores: (window_size, num_patches) Attention scores
        H_patches, W_patches: Patch grid dimensions
        heights_pred: (window_size,) Predicted heights (normalized)
        heights_gt: (window_size,) GT heights (normalized, optional)
        episode_id: Episode identifier string
        start_idx: Start frame index
        window_size: Number of future timesteps
        save_path: Optional path to save figure
    
    Returns:
        fig: Figure object
    """
    fig, axes_dict = visualize_predictions(
        rgb_lowres, dino_vis,
        trajectory_points_lowres, trajectory_points_patches,
        predicted_trajectory_lowres, predicted_trajectory_patches,
        current_kp_2d_lowres, current_kp_2d_patches,
        attention_scores, H_patches, W_patches,
        episode_id, start_idx, window_size
    )
    
    # Add height bar chart in the bottom row
    if 'height' in axes_dict:
        ax_height = axes_dict['height']
        ax_height.clear()
        heights_pred_denorm = heights_pred * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
        timesteps = np.arange(1, len(heights_pred_denorm) + 1)
        
        if heights_gt is not None and len(heights_gt) > 0:
            heights_gt_denorm = heights_gt * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
            ax_height.bar(timesteps - 0.2, heights_gt_denorm, width=0.4, alpha=0.6, color='green', label='GT Height')
            ax_height.bar(timesteps + 0.2, heights_pred_denorm[:len(heights_gt_denorm)], width=0.4, alpha=0.6, color='red', label='Pred Height')
        else:
            ax_height.bar(timesteps, heights_pred_denorm, width=0.4, alpha=0.6, color='red', label='Pred Height')
        
        ax_height.set_xlabel('Timestep', fontsize=9)
        ax_height.set_ylabel('Height (m)', fontsize=9)
        ax_height.set_title('Height Trajectory', fontsize=10)
        ax_height.legend(fontsize=8)
        ax_height.grid(alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✓ Saved {save_path}")
    
    return fig
