"""Visualize gripper joint values across episodes by displaying frames in a grid."""
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

# Configuration
DATASET_ROOT = "scratch/parsed_moredata_pickplace_home"
MAX_EPISODES = 8  # Number of episodes to visualize
MAX_FRAMES_PER_EPISODE = 30  # Limit frames per episode to keep grid manageable
IMAGE_RESIZE = (160, 120)  # Low resolution for grid display

def main():
    dataset_root = Path(DATASET_ROOT)
    
    if not dataset_root.exists():
        print(f"Dataset directory not found: {dataset_root}")
        return
    
    # Find all episode directories
    episode_dirs = sorted([d for d in dataset_root.iterdir() 
                          if d.is_dir() and d.name.startswith("episode_")])
    
    if len(episode_dirs) == 0:
        print(f"No episodes found in {dataset_root}")
        return
    
    print(f"Found {len(episode_dirs)} episodes")
    print(f"Visualizing {MAX_EPISODES} episodes...")
    
    # Collect frame data for each episode
    episode_data = []
    
    for episode_idx, episode_dir in enumerate(episode_dirs[:MAX_EPISODES]):
        # Find all frame images
        frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
        frame_indices = [int(f.stem) for f in frame_files]
        
        frames_info = []
        
        for frame_idx in frame_indices[:MAX_FRAMES_PER_EPISODE]:
            frame_str = f"{frame_idx:06d}"
            image_path = episode_dir / f"{frame_str}.png"
            joint_state_path = episode_dir / f"{frame_str}.npy"
            
            if image_path.exists() and joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])  # Last value is gripper
                frames_info.append({
                    'frame_idx': frame_idx,
                    'image_path': image_path,
                    'gripper_value': gripper_value
                })
        
        if len(frames_info) > 0:
            episode_data.append({
                'episode_id': episode_dir.name,
                'frames': frames_info
            })
            gripper_vals = [f['gripper_value'] for f in frames_info]
            print(f"  {episode_dir.name}: {len(frames_info)} frames, gripper range: [{min(gripper_vals):.3f}, {max(gripper_vals):.3f}]")
    
    if len(episode_data) == 0:
        print("No frames found in any episode")
        return
    
    # Determine grid layout: one row per episode
    n_episodes = len(episode_data)
    max_frames = max(len(ep['frames']) for ep in episode_data)
    
    # Create figure with subplots: n_episodes rows, max_frames columns
    fig, axes = plt.subplots(n_episodes, max_frames, figsize=(max_frames * 1.5, n_episodes * 2))
    if n_episodes == 1:
        axes = axes.reshape(1, -1)
    if max_frames == 1:
        axes = axes.reshape(-1, 1)
    
    # Display frames
    for ep_idx, ep_data in enumerate(episode_data):
        for frame_idx, frame_info in enumerate(ep_data['frames']):
            ax = axes[ep_idx, frame_idx]
            
            # Load and resize image
            try:
                img = Image.open(frame_info['image_path'])
                img_resized = img.resize(IMAGE_RESIZE, Image.Resampling.LANCZOS)
                img_array = np.array(img_resized)
                
                ax.imshow(img_array)
                ax.axis('off')
                
                # Set title to gripper value
                gripper_val = frame_info['gripper_value']
                ax.set_title(f"{gripper_val:.3f}", fontsize=8, pad=2)
            except Exception as e:
                ax.text(0.5, 0.5, f"Error\n{str(e)}", ha='center', va='center', fontsize=6)
                ax.axis('off')
        
        # Hide unused subplots in this row
        for frame_idx in range(len(ep_data['frames']), max_frames):
            axes[ep_idx, frame_idx].axis('off')
    
    # Add episode labels on the left
    for ep_idx, ep_data in enumerate(episode_data):
        fig.text(0.01, 1.0 - (ep_idx + 0.5) / n_episodes, 
                ep_data['episode_id'], 
                rotation=90, ha='center', va='center', fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('gripper_values_grid.png', dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved grid visualization to gripper_values_grid.png")
    
    plt.show()

if __name__ == "__main__":
    main()