"""Process dense dataset: process all timesteps in each episode."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
sys.path.append("Demos")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import argparse
import pickle
from pathlib import Path

import mujoco
import trimesh


from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (
    detect_and_set_link_poses,
    estimate_robot_state,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    detect_and_position_alignment_board,
    combine_xmls,
    get_link_poses_from_robot,
)

# Configuration constants
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12


def process_timestep(episode_dir, timestep_str, model, robot_config, device, model_seg, preprocess, model_dino, pca_embedder):
    """Process a single timestep in an episode."""
    # File paths
    image_path = episode_dir / f"{timestep_str}.png"
    joint_path = episode_dir / f"{timestep_str}.npy"
    
    # Load image
    rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
    if rgb.max() <= 1.0:
        rgb = (rgb * 255).astype(np.uint8)
    
    # Load joint state
    data = mujoco.MjData(model)
    if not joint_path.exists():
        print(f"  ⚠ No joint state file found for timestep {timestep_str}")
        return False
    
    qpos = np.load(joint_path)
    data.qpos[:] = qpos
    data.ctrl[:] = qpos[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Save raw joint state
    joint_state_path = episode_dir / f"joint_state_{timestep_str}.npy"
    np.save(joint_state_path, qpos)
    
    # Detect ArUco markers and get camera pose

    try:
        link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
            rgb, model, data, robot_config
        )
    except Exception as e:
        print(f"  ⚠ Failed to detect link poses for {timestep_str}: {e}")
        return False
    
    # Save camera pose
    camera_pose_path = episode_dir / f"robot_camera_pose_{timestep_str}.npy"
    np.save(camera_pose_path, camera_pose_world)
    
    # Save camera intrinsics
    cam_K_path = episode_dir / f"cam_K_{timestep_str}.npy"
    np.save(cam_K_path, cam_K)
    
    # Get and save gripper position (from mocap body)
    exo_mesh_body_name = "fixed_gripper_exo_mesh"
    exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
    exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
    gripper_pos = data.mocap_pos[exo_mesh_mocap_id].copy()
    
    gripper_pos_path = episode_dir / f"gripper_pos_{timestep_str}.npy"
    np.save(gripper_pos_path, gripper_pos)
    
    # Detect alignment board
    board_result = detect_and_position_alignment_board(
        rgb, model, data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False
    )
    if board_result is not None:
        board_pose, board_pts = board_result
        obj_img_pts["alignment_board"] = board_pts
    
    # Position exoskeleton meshes and update forward kinematics
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Get and save full gripper pose (SE3) from Fixed_Jaw body (the actual gripper body)
    gripper_body_name = "Fixed_Jaw"
    gripper_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, gripper_body_name)
    gripper_pos = data.xpos[gripper_body_id].copy()
    gripper_quat_wxyz = data.xquat[gripper_body_id].copy()  # wxyz format
    
    # Convert quaternion (wxyz) to rotation matrix
    gripper_quat_xyzw = gripper_quat_wxyz[[1, 2, 3, 0]]  # Convert to xyzw for scipy
    gripper_rot = R.from_quat(gripper_quat_xyzw).as_matrix()
    
    # Create 4x4 SE3 pose matrix
    gripper_pose = np.eye(4)
    gripper_pose[:3, :3] = gripper_rot
    gripper_pose[:3, 3] = gripper_pos
    
    # Save gripper pose
    gripper_pose_path = episode_dir / f"{timestep_str}_gripper_pose.npy"
    np.save(gripper_pose_path, gripper_pose)
    print(f"    ✓ Saved gripper pose: {gripper_pose_path.name}")
    
    # Render and save robot+supporting board mask (skip if already exists)
    robot_mask_path = episode_dir / f"robot_mask_{timestep_str}.png"
    if robot_mask_path.exists() and 0:
        print(f"    ⏭ Skipping robot mask rendering (robot_mask_{timestep_str}.png exists)")
    else:
        # Render segmentation mask (robot + supporting board)
        seg = render_from_camera_pose(
            model, data, camera_pose_world, cam_K, *rgb.shape[:2], segmentation=True
        )
        robot_mask = (seg[..., 0] > 0)
        
        # Save binary mask as PNG
        mask_binary = (robot_mask.astype(np.uint8) * 255)
        plt.imsave(robot_mask_path, mask_binary, cmap='gray')
        print(f"    ✓ Saved robot+board mask: {robot_mask_path.name}")
    #plt.figure(figsize=(10, 5))
    #plt.imshow(np.concatenate([rgb, robot_mask[...,None].repeat(3,axis=2)*255,rgb*robot_mask[...,None].repeat(3,axis=2)], axis=1))
    #plt.show()
    
    H, W = rgb.shape[:2]
    
    # Extract and save DINO features (skip if already exists)
    dino_features_path = episode_dir / f"dino_features_{timestep_str}.pt"
    dino_features_hw_path = episode_dir / f"dino_features_hw_{timestep_str}.pt"
    dino_vis_path = episode_dir / f"dino_features_vis_{timestep_str}.png"
    
    if dino_features_path.exists() and  dino_vis_path.exists():
        print(f"    ⏭ Skipping DINO feature extraction (dino_features_{timestep_str}.pt, dino_features_hw_{timestep_str}.pt and dino_features_vis_{timestep_str}.png exist)")
    else:
        # Track pixel coordinates for all points (use all pixels)
        y_coords, x_coords = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
        y_coords_flat = y_coords.reshape(-1)
        x_coords_flat = x_coords.reshape(-1)
        valid_y_coords = y_coords_flat  # Use all pixels
        valid_x_coords = x_coords_flat  # Use all pixels
        
        dino_features, dino_features_hw, pca_features_patches = run_dino_features(
            rgb, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords
        )
        
        torch.save(
            torch.from_numpy(dino_features.astype(np.float32)),
            dino_features_path
        )
        
        visualize_dino_features(rgb, pca_features_patches, H, W, episode_dir, timestep_str)
        print(f"    ✓ Saved DINO features: {len(dino_features)} points, {dino_features_hw.shape} full resolution")
    
    return True


def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))


def run_dino_features(rgb_scene, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords):
    """Extract DINO features for all valid points and full HxW resolution."""
    # Load and preprocess image for DINO
    img_pil = Image.fromarray(rgb_scene).convert("RGB")
    image_resized = resize_transform(img_pil)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract DINO features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = model_dino.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            # Apply PCA to reduce to 32 dimensions
            pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
            
            # Get patch resolution
            h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
            pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)
            
            # Upsample features to full image resolution just for visualization
            pca_features_tensor = torch.from_numpy(pca_features_patches).permute(2, 0, 1).float()  # (32, H_patches, W_patches)
            pca_features_upsampled = TF.resize(
                pca_features_tensor,
                (H, W),
                interpolation=TF.InterpolationMode.BILINEAR
            ).permute(1, 2, 0).numpy()  # (H, W, 32)
            
            # Sample features for each point using its pixel coordinate
    
    return pca_features_tensor.numpy(), pca_features_upsampled, pca_features_patches


def visualize_dino_features(rgb_scene, pca_features_patches, H, W, episode_dir, timestep_str):
    """Visualize DINO features and save visualization."""
    # Create visualization: map first 3 PCA components to RGB
    pca_rgb = pca_features_patches[:, :, :3]
    pca_rgb_normalized = torch.nn.functional.sigmoid(torch.from_numpy(pca_rgb).mul(2.0)).numpy()
    pca_rgb_upsampled = TF.resize(
        torch.from_numpy(pca_rgb_normalized).permute(2, 0, 1).float(),
        (H, W),
        interpolation=TF.InterpolationMode.BILINEAR
    ).permute(1, 2, 0).numpy()
    
    # Create visualization with RGB overlay
    img_normalized = rgb_scene.astype(float) / 255.0
    overlay = img_normalized * 0.5 + pca_rgb_upsampled * 0.5
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(rgb_scene)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(pca_rgb_upsampled)
    axes[1].set_title('DINO PCA (first 3 components)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay: RGB + DINO')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(episode_dir / f"dino_features_vis_{timestep_str}.png", dpi=150, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process dense dataset from parsed episodes")
    parser.add_argument("--input_dir", "-i", type=str, required=True,
                        help="Input directory with parsed episodes")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    
    # DINOv3 configuration
    REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
    WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
    
    print("=" * 60)
    print("Finding all episodes")
    print("=" * 60)
    
    # Find all episode directories
    episode_dirs = sorted([d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
    
    print(f"Found {len(episode_dirs)} episodes")
    for ep_dir in episode_dirs:
        print(f"  - {ep_dir.name}")
    
    # Setup robot config and models (load once, reuse)
    SO100AdhesiveConfig.exo_alpha = 0.2
    SO100AdhesiveConfig.aruco_alpha = 0.2
    robot_config = SO100AdhesiveConfig()
    combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
    model = mujoco.MjModel.from_xml_string(combined_xml)
    
    # Load models once
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    print("Loading DeepLabV3 segmentation model...")
    model_seg = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    model_seg.eval()
    model_seg = model_seg.to(device)
    
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Load DINOv3 model and PCA embedder
    print("Loading DINOv3 model...")
    with torch.inference_mode():
        model_dino = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
        model_dino.eval()
    
    pca_path = "scratch/dino_pca_embedder.pkl"
    if not os.path.exists(pca_path):
        raise FileNotFoundError(f"PCA embedder not found: {pca_path}. Please run get_dino_pca_emb.py first.")

    print(f"Loading PCA embedder from {pca_path}")
    with open(pca_path, 'rb') as f:
        pca_data = pickle.load(f)
        pca_embedder = pca_data['pca']
        print(f"PCA embedder: {pca_embedder.n_components_} dimensions")
    
    print("\n" + "=" * 60)
    print("Processing episodes")
    print("=" * 60)
    
    # Process each episode
    for episode_dir in tqdm(episode_dirs, desc="Processing episodes"):
        print(f"\n{'='*60}")
        print(f"Processing episode: {episode_dir.name}")
        print(f"{'='*60}")
        
        # Find all image files in this episode
        image_files = sorted(episode_dir.glob("*.png"))
        
        if len(image_files) == 0:
            print(f"  ⚠ No images found in {episode_dir.name}, skipping")
            continue
        
        # Process each timestep
        timesteps_processed = 0
        timesteps_failed = 0
        
        for img_path in tqdm(image_files, desc=f"  {episode_dir.name}", leave=False):
            timestep_str = img_path.stem
            
            success = process_timestep(
                episode_dir,
                timestep_str,
                model,
                robot_config,
                device,
                model_seg,
                preprocess,
                model_dino,
                pca_embedder
            )
            
            if success:
                timesteps_processed += 1
            else:
                timesteps_failed += 1
        
        print(f"  ✓ Processed {timesteps_processed} timesteps, {timesteps_failed} failed")
    
    print(f"\n{'='*60}")
    print(f"Done! Processed {len(episode_dirs)} episodes")
    print(f"Input directory: {input_dir}")
    print(f"{'='*60}")

