"""Minimal IK test: load image, estimate robot state, view in MuJoCo."""
import sys
import os
from pathlib import Path
import cv2
import mujoco
import numpy as np
import xml.etree.ElementTree as ET
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
import matplotlib.pyplot as plt
from exo_utils import render_from_camera_pose
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, get_link_poses_from_robot
import mink

KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                  [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3

import argparse
parser = argparse.ArgumentParser(description="Test ResNet pose predictor")
parser.add_argument("--dataset_dir", "-d", default="scratch/parsed_propercup_train", type=str, help="Dataset directory")
parser.add_argument("--episode_id", "-e", default=1, type=int, help="Episode ID")
#parser.add_argument("--start_frame", "--sf", default=0, type=int, help="Start frame index for episode")
args = parser.parse_args()
dataset_dir = Path(args.dataset_dir)

episode_dir = Path(f"{dataset_dir}/episode_{args.episode_id:03d}")
#start_frame_file = frame_files[args.start_frame]
frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])

# Load GT trajectory
trajectory_gt_3d = []
for frame_file in frame_files[:]:
    frame_str = f"{int(frame_file.stem):06d}"
    pose = np.load(episode_dir / f"{frame_str}_gripper_pose.npy")
    trajectory_gt_3d.append(pose[:3, :3] @ KEYPOINTS_LOCAL_M_ALL[KP_INDEX] + pose[:3, 3])

# Add GT spheres and virtual keypoint mocap to XML
robot_config = SO100AdhesiveConfig()
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')
for i, kp_pos in enumerate(trajectory_gt_3d):
    green = 1.0 - (i / max(len(trajectory_gt_3d) - 1, 1))
    blue = i / max(len(trajectory_gt_3d) - 1, 1)
    ET.SubElement(worldbody, 'site', {
        'name': f'gt_kp_{i}', 'type': 'sphere', 'size': '0.015', # note can change this to BOX and red instead of blue for pred vs GT
        'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'0 {green} {blue} 0.8'
    })

mj_model = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data = mujoco.MjData(mj_model)

# Load start image
start_frame_file = frame_files[0]
rgb = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)

# Estimate robot state from image
link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(rgb, mj_model, mj_data, robot_config)
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# View in MuJoCo and execute GT trajectory via IK
configuration = mink.Configuration(mj_model)
configuration.update(mj_data.qpos)

def ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data):
    for _ in range(50):
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        configuration.update(mj_data.qpos)
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.1)
        target_quat = R.from_matrix(kp_rot).as_quat()
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[target_quat[3], target_quat[0], target_quat[1], target_quat[2]], target_pos])))
        posture_task = mink.PostureTask(mj_model, cost=1e-3)
        posture_task.set_target(mj_data.qpos)
        vel = mink.solve_ik(configuration, [kp_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
        configuration.integrate_inplace(vel, 0.01)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_step(mj_model, mj_data)

if 0:
    for target_pos in trajectory_gt_3d[:]: ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data)
    viewer = mujoco.viewer.launch_passive(mj_model, mj_data, show_left_ui=False, show_right_ui=False)
    while True:
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, get_link_poses_from_robot(robot_config, mj_model, mj_data))
        mujoco.mj_step(mj_model, mj_data)
        viewer.sync()
else:
    # render
    for target_pos in trajectory_gt_3d[:]: 
        ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data)

        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, *rgb.shape[:2])
        # Display results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb, rendered, (rgb * 0.5 + rendered * 0.5).astype(np.uint8)]): ax.imshow(img);ax.axis('off')
        plt.tight_layout()
        plt.show()
