#!/usr/bin/env python3
"""
Visualize datasets by showing first and last frame of the first episode from each dataset.

Given a list of dataset directories, creates a matplotlib grid showing
the first and last frame of the first episode from each dataset side by side.
"""

import argparse
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image


def load_episode_frames(episode_dir):
    """Load first and last frame from an episode directory."""
    episode_dir = Path(episode_dir)
    
    # Find all PNG files
    png_files = sorted(episode_dir.glob("*.png"))
    
    if len(png_files) == 0:
        return None, None
    
    # First frame
    first_frame_path = png_files[0]
    first_img = Image.open(first_frame_path)
    first_img = np.array(first_img)
    
    # Last frame
    last_frame_path = png_files[-1]
    last_img = Image.open(last_frame_path)
    last_img = np.array(last_img)
    
    return first_img, last_img


def get_episodes_from_dataset(dataset_dir):
    """Get all episode directories from a dataset."""
    dataset_dir = Path(dataset_dir)
    
    # Try to load metadata.json to get episode list
    metadata_path = dataset_dir / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        episodes = []
        for ep_info in metadata.get('episodes', []):
            ep_num = ep_info['episode_num']
            ep_dir = dataset_dir / f"episode_{ep_num:03d}"
            if ep_dir.exists():
                episodes.append((ep_num, ep_dir))
        return sorted(episodes, key=lambda x: x[0])
    
    # Fallback: find all episode_XXX directories
    episode_dirs = sorted(dataset_dir.glob("episode_*"))
    episodes = []
    for ep_dir in episode_dirs:
        if ep_dir.is_dir():
            # Extract episode number from directory name
            try:
                ep_num = int(ep_dir.name.split('_')[1])
                episodes.append((ep_num, ep_dir))
            except (ValueError, IndexError):
                continue
    return sorted(episodes, key=lambda x: x[0])


def visualize_datasets(dataset_dirs, output_path=None):
    """Create visualization grid of first/last frames for the first episode of each dataset."""
    
    all_episodes = []
    
    # Collect first episode from each dataset
    for dataset_dir in dataset_dirs:
        dataset_dir = Path(dataset_dir)
        if not dataset_dir.exists():
            print(f"Warning: Dataset directory {dataset_dir} does not exist, skipping")
            continue
        
        episodes = get_episodes_from_dataset(dataset_dir)
        dataset_name = dataset_dir.name
        
        # Only take the first episode
        if len(episodes) > 0:
            ep_num, ep_dir = episodes[0]
            all_episodes.append((dataset_name, ep_num, ep_dir))
    
    if len(all_episodes) == 0:
        print("No episodes found in any dataset")
        return
    
    # Create figure: two rows (first frame, last frame), one column per episode
    n_episodes = len(all_episodes)
    fig, axes = plt.subplots(2, n_episodes, figsize=(4 * n_episodes, 8))
    
    if n_episodes == 1:
        axes = axes.reshape(-1, 1)
    
    for col_idx, (dataset_name, ep_num, ep_dir) in enumerate(all_episodes):
        first_img, last_img = load_episode_frames(ep_dir)
        
        if first_img is None or last_img is None:
            print(f"Warning: Could not load frames from {ep_dir}")
            axes[0, col_idx].axis('off')
            axes[1, col_idx].axis('off')
            continue
        
        # First frame row
        ax_first = axes[0, col_idx]
        ax_first.imshow(first_img)
        ax_first.axis('off')
        ax_first.set_title(f"{dataset_name}\nEpisode {ep_num:03d} (First)", 
                          fontsize=10, fontweight='bold')
        
        # Last frame row
        ax_last = axes[1, col_idx]
        ax_last.imshow(last_img)
        ax_last.axis('off')
        ax_last.set_title(f"{dataset_name}\nEpisode {ep_num:03d} (Last)", 
                         fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to {output_path}")
    else:
        plt.show()


def main():
    parser = argparse.ArgumentParser(
        description="Visualize datasets by showing first and last frame of each episode"
    )
    parser.add_argument(
        "datasets",
        nargs="+",
        help="List of dataset directories (e.g., scratch/parsed_school_cap1 scratch/parsed_school_cap2)"
    )
    parser.add_argument(
        "--output", "-o",
        type=str,
        default=None,
        help="Output path for saved visualization (if not provided, displays interactively)"
    )
    
    args = parser.parse_args()
    
    visualize_datasets(
        args.datasets,
        output_path=args.output
    )


if __name__ == "__main__":
    main()
