"""Extract kinematics for start/grasp episodes using temporally smoothed estimates."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import whisper
import numpy as np
import mujoco
import matplotlib.pyplot as plt
from tqdm import tqdm

from ExoConfigs.so100_holemounts import SO100HoleMountsConfig
from exo_utils import (estimate_robot_state, detect_and_set_link_poses, 
                       position_exoskeleton_meshes, render_from_camera_pose)

# Configuration
video_path = "/Users/cameronsmith/Downloads/IMG_9546.MOV"
target_fps = 4

# Setup robot config
SO100HoleMountsConfig.exo_alpha = 0.2
SO100HoleMountsConfig.aruco_alpha = 0.2
robot_config = SO100HoleMountsConfig()

print("=" * 60)
print("STEP 1: Extract audio timestamps for start/grasp")
print("=" * 60)

# Load Whisper and transcribe
model_whisper = whisper.load_model("small")
result = model_whisper.transcribe(
    video_path,
    word_timestamps=True,
    temperature=0,
    no_speech_threshold=0.6,
    initial_prompt="Start. Grasp.",
)

# Extract all words with timestamps12
all_words = []
for segment in result['segments']:
    for word in segment.get('words', []):
        all_words.append({
            'word': word['word'].strip().lower(),
            'time': word['start']
        })

# Group: each 'start' with its following 'grasp' words
groups = []
current_start = None
current_grasps = []

for w in all_words:
    if 'start' in w['word']:
        if current_start is not None:
            groups.append({'start': current_start, 'grasps': current_grasps})
        current_start = w['time']
        current_grasps = []
    elif 'grasp' in w['word'] and current_start is not None:
        current_grasps.append(w['time'])

if current_start is not None:
    groups.append({'start': current_start, 'grasps': current_grasps})

print(f"Found {len(groups)} episodes")

print("\n" + "=" * 60)
print("STEP 2: Process video at 4fps for temporally smoothed estimates")
print("=" * 60)

# Setup video capture
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
frame_skip = int(fps / target_fps)
print(f"Video: {total_frames} frames @ {fps:.2f} fps")
print(f"Processing at {target_fps} fps (every {frame_skip} frames)")

# Load MuJoCo model
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)

# Process frames at target fps - store both configs and metadata
smoothed_estimates = []  # List of dicts with frame_idx, time, qpos, camera_pose, cam_K
frames_to_process = list(range(0, total_frames, frame_skip))

for frame_idx in tqdm(frames_to_process, desc="4fps processing"):
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    if not ret:
        break
    
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_time = frame_idx / fps
    
    try:
        link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(
            rgb, model, data, robot_config)
        configuration = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=35)
        
        # Update robot state (warm start for next frame)
        data.qpos[:] = configuration.q
        data.ctrl[:] = configuration.q[:len(data.ctrl)]
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(robot_config, model, data, link_poses)
        
        smoothed_estimates.append({
            'frame_idx': frame_idx,
            'time': frame_time,
            'qpos': configuration.q.copy(),
            'camera_pose': camera_pose_world.copy(),
            'cam_K': cam_K.copy()
        })
    except Exception as e:
        print(f"\nFrame {frame_idx} failed: {e}")

print(f"Successfully processed {len(smoothed_estimates)} frames")

print("\n" + "=" * 60)
print("STEP 3: Refine estimates for start/grasp frames")
print("=" * 60)

# Helper function to find nearest smoothed estimate
def find_nearest_estimate(target_time):
    times = np.array([est['time'] for est in smoothed_estimates])
    idx = np.argmin(np.abs(times - target_time))
    return smoothed_estimates[idx]

# Process each episode
episode_data = []

for group_idx, group in enumerate(groups):
    print(f"\nEpisode {group_idx + 1}/{len(groups)}")
    episode_frames = []
    
    # Collect all timestamps for this episode
    timestamps = [('start', group['start'])] + [('grasp', t) for t in group['grasps']]
    
    for label, timestamp in timestamps:
        # Find nearest smoothed estimate
        nearest = find_nearest_estimate(timestamp)
        print(f"  {label} at {timestamp:.2f}s -> using frame {nearest['frame_idx']} ({nearest['time']:.2f}s) as init")
        
        # Get the actual frame at the timestamp
        cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
        ret, frame = cap.read()
        if not ret:
            continue
        
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Initialize with nearest smoothed estimate
        data.qpos[:] = nearest['qpos']
        data.ctrl[:] = nearest['qpos'][:len(data.ctrl)]
        mujoco.mj_forward(model, data)
        
        try:
            # Refine estimate for this specific frame
            link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(
                rgb, model, data, robot_config)
            configuration = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=20)
            
            data.qpos[:] = configuration.q
            data.ctrl[:] = configuration.q[:len(data.ctrl)]
            mujoco.mj_forward(model, data)
            position_exoskeleton_meshes(robot_config, model, data, link_poses)
            
            # Render
            rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
            overlay = (rgb * 0.5 + rendered * 0.5).astype(np.uint8)
            
            episode_frames.append({
                'label': label,
                'time': timestamp,
                'qpos': configuration.q.copy(),
                'rgb': rgb,
                'rendered': rendered,
                'overlay': overlay
            })
        except Exception as e:
            print(f"  Failed to process {label} frame: {e}")
    
    episode_data.append(episode_frames)

cap.release()

print("\n" + "=" * 60)
print("STEP 4: Visualize results")
print("=" * 60)

# Plot each episode
for episode_idx, episode_frames in enumerate(episode_data):
    if not episode_frames:
        continue
    
    num_frames = len(episode_frames)
    fig, axes = plt.subplots(1, num_frames, figsize=(5*num_frames, 5))
    if num_frames == 1:
        axes = [axes]
    
    for i, frame_data in enumerate(episode_frames):
        axes[i].imshow(frame_data['overlay'])
        axes[i].set_title(
            f"{frame_data['label'].upper()}\n{frame_data['time']:.2f}s",
            fontsize=12,
            fontweight='bold' if frame_data['label'] == 'start' else 'normal'
        )
        axes[i].axis('off')
    
    plt.suptitle(f"Episode {episode_idx + 1}/{len(episode_data)}", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

print(f"\nProcessed {len(episode_data)} episodes")

