"""Estimate robot state from a single image.

This demo shows how to:
1. Load an RGB image
2. Detect ArUco markers
3. Estimate robot joint configuration
4. Render the estimated pose alongside the original image
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import mujoco
import matplotlib.pyplot as plt
import numpy as np
from mujoco.renderer import Renderer

from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import estimate_robot_state, detect_and_set_link_poses, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--exo", type=str, default="so100_holemounts", 
                   choices=list(EXOSKELETON_CONFIGS.keys()),
                   help="Exoskeleton configuration to use")
parser.add_argument("--just_sim_state", action="store_true", help="cam rerender but dont reset config")
parser.add_argument("--no_render", action="store_true", help="just render arm in sim")
args = parser.parse_args()

#robot_config = EXOSKELETON_CONFIGS[args.exo]

from ExoConfigs.so100_holemounts import SO100HoleMountsConfig
SO100HoleMountsConfig.exo_alpha = 0.2
SO100HoleMountsConfig.aruco_alpha = 0.2  # Set to 0.
robot_config = SO100HoleMountsConfig()


print(f"Using exoskeleton config: {args.exo} ({robot_config.name})")

# Setup video capture
import cv2
from tqdm import tqdm
video_path = "/Users/cameronsmith/Downloads/IMG_9546.MOV"
cap = cv2.VideoCapture(video_path)

# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
target_fps = 4
frame_skip = int(fps / target_fps)
print(f"Video: {total_frames} frames @ {fps:.2f} fps")
print(f"Processing at {target_fps} fps (every {frame_skip} frames)")

# Load model from config
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)

# Process frames at target fps
joint_configs = []
rendered_frames = []
frames_to_process = list(range(0, total_frames, frame_skip))

for frame_idx in tqdm(frames_to_process, desc="Processing frames"):
    # Seek to the specific frame
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    if not ret:
        break
    
    # Convert BGR to RGB
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # Detect link poses from ArUco markers and estimate robot state
    try:
        link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(rgb, model, data, robot_config)
        configuration = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=35)
        
        # Update robot state (this becomes the starting point for next frame)
        data.qpos[:] = configuration.q
        data.ctrl[:] = configuration.q[:len(data.ctrl)]
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(robot_config, model, data, link_poses)
        
        joint_configs.append(configuration.q.copy())
        
        # Render every 10th frame
        if frame_idx % 100 == 0 and 1:
            rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
            overlay = (rgb * 0.5 + rendered * 0.5).astype(np.uint8)
            plt.imshow(overlay)
            plt.show()
            #rendered_frames.append({ 'frame_idx': frame_idx, 'rgb': rgb, 'rendered': rendered, 'overlay': overlay })
            
    except Exception as e:
        print(f"\nFrame {frame_idx} failed: {e}")
        joint_configs.append(None)

cap.release()

print(f"\nProcessed {len([c for c in joint_configs if c is not None])}/{total_frames} frames successfully")
print(f"Rendered {len(rendered_frames)} frames")

# Plot rendered frames
if rendered_frames:
    num_frames = len(rendered_frames)
    fig, axes = plt.subplots(num_frames, 3, figsize=(15, 5*num_frames))
    if num_frames == 1:
        axes = axes.reshape(1, -1)
    
    for i, frame_data in enumerate(rendered_frames):
        axes[i, 0].imshow(frame_data['rgb'])
        axes[i, 0].set_title(f"Frame {frame_data['frame_idx']} - Original")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(frame_data['rendered'])
        axes[i, 1].set_title(f"Frame {frame_data['frame_idx']} - Rendered")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(frame_data['overlay'])
        axes[i, 2].set_title(f"Frame {frame_data['frame_idx']} - Overlay")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()