"""Dataset for real dense trajectory prediction."""
import torch
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
import math

# Number of waypoints in trajectory
N_WINDOW = 12

# Hard-coded gripper range for regression
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8

def process_gripper_value(gripper_value):
    """Process gripper value: fix wrapping and clamp to range.
    
    Args:
        gripper_value: scalar gripper joint value
    
    Returns:
        Processed gripper value in [MIN_GRIPPER, MAX_GRIPPER] = [-0.2, 0.8]
        Processing steps:
        1. Map values > 4.0 to -0.2 (wraparound fix)
        2. Clamp values > 0.8 to 0.8
        3. Ensure final range is [-0.2, 0.8]
    """
    # Pseudocode: x = gripper_inp; if x > 4: x = -0.2; if x > 0.8: x = 0.8
    x = gripper_value
    if x > 4.0:
        x = -0.2
    if x > 0.8:
        x = 0.8
    
    # Final clamp to ensure range [-0.2, 0.8] (handles values < -0.2 if any)
    return max(MIN_GRIPPER, min(MAX_GRIPPER, x))

# Commented out binarization function (kept for reference)
# def binarize_gripper(gripper_value):
#     """Binarize gripper value to open (1.0) or closed (0.0).
#     
#     Args:
#         gripper_value: scalar gripper joint value
#     
#     Returns:
#         Binary gripper value: 1.0 for open, 0.0 for closed
#         Closed if joint < 0.1 or joint > 4.0 (wrapped values), else open
#     """
#     # Fix wrapping: if value > 4, it's wrapped (2π ≈ 6.28, so values > 4 are likely wrapped)
#     #if gripper_value > 4.0:
#     #    # Convert: 2π - value (so 6.1 becomes ~-0.2)
#     #    gripper_value = 2 * math.pi - gripper_value
#     
#     # Binarize: closed if < 0.1 or > 4.0 (after wrapping fix, > 4.0 shouldn't happen, but check anyway)
#     if gripper_value < 0.1 or gripper_value > 4.0:
#         return 0.0  # Closed
#     else:
#         return 1.0  # Open

# Virtual gripper keypoint in local gripper frame
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point to 2D pixel coordinates.
    
    Args:
        point_3d: (3,) 3D point in world coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (2,) 2D pixel coordinates [x, y], or None if behind camera
    """
    # Transform to camera frame
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = camera_pose @ point_3d_h
    
    # Check if point is behind camera
    if point_cam[2] <= 0:
        return None
    
    # Project to image plane
    point_2d_h = cam_K @ point_cam[:3]
    point_2d = point_2d_h[:2] / point_2d_h[2]
    
    return point_2d

class RealTrajectoryDataset(Dataset):
    """Dataset for real dense trajectory prediction.
    
    Each sample contains:
        - RGB image (from any frame)
        - Dense 2D trajectory (N_WINDOW waypoints starting from that frame)
        - Camera parameters (pose and intrinsics)
        - Ground truth heatmaps for each timestep
    
    For each episode, creates samples from every frame, with trajectories
    padded with the last observed keypoint if there aren't enough subsequent frames.
    """
    
    def __init__(self, dataset_root="scratch/", image_size=448, episode: str | None = None, max_episodes: int | None = None):
        """Initialize dataset.
        
        Args:
            dataset_root: Root directory containing episodes
            image_size: Size to resize images to (will be square, default 448)
            episode: Optional episode directory name (e.g. "episode_001") to load only.
            max_episodes: Optional limit on how many episodes to load (after sorting).
        """
        self.dataset_root = Path(dataset_root)
        self.image_size = image_size
        self.n_window = N_WINDOW
        
        # Load episode directories
        if not self.dataset_root.exists():
            raise ValueError(f"Dataset directory not found: {self.dataset_root}")
        
        episode_dirs = sorted([d for d in self.dataset_root.iterdir()
                             if d.is_dir() and "episode" in d.name])

        if episode is not None:
            episode_dirs = [d for d in episode_dirs if d.name == episode]
            if len(episode_dirs) == 0:
                raise ValueError(f"Episode '{episode}' not found in {self.dataset_root}")

        if max_episodes is not None:
            episode_dirs = episode_dirs[:max_episodes]
        
        if len(episode_dirs) == 0:
            raise ValueError(f"No episodes found in {self.dataset_root}")
        
        # Build list of (episode_dir, frame_idx) pairs
        self.samples = []
        for episode_dir in episode_dirs:
            # Find all frame images
            frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
            
            # Create a sample for each frame
            for frame_file in frame_files:
                frame_idx = int(frame_file.stem)
                self.samples.append((episode_dir, frame_idx))
        
        print(f"Loaded {len(episode_dirs)} episodes")
        print(f"Created {len(self.samples)} samples (one per frame)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        """Get a single sample.
        
        Returns:
            dict with keys:
                - rgb: (3, H, W) normalized RGB image
                - heatmap_target: (N_WINDOW, H, W) one-hot heatmaps for each timestep
                - trajectory_2d: (N_WINDOW, 2) 2D pixel locations for each timestep
                - trajectory_3d: (N_WINDOW, 3) 3D world positions for each timestep
                - trajectory_gripper: (N_WINDOW,) gripper values for each timestep
                - target_3d: (3,) final target 3D world position (last waypoint)
                - camera_pose: (4, 4) camera pose matrix (from first frame)
                - cam_K_norm: (3, 3) normalized intrinsics (from first frame)
        """
        episode_dir, frame_idx = self.samples[idx]
        frame_str = f"{frame_idx:06d}"
        
        # Load RGB image for this frame
        rgb_path = episode_dir / f"{frame_str}.png"
        rgb = plt.imread(rgb_path)[..., :3]  # (H, W, 3)
        H_orig, W_orig = rgb.shape[:2]
        
        # Find all subsequent frames
        frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
        frame_indices = [int(f.stem) for f in frame_files]
        start_frame_idx = frame_indices.index(frame_idx)
        
        # Build trajectory from subsequent frames
        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []  # Gripper values for each timestep
        
        # Process up to N_WINDOW frames starting from current frame
        for i in range(self.n_window):
            if start_frame_idx + i >= len(frame_indices):
                break
            
            next_frame_idx = frame_indices[start_frame_idx + i]
            next_frame_str = f"{next_frame_idx:06d}"
            
            # Load gripper pose
            gripper_pose_path = episode_dir / f"{next_frame_str}_gripper_pose.npy"
            if not gripper_pose_path.exists():
                break
            
            gripper_pose = np.load(gripper_pose_path)  # (4, 4)
            gripper_rot = gripper_pose[:3, :3]
            gripper_pos = gripper_pose[:3, 3]
            
            # Compute 3D keypoint from gripper pose
            kp_3d = gripper_rot @ kp_local + gripper_pos
            trajectory_3d.append(kp_3d)
            
            # Load gripper value from joint state file (last dimension)
            joint_state_path = episode_dir / f"{next_frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])  # Last value is gripper
                gripper_value = process_gripper_value(gripper_value)  # Process: map >4.0 to -0.2, clamp to [-0.2, 0.8]
                trajectory_gripper.append(gripper_value)
            else:
                # Default to 1.0 (open) if joint state not found
                trajectory_gripper.append(1.0)
            
            # Load camera parameters (use first frame's camera params for consistency)
            cam_K_norm_path = episode_dir / f"{frame_str}_cam_K.npy"
            camera_pose_path = episode_dir / f"{frame_str}_camera_pose.npy"
            
            if not (cam_K_norm_path.exists() and camera_pose_path.exists()):
                break
            
            cam_K_norm = np.load(cam_K_norm_path)  # (3, 3) normalized
            camera_pose = np.load(camera_pose_path)  # (4, 4)
            
            # Scale intrinsics to image resolution
            cam_K = cam_K_norm.copy()
            cam_K[0] *= W_orig  # fx
            cam_K[1] *= H_orig  # fy
            
            # Project 3D keypoint to 2D
            kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
            if kp_2d is None:
                break
            
            trajectory_2d.append(kp_2d)
        
        if len(trajectory_2d) == 0:
            # Fallback: use current frame only
            gripper_pose = np.load(episode_dir / f"{frame_str}_gripper_pose.npy")
            gripper_rot = gripper_pose[:3, :3]
            gripper_pos = gripper_pose[:3, 3]
            kp_3d = gripper_rot @ kp_local + gripper_pos
            trajectory_3d = [kp_3d]
            
            # Load gripper value for current frame
            joint_state_path = episode_dir / f"{frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])
                gripper_value = process_gripper_value(gripper_value)  # Fix wrapping and clamp
                trajectory_gripper = [gripper_value]
            else:
                trajectory_gripper = [1.0]
            
            cam_K_norm = np.load(episode_dir / f"{frame_str}_cam_K.npy")
            camera_pose = np.load(episode_dir / f"{frame_str}_camera_pose.npy")
            cam_K = cam_K_norm.copy()
            cam_K[0] *= W_orig
            cam_K[1] *= H_orig
            kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
            if kp_2d is not None:
                trajectory_2d = [kp_2d]
        
        if len(trajectory_2d) == 0:
            raise ValueError(f"Could not compute trajectory for {episode_dir.name} frame {frame_str}")
        
        trajectory_2d = np.array(trajectory_2d)  # (N, 2)
        trajectory_3d = np.array(trajectory_3d)  # (N, 3)
        trajectory_gripper = np.array(trajectory_gripper)  # (N,)
        
        # Pad trajectory with last observed keypoint if needed
        if len(trajectory_2d) < self.n_window:
            # Pad with last point
            last_point_2d = trajectory_2d[-1:]  # (1, 2)
            last_point_3d = trajectory_3d[-1:]  # (1, 3)
            last_gripper = trajectory_gripper[-1:]  # (1,)
            n_pad = self.n_window - len(trajectory_2d)
            trajectory_2d = np.concatenate([trajectory_2d, np.tile(last_point_2d, (n_pad, 1))], axis=0)
            trajectory_3d = np.concatenate([trajectory_3d, np.tile(last_point_3d, (n_pad, 1))], axis=0)
            trajectory_gripper = np.concatenate([trajectory_gripper, np.tile(last_gripper, (n_pad,))], axis=0)
        elif len(trajectory_2d) > self.n_window:
            # Truncate to first N_WINDOW points
            trajectory_2d = trajectory_2d[:self.n_window]
            trajectory_3d = trajectory_3d[:self.n_window]
            trajectory_gripper = trajectory_gripper[:self.n_window]
        
        # Get target 3D position (final waypoint)
        target_3d = trajectory_3d[-1]
        
        # Load camera parameters (from first frame for consistency)
        camera_pose = np.load(episode_dir / f"{frame_str}_camera_pose.npy")
        cam_K_norm = np.load(episode_dir / f"{frame_str}_cam_K.npy")
        
        # Resize image if needed
        if H_orig != self.image_size or W_orig != self.image_size:
            rgb = cv2.resize(rgb, (self.image_size, self.image_size), 
                           interpolation=cv2.INTER_LINEAR)
            
            # Scale trajectory_2d to new resolution
            scale_x = self.image_size / W_orig
            scale_y = self.image_size / H_orig
            trajectory_2d = trajectory_2d * np.array([scale_x, scale_y])
        
        # Create heatmaps for each timestep
        heatmap_targets = []
        for t in range(self.n_window):
            target_2d = trajectory_2d[t]
            
            # Clip to image bounds
            target_2d = np.array([
                np.clip(target_2d[0], 0, self.image_size - 1),
                np.clip(target_2d[1], 0, self.image_size - 1)
            ])
            
            # Create one-hot heatmap at model resolution
            heatmap_target = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            target_x = int(round(target_2d[0]))
            target_y = int(round(target_2d[1]))
            if 0 <= target_x < self.image_size and 0 <= target_y < self.image_size:
                heatmap_target[target_y, target_x] = 1.0
            
            heatmap_targets.append(heatmap_target)
        
        heatmap_targets = np.array(heatmap_targets)  # (N_WINDOW, H, W)
        
        # Convert to torch tensors
        # RGB: (H, W, 3) -> (3, H, W) and normalize to [0, 1]
        rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()
        if rgb_tensor.max() > 1.0:
            rgb_tensor = rgb_tensor / 255.0
        
        # Normalize RGB to ImageNet stats (for DINOv2)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        rgb_tensor = (rgb_tensor - mean) / std
        
        heatmap_tensor = torch.from_numpy(heatmap_targets).float()  # (N_WINDOW, H, W)
        trajectory_2d_tensor = torch.from_numpy(trajectory_2d).float()  # (N_WINDOW, 2)
        trajectory_3d_tensor = torch.from_numpy(trajectory_3d).float()  # (N_WINDOW, 3)
        trajectory_gripper_tensor = torch.from_numpy(trajectory_gripper).float()  # (N_WINDOW,)
        target_3d_tensor = torch.from_numpy(target_3d).float()
        camera_pose_tensor = torch.from_numpy(camera_pose).float()
        cam_K_norm_tensor = torch.from_numpy(cam_K_norm).float()

        # Load joint states for debugging / IK initialization
        joint_state_path_cur = episode_dir / f"{frame_str}.npy"
        joint_state_cur = None
        if joint_state_path_cur.exists():
            try:
                joint_state_cur = np.load(joint_state_path_cur)
            except Exception:
                joint_state_cur = None

        joint_state_path_ep0 = episode_dir / "000000.npy"
        joint_state_ep0 = None
        if joint_state_path_ep0.exists():
            try:
                joint_state_ep0 = np.load(joint_state_path_ep0)
            except Exception:
                joint_state_ep0 = None

        joint_state_cur_tensor = (
            torch.from_numpy(np.asarray(joint_state_cur, dtype=np.float32))
            if joint_state_cur is not None
            else None
        )
        joint_state_ep0_tensor = (
            torch.from_numpy(np.asarray(joint_state_ep0, dtype=np.float32))
            if joint_state_ep0 is not None
            else None
        )
        
        return {
            'rgb': rgb_tensor,
            'heatmap_target': heatmap_tensor,
            'trajectory_2d': trajectory_2d_tensor,
            'trajectory_3d': trajectory_3d_tensor,
            'trajectory_gripper': trajectory_gripper_tensor,
            'target_3d': target_3d_tensor,
            'camera_pose': camera_pose_tensor,
            'cam_K_norm': cam_K_norm_tensor,
            'episode_id': f"{episode_dir.name}_frame_{frame_str}",
            'episode_name': episode_dir.name,
            'frame_idx': frame_idx,
            'episode_dir': str(episode_dir),
            'joint_state': joint_state_cur_tensor,
            'episode_start_joint_state': joint_state_ep0_tensor,
        }


if __name__ == "__main__":
    # Test dataset loading
    print("Testing RealTrajectoryDataset...")
    
    try:
        dataset = RealTrajectoryDataset(image_size=448, dataset_root="scratch/parsed_moredata_pickplace_home")
        print(f"✓ Loaded {len(dataset)} samples")
        
        # Test first sample
        sample = dataset[0]
        print(f"\nSample 0:")
        print(f"  RGB shape: {sample['rgb'].shape}")
        print(f"  RGB range: [{sample['rgb'].min():.3f}, {sample['rgb'].max():.3f}]")
        print(f"  Heatmap shape: {sample['heatmap_target'].shape}")
        print(f"  Heatmap sum per timestep: {sample['heatmap_target'].sum(dim=(1,2))}")
        print(f"  Trajectory 2D shape: {sample['trajectory_2d'].shape}")
        print(f"  Trajectory 2D (first 3): {sample['trajectory_2d'][:3]}")
        print(f"  Trajectory 3D shape: {sample['trajectory_3d'].shape}")
        print(f"  Target 3D: {sample['target_3d']}")
        print(f"  Episode ID: {sample['episode_id']}")
        
    except Exception as e:
        print(f"✗ Error: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n{'='*60}")
    print("✓ Dataset test complete!")
