"""Simple RGB and kinematics dataset recording script (UMI config).

Records RGB images and joint state at a fixed framerate. Uses UMI exo config;
joint state is recovered from IK on fixed_gripper + moveable_gripper link poses.
Saved format matches simple_dataset_record.py: image, joint_state, gripper_pose, camera_pose, cam_K_norm.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import mujoco
import cv2
import pickle
import numpy as np
import time
import random
import string
import argparse
from pathlib import Path

from ExoConfigs import EXOSKELETON_CONFIGS
from ExoConfigs.umi_so100 import UMI_SO100_CONFIG
from exo_utils import estimate_robot_state, detect_and_set_link_poses, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot
from robot_models.so100_controller import Arm

# Configuration
fps = 3  # Recording framerate
frame_interval = 1.0 / fps  # Time between frames in seconds

parser = argparse.ArgumentParser(description="Record RGB and joint state dataset (UMI; IK from gripper link poses)")
parser.add_argument("--exo", type=str, default="umi_so100",
                    choices=list(EXOSKELETON_CONFIGS.keys()),
                    help="Exoskeleton configuration to use")
parser.add_argument("--camera", type=int, default=0,
                    help="Camera device ID (default: 0)")
parser.add_argument("--use_robot_state", action="store_true",
                    help="Use direct robot joint states for rendering (else estimate from images)")
parser.add_argument("--render", action="store_true", help="Show corners_vis + rendered + overlay during recording")
parser.add_argument("--show_rgb", action="store_true", help="Show RGB only during recording")
parser.add_argument("--no_arm", action="store_true", help="No arm connected (estimate pose for rendering)")
parser.add_argument("--dont_save", action="store_true", help="Don't save images and joint states")
args = parser.parse_args()

# Generate unique output directory
def make_unique_id():
    return ''.join(random.choices(string.ascii_letters + string.digits, k=6))

unique_id = make_unique_id()
output_dir = Path(f"scratch/umi_capture_{unique_id}")
if not args.dont_save:
    output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")

robot_config = EXOSKELETON_CONFIGS.get(args.exo, UMI_SO100_CONFIG)
camera_device = args.camera

print(f"Using exoskeleton config: {args.exo} ({robot_config.name})")
print(f"Recording at {fps} fps (saving: image, joint_state, gripper_pose, camera_pose, cam_K_norm)")
print(f"Initializing camera device {camera_device}...")

# Optional: use robot state for rendering when arm is connected
arm = None
calib_path = "robot_models/arm_offsets/middleservo_calib_redo_fromimg.pkl"
if os.path.exists(calib_path) and not args.no_arm:
    arm = Arm(pickle.load(open(calib_path, 'rb')))
    print("✓ Connected to robot (use for rendering when --use_robot_state)")
else:
    args.use_robot_state = False

# Load model from config
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)

# Initialize camera
cap = cv2.VideoCapture(camera_device)
if not cap.isOpened():
    raise RuntimeError(f"Failed to open camera device {camera_device}")

# Get first frame to determine resolution
ret, frame = cap.read()
while not ret:
    ret, frame = cap.read()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
height, width = rgb.shape[:2]
print(f"Camera resolution: {width}x{height}")

# Downsampled resolution (half size)
ds_height_save, ds_width_save = height // 2, width // 2
vis_res_ds_factor = 5   
ds_height_vis, ds_width_vis = height // vis_res_ds_factor, width // vis_res_ds_factor

# Initialize renderer if rendering is enabled
renderer = None
if args.render:
    from mujoco.renderer import Renderer
    renderer = Renderer(model, height=height, width=width)
    print("Rendering enabled")

cam_K = None
timestep = 0

print("\n" + "="*60)
print("Recording started. Press 'q' to quit.")
print("="*60)

targ_robot_pos=np.array([-0.00268742, -1.6865245,   1.65632287,  1.51128661,  1.55649603,  0])

try:
    while True:
        frame_start_time = time.time()
        
        # Capture frame
        ret, frame = cap.read()
        if not ret:
            print("Failed to read frame from camera")
            continue
        
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Detect UMI link poses and camera pose
        camera_pose_world = None
        link_poses = None
        corners_vis = None
        try:
            link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                rgb, model, data, robot_config, cam_K=cam_K
            )
            if len(link_poses) != 3:
                print("skipping motion blur")
                continue
            position_exoskeleton_meshes(robot_config, model, data, link_poses)
            data.qpos=data.ctrl= targ_robot_pos#np.array([0, -1.57, 1.57, 1.57, 1.57, 0])*0 
            mujoco.mj_forward(model, data)
            configuration, ik_error = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=35, return_error=True)
            bad_ik = False
            if ik_error['position_mm'] > 5.0 :
                print("BAD IK")
                bad_ik = True
            if args.use_robot_state and arm is not None:
                data.qpos[:] = data.ctrl[:] = arm.get_pos()
            else:
                data.qpos[:] = data.ctrl[:] = configuration.q
            mujoco.mj_forward(model, data)
            joint_state = np.array(data.qpos.copy())
            fixed_gripper_pose = get_link_poses_from_robot(robot_config, model, data)["fixed_gripper"]
        except Exception as e:
            print(f"Error detecting link poses: {e}")
            camera_pose_world = None
            link_poses = None
            corners_vis = None
            if args.render or args.show_rgb:
                cv2.imshow('Recording', cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            continue

        # Downsample image for save (same as simple_dataset_record)
        rgb_downsampled_save = cv2.resize(rgb, (ds_width_save, ds_height_save), interpolation=cv2.INTER_LINEAR)

        # Save in same format as simple_dataset_record.py
        timestep_str = f"{timestep:06d}"
        image_path = output_dir / f"{timestep_str}.png"
        joint_path = output_dir / f"{timestep_str}.npy"
        gripper_pose_path = output_dir / f"{timestep_str}_gripper_pose.npy"
        camera_pose_path = output_dir / f"{timestep_str}_camera_pose.npy"
        cam_K_path = output_dir / f"{timestep_str}_cam_K_norm.npy"

        if not args.dont_save and link_poses is not None and not bad_ik:
            cv2.imwrite(str(image_path), cv2.cvtColor(rgb_downsampled_save, cv2.COLOR_RGB2BGR))
            np.save(joint_path, joint_state)
            np.save(gripper_pose_path, fixed_gripper_pose)
            np.save(camera_pose_path, camera_pose_world)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0] /= rgb.shape[1]
            cam_K_norm[1] /= rgb.shape[0]
            np.save(cam_K_path, cam_K_norm)
            print(f"Saved timestep {timestep_str}: {image_path.name}, {joint_path.name}")

        # Visualization: corners_vis + rendered + overlay (like Demos/umi_cam.py)
        if args.render and camera_pose_world is not None:
            rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, height, width)
            overlay = (rgb.astype(float) * 0.5 + rendered.astype(float) * 0.5).astype(np.uint8)
            if corners_vis is not None:
                display = np.hstack([corners_vis, rendered, overlay])
            else:
                display = np.hstack([rgb, rendered, overlay])
            display = cv2.resize(display, (display.shape[1] // 2, display.shape[0] // 2), interpolation=cv2.INTER_LINEAR)
            cv2.imshow('Recording', cv2.cvtColor(display, cv2.COLOR_RGB2BGR))
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        elif args.show_rgb:
            cv2.imshow('Recording', cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        
        timestep += 1
        
        # Maintain framerate
        elapsed = time.time() - frame_start_time
        sleep_time = max(0, frame_interval - elapsed)
        if sleep_time > 0:
            time.sleep(sleep_time)

except KeyboardInterrupt:
    print("\nRecording interrupted by user")

finally:
    cap.release()
    if args.render or args.show_rgb:
        cv2.destroyAllWindows()
    
    print(f"\n{'='*60}")
    print(f"Recording complete!")
    print(f"Saved {timestep} frames to {output_dir}")
    print(f"{'='*60}")

