"""Add MoGe pointmaps (robot-frame aligned) and depth masks to existing dataset episodes."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'Demos'))

import cv2
import numpy as np
import torch
import mujoco
from tqdm import tqdm
import argparse
from pathlib import Path

from moge.model.v2 import MoGeModel
import utils3d

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import (
    detect_and_set_link_poses,
    estimate_robot_state,
    position_exoskeleton_meshes,
)
from demo_utils import procrustes_alignment


def process_timestep(episode_dir, timestep_str, model, robot_config, device, model_moge):
    """Process a single timestep: extract MoGe pointmap, align to robot frame, and save."""
    # File paths
    image_path = episode_dir / f"{timestep_str}.png"
    joint_path = episode_dir / f"{timestep_str}.npy"
    
    if not image_path.exists() or not joint_path.exists():
        print(f"  ⚠ Missing image or joint file for timestep {timestep_str}")
        return False
    
    # Check if already processed
    pointmap_robot_path = episode_dir / f"moge_pointmap_{timestep_str}.pt"
    moge_mask_path = episode_dir / f"moge_mask_{timestep_str}.npy"
    
    if pointmap_robot_path.exists() and moge_mask_path.exists():
        print(f"    ⏭ Skipping MoGe processing (moge_pointmap_{timestep_str}.pt and moge_mask_{timestep_str}.npy exist)")
        return True
    
    # Load image
    rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    
    H, W = rgb.shape[:2]
    
    # Load joint state
    qpos = np.load(joint_path)
    
    # Setup MuJoCo
    data = mujoco.MjData(model)
    data.qpos[:] = qpos
    data.ctrl[:] = qpos[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Detect ArUco markers and get camera pose
    try:
        link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
            rgb, model, data, robot_config
        )
    except Exception as e:
        print(f"  ⚠ Failed to detect link poses for {timestep_str}: {e}")
        return False
    
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Get MoGe pointcloud in camera frame
    moge_output = get_moge_pointcloud(rgb, model_moge, device, episode_dir, timestep_str)
    points_cam, mask, moge_edge_mask, H_moge, W_moge = moge_output
    
    # Ensure H, W match MoGe output
    if H_moge != H or W_moge != W:
        print(f"    ⚠ Image size mismatch: RGB ({H}x{W}) vs MoGe ({H_moge}x{W_moge})")
        # Resize RGB to match MoGe if needed
        rgb = cv2.resize(rgb, (W_moge, H_moge), interpolation=cv2.INTER_LINEAR)
        H, W = H_moge, W_moge
    
    # Compute Procrustes alignment transformation
    moge_to_robot_frame_path = episode_dir / f"moge_to_robot_frame_{timestep_str}.npy"
    
    if moge_to_robot_frame_path.exists():
        print(f"    ⏭ Using existing Procrustes alignment (moge_to_robot_frame_{timestep_str}.npy)")
        T_procrustes = np.load(moge_to_robot_frame_path)
    else:
        # Collect ArUco correspondences for Procrustes alignment
        aruco_corners_robot_frame = []
        moge_aruco_corners = []
        
        for obj_name, (obj_img_pts_cam, img_pts_px) in obj_img_pts.items():
            if obj_name not in ["alignment_board", "larger_base"]:
                continue
            
            # Transform ArUco corners from camera frame to robot frame
            obj_cam_h = np.hstack([obj_img_pts_cam, np.ones((obj_img_pts_cam.shape[0], 1))])
            obj_robot = (np.linalg.inv(camera_pose_world) @ obj_cam_h.T).T[:, :3]
            aruco_corners_robot_frame.extend(obj_robot)
            
            # Sample MoGe pointcloud at ArUco pixel coordinates
            for pt in img_pts_px:
                x, y = int(pt[0]), int(pt[1])
                if 0 <= y < H and 0 <= x < W:
                    moge_aruco_corners.append(points_cam[y, x])
        
        aruco_corners_robot_frame = np.array(aruco_corners_robot_frame)
        moge_aruco_corners = np.array(moge_aruco_corners)
        
        if len(aruco_corners_robot_frame) >= 3 and len(moge_aruco_corners) >= 3:
            T_procrustes, scale, rotation, translation = procrustes_alignment(
                aruco_corners_robot_frame, moge_aruco_corners
            )
            np.save(moge_to_robot_frame_path, T_procrustes)
            print(f"    ✓ Computed Procrustes alignment: {len(aruco_corners_robot_frame)} correspondences")
        else:
            print(f"    ⚠ Insufficient correspondences for Procrustes ({len(aruco_corners_robot_frame)} robot, {len(moge_aruco_corners)} MoGe), using identity")
            T_procrustes = np.eye(4)
            np.save(moge_to_robot_frame_path, T_procrustes)
    
    # Transform pointmap from camera frame to robot frame
    points_cam_flat = points_cam.reshape(-1, 3)  # (H*W, 3)
    points_cam_h = np.hstack([points_cam_flat, np.ones((len(points_cam_flat), 1))])  # (H*W, 4)
    points_robot_flat = (T_procrustes @ points_cam_h.T).T[:, :3]  # (H*W, 3)
    points_robot = points_robot_flat.reshape(H, W, 3)  # (H, W, 3)
    
    # Save robot-frame pointmap
    colors_flat = rgb.reshape(-1, 3)
    pointmap_robot = {
        "points": torch.from_numpy(points_robot.astype(np.float32)),  # (H, W, 3)
        "colors": torch.from_numpy(colors_flat.astype(np.uint8).reshape(H, W, 3)),  # (H, W, 3)
        "mask": torch.from_numpy(mask.astype(bool)),  # (H, W)
    }
    torch.save(pointmap_robot, pointmap_robot_path)
    print(f"    ✓ Saved robot-frame pointmap: {pointmap_robot_path.name}")
    
    # Save depth-based mask (moge_edge_mask from depth_map_edge)
    np.save(moge_mask_path, moge_edge_mask.astype(bool))
    print(f"    ✓ Saved depth mask: {moge_mask_path.name}")
    
    return True


def get_moge_pointcloud(rgb_scene, model_moge, device, episode_dir, timestep_str):
    """Get MoGe pointcloud, loading from cache if available."""
    moge_output_path = episode_dir / f"moge_output_{timestep_str}_raw.pt"
    
    if not moge_output_path.exists():
        input_tensor_moge = torch.tensor(rgb_scene / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
        with torch.no_grad():
            output_moge = model_moge.infer(input_tensor_moge)
        torch.save(output_moge, moge_output_path)
    else:
        output_moge = torch.load(moge_output_path, weights_only=False)
    
    points = output_moge["points"].cpu().numpy()  # (H, W, 3)
    mask = output_moge["mask"].cpu().numpy().astype(bool)
    
    # Compute depth edge mask (depth-based mask)
    moge_edge_mask = utils3d.np.depth_map_edge(points[:, :, 2], rtol=0.005)
    
    H, W = points.shape[:2]
    return points, mask, moge_edge_mask, H, W


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add MoGe pointmaps (robot-frame aligned) to existing dataset episodes")
    parser.add_argument("--input_dir", "-i", type=str, required=True,
                        help="Input directory with parsed episodes")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    
    print("=" * 60)
    print("Finding all episodes")
    print("=" * 60)
    
    # Find all episode directories
    episode_dirs = sorted([d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
    
    print(f"Found {len(episode_dirs)} episodes")
    for ep_dir in episode_dirs:
        print(f"  - {ep_dir.name}")
    
    # Setup robot config and models (load once, reuse)
    SO100AdhesiveConfig.exo_alpha = 0.2
    SO100AdhesiveConfig.aruco_alpha = 0.2
    robot_config = SO100AdhesiveConfig()
    model = mujoco.MjModel.from_xml_string(robot_config.xml)
    
    # Load models once
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    print("Loading MoGe model...")
    model_moge = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
    model_moge.eval()
    
    print("\n" + "=" * 60)
    print("Processing episodes")
    print("=" * 60)
    
    # Process each episode
    for episode_dir in tqdm(episode_dirs, desc="Processing episodes"):
        print(f"\n{'='*60}")
        print(f"Processing episode: {episode_dir.name}")
        print(f"{'='*60}")
        
        # Find all image files in this episode
        image_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
        
        if len(image_files) == 0:
            print(f"  ⚠ No images found in {episode_dir.name}, skipping")
            continue
        
        # Process each timestep
        timesteps_processed = 0
        timesteps_failed = 0
        
        for img_path in tqdm(image_files, desc=f"  {episode_dir.name}", leave=False):
            timestep_str = img_path.stem
            
            success = process_timestep(
                episode_dir,
                timestep_str,
                model,
                robot_config,
                device,
                model_moge
            )
            
            if success:
                timesteps_processed += 1
            else:
                timesteps_failed += 1
        
        print(f"  ✓ Processed {timesteps_processed} timesteps, {timesteps_failed} failed")
    
    print(f"\n{'='*60}")
    print(f"Done! Processed {len(episode_dirs)} episodes")
    print(f"Input directory: {input_dir}")
    print(f"{'='*60}")
