"""Create a GIF animation from parsed episodes with episode labels."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import argparse
from pathlib import Path
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont

parser = argparse.ArgumentParser(description="Create GIF from parsed episodes")
parser.add_argument("--input_dir", type=str, default="scratch/parsed_episodes_cup_synch", help="Input directory with parsed episodes")
parser.add_argument("--fps", type=float, default=16.0, help="Frames per second for GIF")
args = parser.parse_args()

args.output_path = args.input_dir+"/animation.gif"

input_dir = Path(args.input_dir)
output_path = Path(args.output_path)

# Find all episode directories
episode_dirs = sorted([d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
num_episodes = len(episode_dirs)

if num_episodes == 0:
    raise ValueError(f"No episode directories found in {input_dir}")

print(f"Found {num_episodes} episodes")
print(f"Creating GIF animation...")

# Collect all frames from all episodes
all_frames = []
episode_frame_counts = []

for ep_idx, episode_dir in enumerate(episode_dirs):
    # Find all image files in this episode
    image_files = sorted(episode_dir.glob("??????.png"))
    
    if len(image_files) == 0:
        print(f"Warning: No images found in {episode_dir.name}, skipping")
        continue
    
    episode_frames = []
    for img_path in image_files:
        # Load image
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Convert to PIL Image for text overlay
        pil_img = Image.fromarray(img)
        draw = ImageDraw.Draw(pil_img)
        
        # Add text overlay in right corner
        text = f"Episode: {ep_idx + 1}/{num_episodes}"
        
        # Try to use a nice font, fallback to default if not available
        try:
            # Try to use a larger font
            font_size = max(20, min(img.shape[0], img.shape[1]) // 20)
            font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size)
        except:
            try:
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
            except:
                font = ImageFont.load_default()
        
        # Get text bounding box
        bbox = draw.textbbox((0, 0), text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        
        # Position in right corner with padding
        padding = 10
        x = img.shape[1] - text_width - padding
        y = padding
        
        # Draw background rectangle for text
        bg_padding = 5
        draw.rectangle(
            [x - bg_padding, y - bg_padding, 
             x + text_width + bg_padding, y + text_height + bg_padding],
            fill=(0, 0, 0, 180)  # Semi-transparent black
        )
        
        # Draw text
        draw.text((x, y), text, fill=(255, 255, 255), font=font)
        
        # Convert back to numpy array
        img_with_text = np.array(pil_img)
        episode_frames.append(cv2.resize(img_with_text, (1280//3, 720//3)))
    
    all_frames.extend(episode_frames)
    episode_frame_counts.append(len(episode_frames))
    print(f"  Episode {ep_idx + 1}: {len(episode_frames)} frames")

print(f"\nTotal frames: {len(all_frames)}")

# Create GIF
print(f"Saving GIF to {output_path}...")
output_path.parent.mkdir(parents=True, exist_ok=True)

# Convert frames to PIL Images
pil_frames = [Image.fromarray(frame) for frame in all_frames]

# Save as GIF
pil_frames[0].save(
    str(output_path),
    save_all=True,
    append_images=pil_frames[1:],
    duration=int(1000 / args.fps),  # Duration in milliseconds
    loop=0  # Loop forever
)

print(f"✓ Saved GIF with {len(all_frames)} frames at {args.fps} fps")
print(f"  Output: {output_path}")

