"""Process dense dataset: process all timesteps in each episode."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
sys.path.append("Demos")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import argparse
import pickle
from pathlib import Path

import mujoco
import trimesh

from moge.model.v2 import MoGeModel
import utils3d

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (
    detect_and_set_link_poses,
    estimate_robot_state,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    detect_and_position_alignment_board,
    combine_xmls,
    get_link_poses_from_robot,
)
from demo_utils import procrustes_alignment

# Configuration constants
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12


def process_timestep(episode_dir, timestep_str, model, robot_config, device, model_seg, model_moge, preprocess, model_dino, pca_embedder):
    """Process a single timestep in an episode."""
    # File paths
    image_path = episode_dir / f"{timestep_str}.png"
    joint_path = episode_dir / f"{timestep_str}.npy"
    
    if not image_path.exists() or not joint_path.exists():
        return False
    
    # Load image
    rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    
    # Load joint state
    qpos = np.load(joint_path)
    
    # Setup MuJoCo
    data = mujoco.MjData(model)
    data.qpos[:] = qpos
    data.ctrl[:] = qpos[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Detect ArUco markers and get camera pose
    try:
        link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
            rgb, model, data, robot_config
        )
    except Exception as e:
        print(f"  ⚠ Failed to detect link poses for {timestep_str}: {e}")
        return False
    
    # Save camera pose
    camera_pose_path = episode_dir / f"robot_camera_pose_{timestep_str}.npy"
    np.save(camera_pose_path, camera_pose_world)
    
    # Save camera intrinsics
    cam_K_path = episode_dir / f"cam_K_{timestep_str}.npy"
    np.save(cam_K_path, cam_K)
    
    # Get and save gripper position
    exo_mesh_body_name = "fixed_gripper_exo_mesh"
    exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
    exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
    gripper_pos = data.mocap_pos[exo_mesh_mocap_id].copy()
    
    gripper_pos_path = episode_dir / f"gripper_pos_{timestep_str}.npy"
    np.save(gripper_pos_path, gripper_pos)
    
    # Check for gripper pose file and add gripper STL if it exists
    gripper_pose_path = episode_dir / f"{timestep_str}_gripper_pose.npy"
    if gripper_pose_path.exists():
        gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix
        
        # Load gripper STL
        fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
        fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
        if isinstance(fixed_gripper_mesh, trimesh.Scene):
            fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]
        
        # Check if mesh is in mm and scale to meters if needed
        bounds = fixed_gripper_mesh.bounds
        max_extent = np.max(bounds[1] - bounds[0])
        if max_extent > 1.0:
            fixed_gripper_mesh.apply_scale(0.001)
        
        # Transform mesh by gripper pose
        fixed_gripper_mesh.apply_transform(gripper_pose)
        
        # Save transformed mesh
        gripper_mesh_path = episode_dir / f"gripper_mesh_{timestep_str}.stl"
        fixed_gripper_mesh.export(gripper_mesh_path)
        print(f"    ✓ Saved gripper mesh: {gripper_mesh_path.name}")
    
    # Detect alignment board
    board_result = detect_and_position_alignment_board(
        rgb, model, data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False
    )
    if board_result is not None:
        board_pose, board_pts = board_result
        obj_img_pts["alignment_board"] = board_pts
    
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Render and save robot+supporting board mask (skip if already exists)
    robot_mask_path = episode_dir / f"robot_mask_{timestep_str}.png"
    if robot_mask_path.exists() and 0:
        print(f"    ⏭ Skipping robot mask rendering (robot_mask_{timestep_str}.png exists)")
    else:
        # Render segmentation mask (robot + supporting board)
        seg = render_from_camera_pose(
            model, data, camera_pose_world, cam_K, *rgb.shape[:2], segmentation=True
        )
        robot_mask = (seg[..., 0] > 0)
        
        # Save binary mask as PNG
        mask_binary = (robot_mask.astype(np.uint8) * 255)
        plt.imsave(robot_mask_path, mask_binary, cmap='gray')
        print(f"    ✓ Saved robot+board mask: {robot_mask_path.name}")
    #plt.figure(figsize=(10, 5))
    #plt.imshow(np.concatenate([rgb, robot_mask[...,None].repeat(3,axis=2)*255,rgb*robot_mask[...,None].repeat(3,axis=2)], axis=1))
    #plt.show()
    
    # Get MoGe pointcloud (skip if already exists)
    H, W = rgb.shape[:2]
    pointmap_raw_path = episode_dir / f"pointmap_{timestep_str}_raw.pt"
    
    if pointmap_raw_path.exists():
        print(f"    ⏭ Skipping MoGe inference (pointmap_{timestep_str}_raw.pt exists)")
        pointmap_raw = torch.load(pointmap_raw_path)
        points_cam_flat = pointmap_raw["points"].numpy()
        mask_flat = pointmap_raw["mask"].numpy().astype(bool)
        points_cam = points_cam_flat.reshape(H, W, 3)
        mask = mask_flat.reshape(H, W)
    else:
        moge_output = get_moge_pointcloud(rgb, model_moge, device, episode_dir, timestep_str)
        points_cam, mask, moge_edge_mask, H, W = moge_output
        
        # Save raw pointmap in camera frame
        points_cam_flat = points_cam.reshape(-1, 3)
        mask_flat = mask.reshape(-1)
        colors_flat = rgb.reshape(-1, 3)
        pointmap_raw = {
            "points": torch.from_numpy(points_cam_flat.astype(np.float32)),
            "colors": torch.from_numpy(colors_flat.astype(np.uint8)),
            "mask": torch.from_numpy(mask_flat.astype(bool)),
        }
        torch.save(pointmap_raw, pointmap_raw_path)
    
    # Compute Procrustes alignment transformation
    moge_to_robot_frame_path = episode_dir / f"moge_to_robot_frame_{timestep_str}.npy"
    
    if moge_to_robot_frame_path.exists():
        print(f"    ⏭ Skipping Procrustes alignment (moge_to_robot_frame_{timestep_str}.npy exists)")
        T_procrustes = np.load(moge_to_robot_frame_path)
    else:
        # Collect ArUco correspondences for Procrustes alignment
        aruco_corners_robot_frame = []
        moge_aruco_corners = []
        
        for obj_name, (obj_img_pts_cam, img_pts_px) in obj_img_pts.items():
            if obj_name not in ["alignment_board", "larger_base"]:
                continue
            
            # Transform ArUco corners from camera frame to robot frame
            obj_cam_h = np.hstack([obj_img_pts_cam, np.ones((obj_img_pts_cam.shape[0], 1))])
            obj_robot = (np.linalg.inv(camera_pose_world) @ obj_cam_h.T).T[:, :3]
            aruco_corners_robot_frame.extend(obj_robot)
            
            # Sample MoGe pointcloud at ArUco pixel coordinates
            for pt in img_pts_px:
                x, y = int(pt[0]), int(pt[1])
                if 0 <= y < H and 0 <= x < W:
                    moge_aruco_corners.append(points_cam[y, x])
        
        aruco_corners_robot_frame = np.array(aruco_corners_robot_frame)
        moge_aruco_corners = np.array(moge_aruco_corners)
        
        if len(aruco_corners_robot_frame) >= 3 and len(moge_aruco_corners) >= 3:
            T_procrustes, scale, rotation, translation = procrustes_alignment(
                aruco_corners_robot_frame, moge_aruco_corners
            )
            np.save(moge_to_robot_frame_path, T_procrustes)
            print(f"    ✓ Computed Procrustes alignment: {len(aruco_corners_robot_frame)} correspondences")
        else:
            print(f"    ⚠ Insufficient correspondences for Procrustes ({len(aruco_corners_robot_frame)} robot, {len(moge_aruco_corners)} MoGe), using identity")
            T_procrustes = np.eye(4)
            np.save(moge_to_robot_frame_path, T_procrustes)
    
    # Extract and save DINO features (skip if already exists)
    dino_features_path = episode_dir / f"dino_features_{timestep_str}.pt"
    dino_features_hw_path = episode_dir / f"dino_features_hw_{timestep_str}.pt"
    dino_vis_path = episode_dir / f"dino_features_vis_{timestep_str}.png"
    
    if dino_features_path.exists() and dino_features_hw_path.exists() and dino_vis_path.exists():
        print(f"    ⏭ Skipping DINO feature extraction (dino_features_{timestep_str}.pt, dino_features_hw_{timestep_str}.pt and dino_features_vis_{timestep_str}.png exist)")
    else:
        # Track pixel coordinates for all points (no masking, so use all valid MoGe points)
        y_coords, x_coords = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
        y_coords_flat = y_coords.reshape(-1)
        x_coords_flat = x_coords.reshape(-1)
        valid_y_coords = y_coords_flat[mask_flat]
        valid_x_coords = x_coords_flat[mask_flat]
        
        dino_features, dino_features_hw, pca_features_patches = run_dino_features(
            rgb, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords
        )
        
        torch.save(
            torch.from_numpy(dino_features.astype(np.float32)),
            dino_features_path
        )
        torch.save(
            torch.from_numpy(dino_features_hw.astype(np.float32)),
            dino_features_hw_path
        )
        visualize_dino_features(rgb, pca_features_patches, H, W, episode_dir, timestep_str)
        print(f"    ✓ Saved DINO features: {len(dino_features)} points, {dino_features_hw.shape} full resolution")
    
    return True


def get_moge_pointcloud(rgb_scene, model_moge, device, episode_dir, timestep_str):
    """Get MoGe pointcloud, loading from cache if available."""
    moge_output_path = episode_dir / f"moge_output_{timestep_str}_raw.pt"
    
    if not moge_output_path.exists():
        input_tensor_moge = torch.tensor(rgb_scene / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
        with torch.no_grad():
            output_moge = model_moge.infer(input_tensor_moge)
        torch.save(output_moge, moge_output_path)
    else:
        output_moge = torch.load(moge_output_path, weights_only=False)
    
    points = output_moge["points"].cpu().numpy()  # (H, W, 3)
    mask = output_moge["mask"].cpu().numpy().astype(bool)
    
    # Compute depth edge mask separately (but don't apply it to mask, just save it)
    moge_edge_mask = utils3d.np.depth_map_edge(points[:, :, 2], rtol=0.005)
    # Don't apply edge mask to main mask - keep all valid MoGe points
    
    H, W = points.shape[:2]
    return points, mask, moge_edge_mask, H, W


def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))


def run_dino_features(rgb_scene, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords):
    """Extract DINO features for all valid points and full HxW resolution."""
    # Load and preprocess image for DINO
    img_pil = Image.fromarray(rgb_scene).convert("RGB")
    image_resized = resize_transform(img_pil)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract DINO features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = model_dino.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            # Apply PCA to reduce to 32 dimensions
            pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
            
            # Get patch resolution
            h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
            pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)
            
            # Upsample features to full image resolution
            pca_features_tensor = torch.from_numpy(pca_features_patches).permute(2, 0, 1).float()  # (32, H_patches, W_patches)
            pca_features_upsampled = TF.resize(
                pca_features_tensor,
                (H, W),
                interpolation=TF.InterpolationMode.BILINEAR
            ).permute(1, 2, 0).numpy()  # (H, W, 32)
            
            # Sample features for each point using its pixel coordinate
            dino_features = np.zeros((len(valid_y_coords), pca_embedder.n_components_))
            for i, (y, x) in enumerate(zip(valid_y_coords, valid_x_coords)):
                if 0 <= y < H and 0 <= x < W:
                    dino_features[i] = pca_features_upsampled[y, x]
    
    return dino_features, pca_features_upsampled, pca_features_patches


def visualize_dino_features(rgb_scene, pca_features_patches, H, W, episode_dir, timestep_str):
    """Visualize DINO features and save visualization."""
    # Create visualization: map first 3 PCA components to RGB
    pca_rgb = pca_features_patches[:, :, :3]
    pca_rgb_normalized = torch.nn.functional.sigmoid(torch.from_numpy(pca_rgb).mul(2.0)).numpy()
    pca_rgb_upsampled = TF.resize(
        torch.from_numpy(pca_rgb_normalized).permute(2, 0, 1).float(),
        (H, W),
        interpolation=TF.InterpolationMode.BILINEAR
    ).permute(1, 2, 0).numpy()
    
    # Create visualization with RGB overlay
    img_normalized = rgb_scene.astype(float) / 255.0
    overlay = img_normalized * 0.5 + pca_rgb_upsampled * 0.5
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(rgb_scene)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(pca_rgb_upsampled)
    axes[1].set_title('DINO PCA (first 3 components)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay: RGB + DINO')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(episode_dir / f"dino_features_vis_{timestep_str}.png", dpi=150, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process dense dataset from parsed episodes")
    parser.add_argument("--input_dir", type=str, default="scratch/parsed_episodes_cup_synch",
                        help="Input directory with parsed episodes")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    
    # DINOv3 configuration
    REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
    WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
    
    print("=" * 60)
    print("Finding all episodes")
    print("=" * 60)
    
    # Find all episode directories
    episode_dirs = sorted([d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
    
    print(f"Found {len(episode_dirs)} episodes")
    for ep_dir in episode_dirs:
        print(f"  - {ep_dir.name}")
    
    # Setup robot config and models (load once, reuse)
    SO100AdhesiveConfig.exo_alpha = 0.2
    SO100AdhesiveConfig.aruco_alpha = 0.2
    robot_config = SO100AdhesiveConfig()
    combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
    model = mujoco.MjModel.from_xml_string(combined_xml)
    
    # Load models once
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    print("Loading DeepLabV3 segmentation model...")
    model_seg = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    model_seg.eval()
    model_seg = model_seg.to(device)
    
    print("Loading MoGE model...")
    model_moge = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
    model_moge.eval()
    
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Load DINOv3 model and PCA embedder
    print("Loading DINOv3 model...")
    with torch.inference_mode():
        model_dino = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
        model_dino.eval()
    
    pca_path = "scratch/dino_pca_embedder.pkl"
    if not os.path.exists(pca_path):
        raise FileNotFoundError(f"PCA embedder not found: {pca_path}. Please run get_dino_pca_emb.py first.")

    print(f"Loading PCA embedder from {pca_path}")
    with open(pca_path, 'rb') as f:
        pca_data = pickle.load(f)
        pca_embedder = pca_data['pca']
        print(f"PCA embedder: {pca_embedder.n_components_} dimensions")
    
    print("\n" + "=" * 60)
    print("Processing episodes")
    print("=" * 60)
    
    # Process each episode
    for episode_dir in tqdm(episode_dirs, desc="Processing episodes"):
        print(f"\n{'='*60}")
        print(f"Processing episode: {episode_dir.name}")
        print(f"{'='*60}")
        
        # Find all image files in this episode
        image_files = sorted(episode_dir.glob("*.png"))
        
        if len(image_files) == 0:
            print(f"  ⚠ No images found in {episode_dir.name}, skipping")
            continue
        
        # Process each timestep
        timesteps_processed = 0
        timesteps_failed = 0
        
        for img_path in tqdm(image_files, desc=f"  {episode_dir.name}", leave=False):
            timestep_str = img_path.stem
            
            success = process_timestep(
                episode_dir,
                timestep_str,
                model,
                robot_config,
                device,
                model_seg,
                model_moge,
                preprocess,
                model_dino,
                pca_embedder
            )
            
            if success:
                timesteps_processed += 1
            else:
                timesteps_failed += 1
        
        print(f"  ✓ Processed {timesteps_processed} timesteps, {timesteps_failed} failed")
    
    print(f"\n{'='*60}")
    print(f"Done! Processed {len(episode_dirs)} episodes")
    print(f"Input directory: {input_dir}")
    print(f"{'='*60}")

