vid_path="/Users/cameronsmith/Downloads/IMG_9546.MOV"
import cv2, matplotlib.pyplot as plt, whisper, numpy as np

# Load Whisper model
model = whisper.load_model("small")
result = model.transcribe(
    vid_path, 
    word_timestamps=True,
    temperature=0,
    no_speech_threshold=0.6,
    initial_prompt="Start. Grasp.",
)

# Extract all words with timestamps
all_words = []
for segment in result['segments']:
    for word in segment.get('words', []):
        all_words.append({
            'word': word['word'].strip().lower(),
            'time': word['start']
        })

print(f"Total words: {len(all_words)}")

# Group: each 'start' with its following 'grasp' words
groups = []
current_start = None
current_grasps = []

for w in all_words:
    if 'start' in w['word']:
        # Save previous group if exists
        if current_start is not None:
            groups.append({
                'start': current_start,
                'grasps': current_grasps
            })
        # Start new group
        current_start = w['time']
        current_grasps = []
    elif 'grasp' in w['word'] and current_start is not None:
        current_grasps.append(w['time'])

# Don't forget the last group
if current_start is not None:
    groups.append({
        'start': current_start,
        'grasps': current_grasps
    })

print(f"Found {len(groups)} start commands")
for i, g in enumerate(groups):
    print(f"  Group {i+1}: start at {g['start']:.2f}s with {len(g['grasps'])} grasp(s)")

# Extract frames
cap = cv2.VideoCapture(vid_path)

def get_frame(time_sec):
    cap.set(cv2.CAP_PROP_POS_MSEC, time_sec * 1000)
    ret, frame = cap.read()
    if ret:
        return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return None

# Plot each group
for group_idx, group in enumerate(groups):
    num_frames = 1 + len(group['grasps'])  # start + grasps
    
    fig, axes = plt.subplots(1, num_frames, figsize=(4*num_frames, 4))
    if num_frames == 1:
        axes = [axes]
    
    # Start frame
    start_frame = get_frame(group['start'])
    if start_frame is not None:
        axes[0].imshow(start_frame)
        axes[0].set_title(f"START\n{group['start']:.2f}s", fontsize=12, fontweight='bold')
        axes[0].axis('off')
    
    # Grasp frames
    for i, grasp_time in enumerate(group['grasps']):
        grasp_frame = get_frame(grasp_time)
        if grasp_frame is not None:
            axes[i+1].imshow(grasp_frame)
            axes[i+1].set_title(f"Grasp {i+1}\n{grasp_time:.2f}s", fontsize=12)
            axes[i+1].axis('off')
    
    plt.suptitle(f"Group {group_idx+1}/{len(groups)}", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

cap.release()