"""Generate human segmentation masks for demonstration frames."""
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
from torchvision import transforms
from tqdm import tqdm

# Configuration
dataset_dir = "scratch/dataset"

print("=" * 60)
print("Human Mask Generation")
print("=" * 60)

# Load pre-trained segmentation model (DeepLabV3 with ResNet101)
print("Loading segmentation model...")
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()

# Move to GPU if available
device = torch.device('mps')
model = model.to(device)
print(f"Using device: {device}")

# Define preprocessing transform
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_human_mask(image_path):
    """
    Generate human segmentation mask for an image.
    
    Args:
        image_path: Path to the image file
        
    Returns:
        mask: Binary mask where 1 = person, 0 = background
        rgb: Original RGB image as numpy array
    """
    # Load image
    img = Image.open(image_path).convert('RGB')
    rgb = np.array(img)
    
    # Preprocess
    input_tensor = preprocess(img)
    input_batch = input_tensor.unsqueeze(0).to(device)
    
    # Run segmentation
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    
    # Get class predictions
    output_predictions = output.argmax(0).cpu().numpy()
    
    # Class 15 in COCO is 'person'
    person_mask = (output_predictions == 15).astype(np.uint8)
    
    return person_mask, rgb

# Find all episode directories
episode_dirs = sorted([d for d in os.listdir(dataset_dir) 
                      if d.startswith('episode_') and 
                      os.path.isdir(os.path.join(dataset_dir, d))])

print(f"Found {len(episode_dirs)} episodes")

# Process each episode
for episode_name in tqdm(episode_dirs, desc="Processing episodes"):
    episode_dir = os.path.join(dataset_dir, episode_name)
    
    # Only process start frame
    start_img_path = os.path.join(episode_dir, "start.png")
    
    if not os.path.exists(start_img_path):
        print(f"  Skipping {episode_name}: missing start.png")
        continue
    
    print(f"\n{episode_name}: Processing start frame")
    
    # Get human mask for start frame
    mask, rgb = get_human_mask(start_img_path)
    
    # Save mask as PNG (0-255 for visibility)
    mask_filename = "human_mask_start.png"
    mask_path = os.path.join(episode_dir, mask_filename)
    plt.imsave(mask_path, mask, cmap='gray')
    print(f"  Saved mask: {mask_path}")
    
    # Create and save visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(rgb)
    axes[0].set_title("start - Original", fontsize=12)
    axes[0].axis('off')
    
    # Mask
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title("start - Human Mask", fontsize=12)
    axes[1].axis('off')
    
    # Overlay
    overlay = rgb.copy()
    overlay[mask == 1] = overlay[mask == 1] * 0.5 + np.array([255, 0, 0]) * 0.5
    axes[2].imshow(overlay.astype(np.uint8))
    axes[2].set_title("start - Overlay", fontsize=12)
    axes[2].axis('off')
    
    plt.suptitle(f"{episode_name} - start", fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # Save visualization
    vis_path = os.path.join(episode_dir, "human_mask_vis.png")
    plt.savefig(vis_path, dpi=150, bbox_inches='tight')
    print(f"  Saved visualization: {vis_path}")
    
    if 0:plt.show()
    plt.close()

print("\n✓ Done!")

