"""Generate robot segmentation masks from joint states for start frames."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import mujoco
import matplotlib.pyplot as plt
from tqdm import tqdm

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (detect_and_set_link_poses, position_exoskeleton_meshes, 
                       render_from_camera_pose, detect_and_position_alignment_board, combine_xmls)
from mujoco.renderer import Renderer

# Configuration
dataset_dir = "scratch/dataset"

# Setup robot config
SO100AdhesiveConfig.exo_alpha = 1.0  # Fully opaque for mask
SO100AdhesiveConfig.aruco_alpha = 1.0  # Hide ArUco markers in mask
robot_config = SO100AdhesiveConfig()

print("=" * 60)
print("Robot Mask Generation")
print("=" * 60)

# Load MuJoCo model
model = mujoco.MjModel.from_xml_string(combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition()))
data = mujoco.MjData(model)

# Find all episode directories
episode_dirs = sorted([d for d in os.listdir(dataset_dir) 
                      if d.startswith('episode_') and 
                      os.path.isdir(os.path.join(dataset_dir, d))])

print(f"Found {len(episode_dirs)} episodes")


# Process each episode
for episode_name in tqdm(episode_dirs, desc="Processing episodes"):
    episode_dir = os.path.join(dataset_dir, episode_name)
    
    # Check for start frame files
    start_img_path = os.path.join(episode_dir, "start.png")
    start_joint_path = os.path.join(episode_dir, "joint_states_start.npy")
    
    if not os.path.exists(start_img_path) or not os.path.exists(start_joint_path):
        print(f"  Skipping {episode_name}: missing start files")
        continue
    
    print(f"\n{episode_name}: Processing start frame")
    
    # Load start frame image and joint states
    rgb = plt.imread(start_img_path)[...,:3]
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    else:
        rgb = rgb.astype(np.uint8)
    
    joint_states = np.load(start_joint_path)
    
    # Set robot state
    data.qpos[:] = joint_states
    data.ctrl[:] = joint_states[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Detect camera pose from ArUco markers (needed for rendering)
    link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, _ = detect_and_set_link_poses( rgb, model, data, robot_config)
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    board_result = detect_and_position_alignment_board(rgb, model, data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False)
    mujoco.mj_forward(model, data)  # Update scene with board position
    
    # Render robot mask
    rendered_mask = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2],segmentation=True)[...,0]>0

    # Save binary mask as PNG
    mask_path = os.path.join(episode_dir, "start_robot_mask.png")
    mask_binary = (rendered_mask.astype(np.uint8) * 255)  # Convert True/False to 0/255
    plt.imsave(mask_path, mask_binary, cmap='gray')
    print(f"  ✓ Saved robot mask to {mask_path}")

    # Create visualization
    if 0:
        rendered_rgb = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(rgb)
        axes[0].set_title("Original", fontsize=12)
        axes[0].axis('off')
        
        # Robot mask
        axes[1].imshow(rendered_mask, cmap='gray')
        axes[1].set_title("Render", fontsize=12)
        axes[1].axis('off')
        
        # Overlay
        overlay = (rgb*.5+rendered*.5).astype(np.uint8)
        axes[2].imshow(overlay.astype(np.uint8))
        axes[2].set_title("Overlay", fontsize=12)
        axes[2].axis('off')
        
        plt.suptitle(f"{episode_name} - Start Frame Robot Mask", fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
        plt.close()

    
    
        
    print(f"  ✓ Successfully generated robot mask")
    
print("\n✓ Done!")

