"""Estimate robot state from a single image.

This demo shows how to:
1. Load an RGB image
2. Detect ArUco markers
3. Estimate robot joint configuration
4. Render the estimated pose alongside the original image
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from tqdm import tqdm

from glob import glob

import mujoco
import matplotlib.pyplot as plt
import numpy as np
from mujoco.renderer import Renderer

from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import estimate_robot_state, detect_and_set_link_poses, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--exo", type=str, default="so100_adhesive", 
                   choices=list(EXOSKELETON_CONFIGS.keys()),
                   help="Exoskeleton configuration to use")
parser.add_argument("--just_sim_state", action="store_true", help="cam rerender but dont reset config")
parser.add_argument("--no_render", action="store_true", help="just render arm in sim")
args = parser.parse_args()

#robot_config = EXOSKELETON_CONFIGS[args.exo]

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
SO100AdhesiveConfig.exo_alpha = 0.2
SO100AdhesiveConfig.aruco_alpha = 0.2  # Set to 0.
robot_config = SO100AdhesiveConfig()

print(f"Using exoskeleton config: {args.exo} ({robot_config.name})")
cam_K=None

for data_dir in glob("scratch/parsed_propercup_*"):
    for episode in tqdm(glob(os.path.join(data_dir, "episode_*"))):
        for img_i,image_path in tqdm(enumerate(sorted(glob(os.path.join(episode, "*.png")))),leave=False,desc=episode):
            if "overlay" in image_path: continue
            joint_state=np.load(image_path.replace(".png", ".npy"))

            # Load model from config
            model = mujoco.MjModel.from_xml_string(robot_config.xml)
            data = mujoco.MjData(model)

            # Set virtual robot state from image
            rgb = plt.imread(image_path)[..., :3]
            if rgb.max() <= 1.0: rgb = (rgb * 255).astype(np.uint8)

            # Detect link poses from ArUco markers
            #data.qpos[:] = joint_state*0
            link_poses, camera_pose_world, cam_K, corners_cache,corners_vis,obj_img_pts = detect_and_set_link_poses(rgb, model, data, robot_config,cam_K=cam_K)

            mujoco.mj_forward(model, data)

            configuration = estimate_robot_state( model, data, robot_config, link_poses, ik_iterations=55)
            data.qpos[:] = configuration.q
            data.qpos[-1]=joint_state[-1]
            mujoco.mj_forward(model, data)
            position_exoskeleton_meshes(robot_config, model, data, link_poses)

            if len(link_poses)<5: continue
            gripper_pose = link_poses["fixed_gripper"]

            if img_i==0:
                np.save(image_path[:-10]+"robot_camera_pose.npy", camera_pose_world)
                np.save(image_path[:-10]+"cam_K.npy", cam_K)
            np.save(image_path.replace(".png", "_gripper_pose.npy"), gripper_pose)
            np.save(image_path.replace(".png", "_joint_state.npy"), data.qpos)
            #overlay = (corners_vis * 0.5 + rendered * 0.5).astype(np.uint8)
            #rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
            #plt.imsave(image_path.replace(".png", "_overlay.png"), overlay)

            if 0:
                # Display results
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                for ax, img in zip(axes, [rgb, rendered, overlay]): ax.imshow(img);ax.axis('off')
                plt.tight_layout()
                plt.show()