o
    ИiF                  
   @   s,  d Z ddlZddlZddlmZ ddlmZ ddlm	Z
 ddlZddlZdZdZdZdZd	d
 Zeg dg dg dg dg dgd ZdZee Zdd ZG dd deZedkred zedddZedee d ed Zed eded j  ed ed  d!d"ed  d!d# ed$ed% j  ed&ed% j d'd(  ed)ed* j  ed+ed* dd   ed,ed- j  ed.ed/   ed0ed1   W n" e!y Z" zed2e"  ddl#Z#e#$  W Y dZ"["ndZ"["ww ed3d4  ed5 dS dS )6z-Dataset for real dense trajectory prediction.    N)Dataset)Path      皙ɿ皙?c                 C   s,   | }|dkrd}|dkrd}t ttt|S )a|  Process gripper value: fix wrapping and clamp to range.
    
    Args:
        gripper_value: scalar gripper joint value
    
    Returns:
        Processed gripper value in [MIN_GRIPPER, MAX_GRIPPER] = [-0.2, 0.8]
        Processing steps:
        1. Map values > 4.0 to -0.2 (wraparound fix)
        2. Clamp values > 0.8 to 0.8
        3. Ensure final range is [-0.2, 0.8]
    g      @r   r   )maxMIN_GRIPPERminMAX_GRIPPER)gripper_valuex r   D/data/cameron/keygrip/volume_dino_tracks_act_baseline_joints/data.pyprocess_gripper_value   s   r   )     *@{GVg/@)g
ףp=%@gfffffXr   )r   r   g/)g(\1@g=
ףpTr   )g\(6@g=
ףpQr   g     @@   c                 C   sL   t | d}|| }|d dkrdS ||dd  }|dd |d  }|S )a0  Project 3D point to 2D pixel coordinates.
    
    Args:
        point_3d: (3,) 3D point in world coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (2,) 2D pixel coordinates [x, y], or None if behind camera
          ?   r   Nr   )npappend)point_3dcamera_posecam_K
point_3d_h	point_cam
point_2d_hpoint_2dr   r   r   project_3d_to_2dF   s   r   c                   @   s<   e Zd ZdZddedB dedB fddZd	d
 Zdd ZdS )RealTrajectoryDataseta  Dataset for real dense trajectory prediction.
    
    Each sample contains:
        - RGB image (from any frame)
        - Dense 2D trajectory (N_WINDOW waypoints starting from that frame)
        - Camera parameters (pose and intrinsics)
        - Ground truth heatmaps for each timestep
    
    For each episode, creates samples from every frame, with trajectories
    padded with the last observed keypoint if there aren't enough subsequent frames.
    scratch/  Nepisodemax_episodesc           
         s.  t || _|| _t| _| j std| j tdd | j D } durB fdd|D }t	|dkrBtd  d| j |durL|d| }t	|dkrZtd	| j g | _
|D ] }td
d |dD }|D ]}t|j}	| j
||	f qoq_tdt	| d tdt	| j
 d dS )an  Initialize dataset.
        
        Args:
            dataset_root: Root directory containing episodes
            image_size: Size to resize images to (will be square, default 448)
            episode: Optional episode directory name (e.g. "episode_001") to load only.
            max_episodes: Optional limit on how many episodes to load (after sorting).
        zDataset directory not found: c                 S   s"   g | ]}|  rd |jv r|qS r#   )is_dirname.0dr   r   r   
<listcomp>}   s    
z2RealTrajectoryDataset.__init__.<locals>.<listcomp>Nc                    s   g | ]	}|j  kr|qS r   )r'   r(   r%   r   r   r+          r   z	Episode 'z' not found in zNo episodes found in c                 S      g | ]	}|j  r|qS r   stemisdigitr)   fr   r   r   r+      r,   *.pngzLoaded z	 episodeszCreated z samples (one per frame))r   dataset_root
image_sizeN_WINDOWn_windowexists
ValueErrorsortediterdirlensamplesglobintr/   r   print)
selfr4   r5   r#   r$   episode_dirsepisode_dirframe_files
frame_file	frame_idxr   r%   r   __init__l   s.   
	

zRealTrajectoryDataset.__init__c                 C   s
   t | jS )N)r<   r=   )rA   r   r   r   __len__   s   
zRealTrajectoryDataset.__len__c           A      C   s  | j | \}}|d}|| d }t|dddf }|jdd \}}tdd |d	D }	d
d |	D }
|
|}g }g }g }g }t| jD ]}|| t	|
krW n|
||  }|d}|| d }|
 sn nt|}|ddddf }|dddf }|t | }|| || d }|
 rt|}tj|dt tjd}|| t|d }t|}|| n nI|| d }|| d }|
 r|
 s n1t|}t|}| } | d  |9  < | d  |9  < t||| }!|!du r n||! qKt	|dkrt|| d }|ddddf }|dddf }|t | }|g}|| d }|
 rbt|}tj|dt tjdg}t|d }t|g}ndg}tjttjdg}t|| d }t|| d }| } | d  |9  < | d  |9  < t||| }!|!dur|!g}t	|dkrtd|j d| t|}t|}t|}t|}t	|| jk r-|dd }"|dd }#|dd }$|dd }%| jt	| }&tj|t|"|&dfgdd}tj|t|#|&dfgdd}tj|t|$|&fgdd}tj|t|%|&dfgdd}n$t	|| jkrQ|d| j }|d| j }|d| j }|d| j }|d }'t|| d }t|| d }|| jksu|| jkrtj|| j| jftjd}| j| }(| j| })|t|(|)g }g }*t| jD ]g}+||+ },tt |,d d| jd t |,d d| jd g},tj| j| jftjd}-t!t"|,d }.t!t"|,d }/d|.  kr| jk rn nd|/  kr| jk rn nd|-|/|.f< |*|- qt|*}*t#$|%ddd }0|0& dkr |0d }0t#'g d(ddd}1t#'g d(ddd}2|0|1 |2 }0t#$|* }3t#$| }4t#$| }5t#$| }6t#$| }7t#$|' }8t#$| }9t#$| }:|| d };d}<|;
 rzt|;}<W n t)y   d}<Y nw |d }=d}>|=
 rzt|=}>W n t)y   d}>Y nw |<durt#$tj|<tjdnd}?|>durt#$tj|>tjdnd}@|0|3|4|5|6|7|8|9|:|j d| |j|t*||?|@dS )a  Get a single sample.
        
        Returns:
            dict with keys:
                - rgb: (3, H, W) normalized RGB image
                - heatmap_target: (N_WINDOW, H, W) one-hot heatmaps for each timestep
                - trajectory_2d: (N_WINDOW, 2) 2D pixel locations for each timestep
                - trajectory_3d: (N_WINDOW, 3) 3D world positions for each timestep
                - trajectory_gripper: (N_WINDOW,) gripper values for each timestep
                - target_3d: (3,) final target 3D world position (last waypoint)
                - camera_pose: (4, 4) camera pose matrix (from first frame)
                - cam_K_norm: (3, 3) normalized intrinsics (from first frame)
        06dz.png.Nr   r   c                 S   r-   r   r.   r1   r   r   r   r+      r,   z5RealTrajectoryDataset.__getitem__.<locals>.<listcomp>r3   c                 S   s   g | ]}t |jqS r   )r?   r/   r1   r   r   r   r+      s    z_gripper_pose.npyz.npy)dtypez
_cam_K.npyz_camera_pose.npyr      r   z!Could not compute trajectory for z frame )axis)interpolationg     o@)g
ףp=
?gv/?gCl?)gZd;O?gy&1?g?z
000000.npy_frame_)rgbheatmap_targettrajectory_2dtrajectory_3dtrajectory_grippertrajectory_joints	target_3dr   
cam_K_norm
episode_idepisode_namerF   rC   joint_stateepisode_start_joint_state)+r=   pltimreadshaper:   r>   indexranger7   r<   r8   r   loadkp_localr   asarrayN_JOINTSfloat32floatr   copyr   zerosr9   r'   arrayconcatenatetiler5   cv2resizeINTER_LINEARclipr?   roundtorch
from_numpypermuter   tensorview	Exceptionstr)ArA   idxrC   rF   	frame_strrgb_pathrP   H_origW_origrD   frame_indicesstart_frame_idxrR   rS   rT   rU   inext_frame_idxnext_frame_strgripper_pose_pathgripper_posegripper_rotgripper_poskp_3djoint_state_pathrZ   joints_6r   cam_K_norm_pathcamera_pose_pathrW   r   r   kp_2dlast_point_2dlast_point_3dlast_gripperlast_jointsn_padrV   scale_xscale_yheatmap_targetst	target_2drQ   target_xtarget_y
rgb_tensormeanstdheatmap_tensortrajectory_2d_tensortrajectory_3d_tensortrajectory_gripper_tensortrajectory_joints_tensortarget_3d_tensorcamera_pose_tensorcam_K_norm_tensorjoint_state_path_curjoint_state_curjoint_state_path_ep0joint_state_ep0joint_state_cur_tensorjoint_state_ep0_tensorr   r   r   __getitem__   sP  














 

<




z!RealTrajectoryDataset.__getitem__)r!   r"   NN)	__name__
__module____qualname____doc__rw   r?   rG   rH   r   r   r   r   r   r    _   s
    -r    __main__z Testing RealTrajectoryDataset...r"   z&scratch/parsed_moredata_pickplace_home)r5   r4   u   ✓ Loaded z samplesz

Sample 0:z  RGB shape: rP   z  RGB range: [z.3fz, ]z  Heatmap shape: rQ   z  Heatmap sum per timestep: )rL   r   )dimz  Trajectory 2D shape: rR   z  Trajectory 2D (first 3): z  Trajectory 3D shape: rS   z  Target 3D: rV   z  Episode ID: rX   u   ✗ Error: 
z<============================================================u   ✓ Dataset test complete!)%r   rq   numpyr   torch.utils.datar   pathlibr   matplotlib.pyplotpyplotr\   rl   mathr6   rd   r	   r   r   ri   KEYPOINTS_LOCAL_M_ALLKP_INDEXrb   r   r    r   r@   datasetr<   sampler^   r
   r   sumrv   e	traceback	print_excr   r   r   r   <module>   sb    .  
3*