"""Simple RGB and kinematics dataset recording script.

Records RGB images and joint states at a fixed framerate.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import mujoco
import cv2
import pickle
import numpy as np
import time
import random
import string
import argparse
from pathlib import Path

from ExoConfigs import EXOSKELETON_CONFIGS
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import estimate_robot_state, detect_and_set_link_poses, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot
from robot_models.so100_controller import Arm

# Configuration
fps = 3  # Recording framerate
frame_interval = 1.0 / fps  # Time between frames in seconds

parser = argparse.ArgumentParser(description="Record RGB and joint state dataset")
parser.add_argument("--exo", type=str, default="so100_adhesive", 
                    choices=list(EXOSKELETON_CONFIGS.keys()), 
                    help="Exoskeleton configuration to use")
parser.add_argument("--camera", type=int, default=0, 
                    help="Camera device ID (default: 0)")
parser.add_argument("--use_robot_state", action="store_true",
                    help="Use direct robot joint states instead of estimating from images")
parser.add_argument("--render", action="store_true", help="Enable rendering visualization during recording")
parser.add_argument("--show_rgb", action="store_true", help="Enable rendering visualization during recording")
parser.add_argument("--no_arm", action="store_true", help="No arm connected")
parser.add_argument("--dont_save", action="store_true", help="Don't save images and joint states")
args = parser.parse_args()

# Generate unique output directory
def make_unique_id():
    return ''.join(random.choices(string.ascii_letters + string.digits, k=6))

unique_id = make_unique_id()
output_dir = Path(f"scratch/rgb_joints_capture_{unique_id}")
if not args.dont_save: output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")

# Configuration
SO100AdhesiveConfig.exo_alpha = 0.2
SO100AdhesiveConfig.aruco_alpha = 0.8
robot_config = SO100AdhesiveConfig()
camera_device = args.camera

print(f"Using exoskeleton config: {args.exo} ({robot_config.name})")
print(f"Recording at {fps} fps")
print(f"Initializing camera device {camera_device}...")

# Initialize robot if using direct robot state
arm = None
args.use_robot_state = True
calib_path = "robot_models/arm_offsets/rescrew_school_fromimg.pkl"
if os.path.exists(calib_path) and not args.no_arm:
    arm = Arm(pickle.load(open(calib_path, 'rb')))
    print("✓ Connected to robot for direct joint state reading")
else:
    print(f"Warning: Calibration file not found at {calib_path}, falling back to image-based estimation")
    args.use_robot_state = False

# Load model from config
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)

# Initialize camera
cap = cv2.VideoCapture(camera_device)
if not cap.isOpened():
    raise RuntimeError(f"Failed to open camera device {camera_device}")

# Get first frame to determine resolution
ret, frame = cap.read()
while not ret:
    ret, frame = cap.read()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
height, width = rgb.shape[:2]
print(f"Camera resolution: {width}x{height}")

# Downsampled resolution (half size)
ds_height_save, ds_width_save = height // 2, width // 2
vis_res_ds_factor = 5   
ds_height_vis, ds_width_vis = height // vis_res_ds_factor, width // vis_res_ds_factor

# Initialize renderer if rendering is enabled
renderer = None
if args.render:
    from mujoco.renderer import Renderer
    renderer = Renderer(model, height=height, width=width)
    print("Rendering enabled")

cam_K = None
timestep = 0

print("\n" + "="*60)
print("Recording started. Press 'q' to quit.")
print("="*60)

try:
    while True:
        frame_start_time = time.time()
        
        # Capture frame
        ret, frame = cap.read()
        if not ret:
            print("Failed to read frame from camera")
            continue
        
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Get joint states
        camera_pose_world = None
        # Use direct robot joint states
        if not args.no_arm:
            joint_state = arm.get_pos()
            data.qpos[:] = data.ctrl[:] = joint_state
            mujoco.mj_forward(model, data)
        # Still need camera pose for rendering if enabled
        try:
            link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses( rgb, model, data, robot_config, cam_K=cam_K)
            position_exoskeleton_meshes(robot_config, model, data, link_poses)
            configuration, _ = estimate_robot_state(model, data, robot_config, link_poses, ik_iterations=35)
            #data.qpos[:] = data.ctrl[:] = configuration.q
            mujoco.mj_forward(model, data)
        except Exception as e:
            print(f"Error detecting link poses: {e}")
            camera_pose_world = None
            continue
        
        # Downsample image
        rgb_downsampled_save = cv2.resize(rgb, (ds_width_save, ds_height_save), interpolation=cv2.INTER_LINEAR)
        rgb_downsampled_vis = cv2.resize(rgb, (ds_width_vis, ds_height_vis), interpolation=cv2.INTER_LINEAR)
        
        # Save image and joint state
        timestep_str = f"{timestep:06d}"
        image_path = output_dir / f"{timestep_str}.png"
        joint_path = output_dir / f"{timestep_str}.npy"
        gripper_pose_path = output_dir / f"{timestep_str}_gripper_pose.npy"
        camera_pose_path = output_dir / f"{timestep_str}_camera_pose.npy"
        cam_K_path = output_dir / f"{timestep_str}_cam_K_norm.npy"

        fixed_gripper_pose = get_link_poses_from_robot(robot_config, model, data)["fixed_gripper"]
        
        # Save RGB image (as RGB, not BGR)
        if not args.dont_save:# and len(link_poses)==5:
            print("saving image and joint state")
            cv2.imwrite(str(image_path), cv2.cvtColor(rgb_downsampled_save, cv2.COLOR_RGB2BGR))
            np.save(joint_path, joint_state)
            #np.save(gripper_pose_path, link_poses["fixed_gripper"])
            np.save(gripper_pose_path, fixed_gripper_pose)
            np.save(camera_pose_path, camera_pose_world)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0]/=rgb.shape[1]
            cam_K_norm[1]/=rgb.shape[0]
            np.save(cam_K_path, cam_K_norm)
            print(f"Saved timestep {timestep_str}: {image_path.name}, {joint_path.name}")
        
        # Optional rendering
        if args.render and camera_pose_world is not None:
            cam_K_low_res=cam_K.copy()
            cam_K_low_res[0] = cam_K_low_res[0] // vis_res_ds_factor
            cam_K_low_res[1] = cam_K_low_res[1] // vis_res_ds_factor
            rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K_low_res, *rgb_downsampled_vis.shape[:2])
            overlay = (rgb_downsampled_vis.astype(float) * 0.5 + rendered.astype(float) * 0.5).astype(np.uint8)
            display = np.hstack([rgb_downsampled_vis, rendered, overlay])
            display = cv2.resize(display,np.array(display.shape[:2][::-1])*4)
            cv2.imshow('Recording', cv2.cvtColor(display, cv2.COLOR_RGB2BGR))
        
            if cv2.waitKey(1) & 0xFF== ord('q'): break
        elif args.show_rgb:
            cv2.imshow('Recording', cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
            if cv2.waitKey(1) & 0xFF== ord('q'): break
        
        timestep += 1
        
        # Maintain framerate
        elapsed = time.time() - frame_start_time
        sleep_time = max(0, frame_interval - elapsed)
        if sleep_time > 0:
            time.sleep(sleep_time)

except KeyboardInterrupt:
    print("\nRecording interrupted by user")

finally:
    cap.release()
    if args.render:
        cv2.destroyAllWindows()
    
    print(f"\n{'='*60}")
    print(f"Recording complete!")
    print(f"Saved {timestep} frames to {output_dir}")
    print(f"{'='*60}")

