"""Extract start/grasp episodes using pre-computed temporally smoothed kinematics."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import whisper
import numpy as np
import mujoco
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import (estimate_robot_state, detect_and_set_link_poses, 
                       position_exoskeleton_meshes, render_from_camera_pose)

# Configuration
video_path = "/Users/cameronsmith/Downloads/IMG_9546.MOV"
kinematics_path = "scratch/kinematics_exp_test.npy"

# Setup robot config (must match the config used to generate kinematics!)
SO100AdhesiveConfig.exo_alpha = 0.2
SO100AdhesiveConfig.aruco_alpha = 0.2
SO100AdhesiveConfig.exo_link_alpha = 0.2
robot_config = SO100AdhesiveConfig()

print("=" * 60)
print("STEP 1: Load pre-computed kinematics")
print("=" * 60)

smoothed_estimates = np.load(kinematics_path, allow_pickle=True)
print(f"Loaded {len(smoothed_estimates)} temporally smoothed estimates")
print(f"Time range: {smoothed_estimates[0]['time']:.2f}s - {smoothed_estimates[-1]['time']:.2f}s")

print("\n" + "=" * 60)
print("STEP 2: Extract audio timestamps for start/grasp")
print("=" * 60)

# Load Whisper and transcribe
import torch
if 0:
    model_whisper = whisper.load_model("small")
    result = model_whisper.transcribe( video_path, word_timestamps=True, temperature=0, no_speech_threshold=0.6, initial_prompt="",)
    print("saving whisper result")
    torch.save(result, "scratch/whisper_result.pth")
else:result=torch.load("scratch/whisper_result.pth",weights_only=False)

# Extract all words with timestamps
all_words = []
for segment in result['segments']:
    for word in segment.get('words', []):
        all_words.append({
            'word': word['word'].strip().lower(),
            'time': word['start']
        })

# Detect and correct timestamp offset
# Whisper timestamps are relative to when it detects speech, not absolute video time
# If the first word is at 0.0 but speech actually starts later, we need an offset
timestamp_offset = 0.0
if len(all_words) > 0:
    first_word_time = all_words[0]['time']
    first_segment_start = result['segments'][0]['start'] if len(result['segments']) > 0 else 0.0
    
    # If first word starts at 0.0 or very close, there's likely an offset
    if first_word_time < 0.1:
        # Manual offset: Set this to the actual time when speech starts in the video
        # For example, if speech starts at 2 seconds, set manual_offset = 2.0
        manual_offset = 2.0  # Adjust this value based on when speech actually starts in your video
        
        if manual_offset > 0:
            timestamp_offset = manual_offset
            print(f"\nDetected timestamp offset: Whisper starts at {first_word_time:.2f}s")
            print(f"Applying offset of {timestamp_offset:.2f}s to all timestamps")
            # Apply offset to all word timestamps
            for word in all_words:
                word['time'] += timestamp_offset
print(all_words)

print(f"\nExtracted {len(all_words)} words")
if len(all_words) > 0:
    print(f"First word: '{all_words[0]['word']}' at {all_words[0]['time']:.2f}s")
    if len(all_words) > 1:
        print(f"Last word: '{all_words[-1]['word']}' at {all_words[-1]['time']:.2f}s")

# Group: each 'start' with its following 'grasp' words
# Keep ALL start frames as episodes, even if they have no grasp frames
groups = []
current_start = None
current_grasps = []

for w in all_words:
    if 'start' in w['word']:
        # When we encounter a new start, save the previous one (even if it has no grasps)
        if current_start is not None:
            groups.append({'start': current_start, 'grasps': current_grasps})
        current_start = w['time']
        current_grasps = []
    elif 'grasp' in w['word'] and current_start is not None:
        current_grasps.append(w['time'])

# Don't forget the last episode
if current_start is not None:
    groups.append({'start': current_start, 'grasps': current_grasps})

print(f"Found {len(groups)} episodes")
for i, g in enumerate(groups):
    if len(g['grasps']) > 0:
        print(f"  Episode {i+1}: start at {g['start']:.2f}s with {len(g['grasps'])} grasp(s) at {[f'{t:.2f}s' for t in g['grasps']]}")
    else:
        print(f"  Episode {i+1}: start at {g['start']:.2f}s with 0 grasps")

print("\n" + "=" * 60)
print("STEP 3: Refine estimates for start/grasp frames")
print("=" * 60)

# Setup dataset directory
dataset_dir = "scratch/dataset"
os.makedirs(dataset_dir, exist_ok=True)
print(f"Saving dataset to: {dataset_dir}")

# Load MuJoCo model
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)

# Setup video capture
cap = cv2.VideoCapture(video_path)

# Helper function to find nearest smoothed estimate
def find_nearest_estimate(target_time):
    times = np.array([est['time'] for est in smoothed_estimates])
    idx = np.argmin(np.abs(times - target_time))
    return smoothed_estimates[idx]

# Process each episode
episode_data = []

for group_idx, group in enumerate(groups):
    episode_num = group_idx + 1
    print(f"\nEpisode {episode_num}/{len(groups)}")
    
    # Create episode directory
    episode_dir = os.path.join(dataset_dir, f"episode_{episode_num}")
    os.makedirs(episode_dir, exist_ok=True)
    
    episode_frames = []
    episode_failed = False
    
    # Collect all timestamps for this episode
    timestamps = [('start', group['start'])] + [(f'grasp{i+1}', t) for i, t in enumerate(group['grasps'])]
    
    for label, timestamp in timestamps:
        # Find nearest smoothed estimate
        nearest = find_nearest_estimate(timestamp)
        time_diff = abs(nearest['time'] - timestamp)
        print(f"  {label} at {timestamp:.2f}s -> init from {nearest['time']:.2f}s (Δ={time_diff:.2f}s)")
        
        # Get the actual frame at the timestamp
        cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
        ret, frame = cap.read()
        if not ret:
            print(f"    ✗ Failed to read frame - discarding episode")
            episode_failed = True
            break

        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Initialize with nearest smoothed estimate
        data.qpos[:] = nearest['qpos']
        data.ctrl[:] = nearest['qpos'][:len(data.ctrl)]
        mujoco.mj_forward(model, data)
        
        # Refine estimate for this specific frame
        try:
            link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(rgb, model, data, robot_config)
        except Exception as e:
            print(f"    ✗ Failed to detect link poses: {e} - discarding episode")
            episode_failed = True
            break
        configuration = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=100)
        data.qpos[:] = configuration.q
        data.ctrl[:] = configuration.q[:len(data.ctrl)]
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(robot_config, model, data, link_poses)
        
        # Render
        rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
        overlay = (rgb * 0.5 + rendered * 0.5).astype(np.uint8)
        
        # Save image and joint states
        image_path = os.path.join(episode_dir, f"{label}.png")
        joint_path = os.path.join(episode_dir, f"joint_states_{label}.npy")
        print(image_path, joint_path)
        
        np.save(joint_path, configuration.q)
        
        # Get fixed_gripper exo_link 6D pose (4x4 transformation matrix)
        exo_mesh_body_name = "fixed_gripper_exo_mesh"
        exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
        exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
        
        # Get position and quaternion from mocap
        exo_pos = data.mocap_pos[exo_mesh_mocap_id].copy()
        exo_quat_wxyz = data.mocap_quat[exo_mesh_mocap_id].copy()
        
        # Convert to 4x4 transformation matrix
        exo_pose = np.eye(4)
        exo_pose[:3, :3] = R.from_quat(exo_quat_wxyz[[1, 2, 3, 0]]).as_matrix()  # wxyz to xyzw
        exo_pose[:3, 3] = exo_pos
        
        # Save exo_link pose
        exo_pose_path = os.path.join(episode_dir, f"fixed_gripper_exo_pose_{label}.npy")
        np.save(exo_pose_path, exo_pose)
        
        # Save camera pose for start frame
        if label == 'start':
            camera_pose_path = os.path.join(episode_dir, "robot_camera_pose.npy")
            np.save(camera_pose_path, camera_pose_world)
        
        episode_frames.append({
            'label': label,
            'time': timestamp,
            'qpos': configuration.q.copy(),
            'rgb': rgb,
            'rendered': rendered,
            'overlay': overlay
        })
        print(f"    ✓ Saved to {episode_dir}")
    
    # Only add episode to dataset if it completed successfully
    if episode_failed:
        print(f"  ✗ Episode {episode_num} discarded due to failure")
        # Optionally remove the episode directory if you want to clean up partial data
        # import shutil
        # if os.path.exists(episode_dir):
        #     shutil.rmtree(episode_dir)
    elif episode_frames:
        episode_data.append(episode_frames)

cap.release()

print("\n" + "=" * 60)
print("STEP 4: Dataset Summary")
print("=" * 60)

total_frames_saved = sum(len(ep) for ep in episode_data)
print(f"Saved {len(episode_data)} episodes with {total_frames_saved} total frames")
print(f"Dataset location: {dataset_dir}")
print("\nDataset structure:")
for i in range(len(episode_data)):
    episode_dir = os.path.join(dataset_dir, f"episode_{i+1}")
    files = sorted(os.listdir(episode_dir))
    print(f"  episode_{i+1}/ ({len(files)} files)")
    for f in files:
        print(f"    - {f}")

print("\n" + "=" * 60)
print("STEP 5: Visualize results")
print("=" * 60)

# Plot each episode
for episode_idx, episode_frames in enumerate(episode_data):
    if not episode_frames:
        continue
    
    num_frames = len(episode_frames)
    fig, axes = plt.subplots(1, num_frames, figsize=(5*num_frames, 5))
    if num_frames == 1:
        axes = [axes]
    
    for i, frame_data in enumerate(episode_frames):
        axes[i].imshow(frame_data['overlay'])
        axes[i].set_title(
            f"{frame_data['label'].upper()}\n{frame_data['time']:.2f}s",
            fontsize=12,
            fontweight='bold' if frame_data['label'] == 'start' else 'normal'
        )
        axes[i].axis('off')
    
    plt.suptitle(f"Episode {episode_idx + 1}/{len(episode_data)}", fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # Save figure to episode directory
    episode_dir = os.path.join(dataset_dir, f"episode_{episode_idx + 1}")
    fig_path = os.path.join(episode_dir, "ep_render_overlay.png")
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"Saved visualization: {fig_path}")
    
    if 0:plt.show()
    plt.close()

print(f"\n✓ Done! Saved {len(episode_data)} episodes to {dataset_dir}")

