"""Parse recorded video into episodes by manually selecting start/end frames.

Interactive matplotlib interface for navigating frames and marking episode boundaries.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
import json

# Set matplotlib to use interactive backend
plt.ion()

parser = argparse.ArgumentParser(description="Parse video into episodes by selecting start/end frames")
parser.add_argument("--input_dir","-i", type=str, required=True,
                    help="Input directory with recorded frames (e.g., scratch/rgb_joints_capture_cup)")
args = parser.parse_args()

args.output_dir = args.input_dir.split("/")[0]+ "/parsed_"+args.input_dir.split("/")[-1]

input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# Load all images and joint states
print(f"Loading frames from {input_dir}...")
image_files = sorted(input_dir.glob("*.png"))
joint_files = sorted(input_dir.glob("*.npy"))

# Match image and joint files by timestep
# Extract timestep from filename (e.g., "00001_rgb.png" -> "00001")
frame_data = []
for img_file in image_files:
    # Extract timestep from filename (remove "_rgb" suffix)
    timestep = str(img_file).split("/")[-1][:-4]  # Remove "_rgb" suffix
    joint_file = input_dir / f"{timestep}.npy"
    if joint_file.exists():
        frame_data.append({
            'timestep': timestep,
            'image_path': img_file,
            'joint_path': joint_file
        })

if len(frame_data) == 0:
    raise ValueError(f"No matching image/joint pairs found in {input_dir}")

print(f"Loaded {len(frame_data)} frames")

# State variables
current_frame_idx = 0
episodes = []  # List of {'start': idx, 'end': idx}
current_start = None
current_end = None

# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
fig.canvas.manager.set_window_title('Episode Parser - Navigate and mark start/end frames')

# Disable matplotlib's default key bindings to prevent save dialog on 's'
from matplotlib import rcParams
rcParams['keymap.save'].remove('s')  # Remove 's' from default save keymap

def load_frame(idx):
    """Load and return frame image."""
    if idx < 0 or idx >= len(frame_data):
        return None
    img_path = frame_data[idx]['image_path']
    img = cv2.imread(str(img_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def update_display():
    """Update the displayed frame with current state."""
    ax.clear()
    
    # Load current frame
    img = load_frame(current_frame_idx)
    if img is not None:
        ax.imshow(img)
    
    # Build status text
    status_lines = [
        f"Frame: {current_frame_idx + 1}/{len(frame_data)}",
        f"Timestep: {frame_data[current_frame_idx]['timestep']}",
        "",
        "Controls:",
        "  ←/a: Previous frame",
        "  →/d: Next frame",
        "  s: Mark START (new episode)",
        "  e: Mark END (complete episode)",
        "  u: Undo last mark",
        "  q: Quit and save",
        "",
    ]
    
    # Show current episode being built
    if current_start is not None:
        status_lines.append(f"Current episode (in progress):")
        status_lines.append(f"  START: Frame {current_start + 1} ({frame_data[current_start]['timestep']})")
        if current_end is not None:
            status_lines.append(f"  END: Frame {current_end + 1} ({frame_data[current_end]['timestep']})")
            status_lines.append(f"  Duration: {current_end - current_start + 1} frames")
        else:
            status_lines.append("  (waiting for END frame)")
    else:
        status_lines.append("Ready for new episode (press 's' to start)")
    
    # Show completed episodes
    status_lines.append("")
    status_lines.append(f"Completed episodes: {len(episodes)}")
    for i, ep in enumerate(episodes):
        start_ts = frame_data[ep['start']]['timestep']
        end_ts = frame_data[ep['end']]['timestep']
        duration = ep['end'] - ep['start'] + 1
        status_lines.append(f"  Episode {i+1}: Frames {ep['start']+1}-{ep['end']+1} ({duration} frames)")
    
    # Add visual indicators for marked frames
    if current_start is not None:
        # Draw green border/indicator for start
        y_pos = img.shape[0] - 20 if img is not None else 100
        ax.text(10, y_pos, f"START: Frame {current_start + 1}", 
                bbox=dict(boxstyle="round,pad=0.5", facecolor="green", alpha=0.7),
                fontsize=12, color="white", weight="bold")
    
    if current_end is not None:
        # Draw red border/indicator for end
        y_pos = img.shape[0] - 50 if img is not None else 70
        ax.text(10, y_pos, f"END: Frame {current_end + 1}", 
                bbox=dict(boxstyle="round,pad=0.5", facecolor="red", alpha=0.7),
                fontsize=12, color="white", weight="bold")
    
    # Highlight current frame if it's marked
    if current_start == current_frame_idx:
        # Draw green border
        if img is not None:
            rect = plt.Rectangle((0, 0), img.shape[1], img.shape[0], 
                               linewidth=5, edgecolor='green', facecolor='none')
            ax.add_patch(rect)
    elif current_end == current_frame_idx:
        # Draw red border
        if img is not None:
            rect = plt.Rectangle((0, 0), img.shape[1], img.shape[0], 
                               linewidth=5, edgecolor='red', facecolor='none')
            ax.add_patch(rect)
    
    # Add status text
    status_text = "\n".join(status_lines)
    ax.text(0.02, 0.98, status_text, transform=ax.transAxes,
            fontsize=10, verticalalignment='top', family='monospace',
            bbox=dict(boxstyle="round,pad=0.5", facecolor="black", alpha=0.7),
            color="white")
    
    ax.axis('off')
    plt.tight_layout()
    fig.canvas.draw()
    fig.canvas.flush_events()

def on_key(event):
    """Handle keyboard events."""
    global current_frame_idx, current_start, current_end, episodes
    
    if event.key in ['left', 'a', 'A']:
        # Previous frame
        current_frame_idx = max(0, current_frame_idx - 1)
        update_display()
    
    elif event.key in ['right', 'd', 'D']:
        # Next frame
        current_frame_idx = min(len(frame_data) - 1, current_frame_idx + 1)
        update_display()
    
    elif event.key in ['s', 'S']:
        # Mark start frame (starts new episode)
        # If there's an incomplete episode, discard it
        if current_start is not None and current_end is None:
            print(f"Discarding incomplete episode (start at frame {current_start + 1})")
        current_start = current_frame_idx
        current_end = None  # Clear any previous end
        print(f"Started new episode at frame {current_frame_idx + 1} ({frame_data[current_frame_idx]['timestep']})")
        update_display()
    
    elif event.key in ['e', 'E']:
        # Mark end frame (completes current episode)
        if current_start is None:
            print("Warning: Please mark START frame first (press 's')")
        elif current_frame_idx < current_start:
            print("Warning: END frame must be after START frame")
        else:
            current_end = current_frame_idx
            # Automatically create and save the episode
            episodes.append({'start': current_start, 'end': current_end})
            print(f"Completed episode {len(episodes)}: Frames {current_start + 1}-{current_end + 1} ({current_end - current_start + 1} frames)")
            # Reset for next episode
            current_start = None
            current_end = None
            update_display()
    
    elif event.key in ['u', 'U']:
        # Undo last mark
        if current_end is not None:
            current_end = None
            print("Undid END mark")
        elif current_start is not None:
            current_start = None
            print("Undid START mark")
        else:
            print("No marks to undo")
        update_display()
    
    elif event.key in ['q', 'Q']:
        # Quit and save
        save_episodes()
        plt.close('all')
        print("\nExiting...")

def save_episodes():
    """Save all episodes to output directory."""
    if len(episodes) == 0:
        print("No episodes to save")
        return
    
    print(f"\nSaving {len(episodes)} episodes to {output_dir}...")
    
    # Save episode metadata
    metadata = {
        'source_dir': str(input_dir),
        'total_frames': len(frame_data),
        'episodes': []
    }
    
    for i, ep in enumerate(episodes):
        episode_num = i + 1
        episode_dir = output_dir / f"episode_{episode_num:03d}"
        episode_dir.mkdir(exist_ok=True)
        
        # Copy frames for this episode
        start_idx = ep['start']
        end_idx = ep['end']
        
        frame_count = 0
        for frame_idx in range(start_idx, end_idx + 1):
            frame_info = frame_data[frame_idx]
            
            # Copy image
            src_img = frame_info['image_path']
            dst_img = episode_dir / f"{frame_count:06d}.png"
            shutil.copy2(src_img, dst_img)
            
            # Copy joint state
            shutil.copy2(frame_info['joint_path'], episode_dir / f"{frame_count:06d}.npy")
            shutil.copy2(str(frame_info['joint_path']).replace(".npy", "_gripper_pose.npy"), episode_dir / f"{frame_count:06d}_gripper_pose.npy")
            shutil.copy2(str(frame_info['joint_path']).replace(".npy", "_camera_pose.npy"), episode_dir / f"{frame_count:06d}_camera_pose.npy")
            shutil.copy2(str(frame_info['joint_path']).replace(".npy", "_cam_K_norm.npy"), episode_dir / f"{frame_count:06d}_cam_K.npy")
            
            frame_count += 1
        
        episode_info = {
            'episode_num': episode_num,
            'start_frame_idx': start_idx,
            'end_frame_idx': end_idx,
            'start_timestep': frame_data[start_idx]['timestep'],
            'end_timestep': frame_data[end_idx]['timestep'],
            'num_frames': frame_count
        }
        metadata['episodes'].append(episode_info)
        
        print(f"  Episode {episode_num}: {frame_count} frames (from {frame_data[start_idx]['timestep']} to {frame_data[end_idx]['timestep']})")
    
    # Save metadata JSON
    metadata_path = output_dir / "metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"\nSaved episodes to {output_dir}")
    print(f"Metadata saved to {metadata_path}")

# Connect keyboard handler
fig.canvas.mpl_connect('key_press_event', on_key)

# Initial display
update_display()

print("\n" + "="*60)
print("Episode Parser")
print("="*60)
print("Navigate frames and mark start/end to create episodes")
print("Press 'q' to quit and save all episodes")
print("="*60 + "\n")

# Keep the window open
plt.show(block=True)

