o
    ii                     @   s   d Z ddlZddlmZ ddlZddlZddlZddlm	Z	 ddl
mZ dZdZdZd	Zd
Zd\ZZejg dejdZejg dejdZeeedfddZdd ZG dd de	ZG dd de	ZdS )u  DROID dataset for PARA training.

Streams episodes from HuggingFace (cadene/droid_1.0.1) and produces training
samples in the same format as CachedTrajectoryDataset (LIBERO).

NOTE: Camera intrinsics are estimated (no per-camera intrinsics in the dataset).
The estimated fy=130 corresponds to ZED 2 wide mode at 320×180.

Key differences from LIBERO:
  - Real images (320×180), resized to 448×448, no vertical flip needed
  - Camera extrinsics from DROID 6D format [x,y,z,rx,ry,rz]
  - Gripper: DROID [0,1] → mapped to [-1,+1] for PARA convention
  - EEF pose from cartesian_position [x,y,z,roll,pitch,yaw]
  - Robot base at world origin (base_z=0)
  - 15 Hz (vs LIBERO 20 Hz), frame_stride=2 gives ~7.5Hz
    N)Path)Dataset)Rotation   g            ?g        g     @`@)i@     )g
ףp=
?gv/?gCl?dtype)gZd;O?gy&1?g?  c                 C   s  | dd  tj}td| dd  }tjdtjd}||ddddf< ||dddf< |j}| | }	tjdtjd}
||
ddddf< |	|
dddf< || }|| }|| }|| }|d | }|d | }tj|d|dgd||dgg d	g d
gtjd}||
 }tj|d|gd||gg dgtjd}|	 }|d  t
|  < |d  t
|  < | tj| tj| tjfS )u  Build camera matrices from DROID 6D extrinsics.

    DROID extrinsics: [x,y,z,rx,ry,rz] = camera pose in robot base frame.
    R = Rotation.from_euler("xyz", [rx,ry,rz]) maps camera frame → base frame.

    All matrices are in the RESIZED (target_size × target_size) image space,
    since training images are resized from 320×180 to 448×448.

    Returns:
        camera_pose: (4,4) camera→world (for unprojection / 3D recovery)
        world_to_cam_proj: (4,4) full projection matrix K@[R|t] at target_size
                           (compatible with robosuite project_points_from_world_to_camera)
        cam_K_norm: (3,3) normalized intrinsic matrix at target_size (divided by target_size)
    N   xyz   r   r   g       @r   )r   r      r   )r   r   r   r   )r   r   r   r   )astypenpfloat64ScipyR
from_euler	as_matrixeyeTarraycopyfloatfloat32)Zext6dfyimg_wimg_htarget_sizeposZ
R_base_camcamera_poseZ
R_cam_baseZt_camZ
T_cam_basesxsyZfx_effZfy_effZcx_effZcy_effZK_4x4world_to_cam_projZK_3x3
cam_K_norm r%   6/data/cameron/para_droid_pretrain/libero/data_droid.py_build_camera_matrices(   sB   


&


r'   c                 C   s   t | t jd}|| }|d dkr"t j|d |d gt jdS |d |d  }|d |d  }t |d|d }t |d|d }t j||gt jdS )u   Project 3D world point to pixel (u, v) in target_size image space.

    Uses the full 4×4 projection matrix (K @ extrinsic) at target_size.
    Returns (u, v) clipped to [0, target_size-1].
    r      r   r   r   )r   appendr   r   r   r   clip)Zp_worldr#   r   Zp_hZp_projuvr%   r%   r&   _project_to_pixel_targetsizee   s   r-   c                   @   sH   e Zd ZdZddddedefddZd	d
 Zdd Zdd Z	dd Z
dS )DroidLocalDataseta  Reads DROID episodes from a local directory (downloaded via huggingface-cli).

    Init is fast: only scans file paths and reads parquet row counts from metadata.
    All per-episode data (parquet columns, projections) is loaded lazily in __getitem__.

    Expected layout under data_root:
        data/chunk-NNN/episode_NNNNNN.parquet
        videos/chunk-NNN/observation.images.exterior_{1,2}_left/episode_NNNNNN.mp4

    Args:
        data_root: path to downloaded dataset (e.g. /data/cameron/droid)
        camera: "ext1" or "ext2"
        max_episodes: limit number of episodes (0 = all)
        image_size, n_window, frame_stride, fy: same as DroidStreamingDataset
    ext2r    r
   r(   c	              	   C   sp  t || _|| _|| _|| _|| _|| _|dkrdnd| _|dkr#dnd| _dd l	}	|		 }
|rt |
 rdd l}t|}||}W d    n1 sNw   Y  td|  tdt|d	  d
|dd  g }g }g }| j| j }|d	 D ]`}|d }|d }||k rq|d|d d}d|d}| jd | | d }| jd | | j | d }|
 r|
 sq||t| |t| || |dkrt||kr nq|ndd lm} td| j d d}g }g }g }| j| j }|dkrt||n|}t|D ]j}d|d d}d|d}| jd | | d }| jd | | j | d }|
 r@|
 sBqz|t|}|j}W n tyY   Y qw ||k raq|t| |t| || q|| _|| _tj|tjd| _t | j| _!t| j!dkrt"| j!d nd| _#|		 |
 }tdt| d| j# d|d d! d S )"Next1"observation.images.exterior_1_left"observation.images.exterior_2_left!camera_extrinsics.exterior_1_left!camera_extrinsics.exterior_2_leftr   z)DroidLocalDataset: loading from manifest z  episodesz episodes, min_in_frame=Zmin_in_frame?ep_idx
num_frameschunk-  03depisode_06ddata.parquetvideos.mp4zDroidLocalDataset: indexing z/data/ (no manifest)...ipu r   zDroidLocalDataset:  episodes, z samples (init z.1fzs))$r   	data_root
image_sizen_windowframe_strider   cameracam_keyext_coltimeexistsjsonopenloadprintlengetr)   strZpyarrow.parquetZparquetminrangeZread_metadatanum_rows	Exception	_pq_paths
_vid_pathsr   r   int32_frame_countscumsum_cumsumint_total_samples)selfrE   rI   max_episodesmanifest_pathrF   rG   rH   r   rL   t0rN   fmanifestZpq_pathsZ	vid_pathsZframe_countsZ
min_framesentryepr   chunkep_strpq_pathvid_pathpqZTOTAL_EPISODESZ	n_to_scan
video_pathmetaelapsedr%   r%   r&   __init__   s   





$zDroidLocalDataset.__init__c                 C   sn  ddl }| j| }| j| }t| j| }||}t|d j	tj
}|d j	tj
}tj|| j jd tj
d}	d|jv rKt|d jd nd}
t|	| jtt| jd\}}}|dddd	f }tj|d
ftj
d}t|D ]}t|| 	tj|| j||< qs|d
 d  }|ddd	df 	tj}tjdd |D dd	tj
}||||||||||
d
S )zLLoad and process a single episode's parquet. Called lazily from __getitem__.r   N$observation.state.cartesian_position"observation.state.gripper_positionr   language_instructionr0   r   r   r   r   r   r(   r   r   c                 S      g | ]
}t d | qS r   r   r   as_quat.0er%   r%   r&   
<listcomp>      z8DroidLocalDataset._load_episode_data.<locals>.<listcomp>axis)
rn   eef_poseef_quatgripperpix_uvr    world_to_camr$   r9   language)pandasrY   rZ   r_   r\   read_parquetr   stackvaluesr   r   r   rK   iloccolumnsrT   r'   r   IMG_W_NATIVEIMG_H_NATIVErF   zerosrV   r-   r   )ra   r8   pdrk   rn   r   dfcartesian_positionsgripper_positions
extrinsicsr   r    r#   r$   r   r   t_igripper_para	eef_eulerr   r%   r%   r&   _load_episode_data   sP   


 

z$DroidLocalDataset._load_episode_datac           	         s   ddl }zB||}t|}t|}i  t|jddD ]\}}||v r,|jdd |< ||kr2 nq|   s=td fdd|D W S  t	y_   t
jttd	ft
jd
gt|  Y S w )zDecode specific frames from mp4 using PyAV. Returns list of (H,W,3) uint8.

        Handles corrupt videos gracefully by returning black frames.
        r   Nvideorgb24formatzNo frames decodedc              	      s$   g | ]}  | t   qS r%   )rS   maxkeys)r{   iresultr%   r&   r}   3     $ z4DroidLocalDataset._decode_frames.<locals>.<listcomp>r   r   )avrO   setr   	enumeratedecode
to_ndarrayclose
ValueErrorrX   r   r   r   r   uint8rR   )	ra   rn   frame_indicesr   	containerZframes_needed	max_framer   framer%   r   r&   _decode_frames  s&   
$z DroidLocalDataset._decode_framesc                 C   s   | j S N)r`   ra   r%   r%   r&   __len__8  s   zDroidLocalDataset.__len__c                    s  t tjj|dd}t ||dkrj|d  nd |}|d   fddtjD }|d |}g }g }g }g }	g }
d }t|D ]k\}}|| 	tj
d	 }|jd jksi|jd jkrvtj|jjftjd
}|d u r||}|
| ||d | 	tj |	|d | 	tj |tt|d | tt ||d |   qKtj|dd	tj
}tj|dd	tj
}tj|tj
d}tj|	dd	tj
}	tjdd |	D dd	tj
}dd l}tj|jtjd}t|tjfdd|	D dd	tj
}tj|
dd	tj
}
g }tjD ]@}|| \}}t ttt|djd }t ttt|djd }tjjjftj
d}d|||f< || q&tj|dd}t |!ddd }tj"t#tj
d$ddd}tj"t%tj
d$ddd}|| | }i d|dt | dt | dt | dt | dt |	 dt | dt | dt |
 dt |d   d!tj"t&tj
dd"t |d#  d$t |d$  d%t |d%  d&tj"|tj'dd'tj"tj'dd(td)|(d*d+tj"d,tj)dtdjjtjdtj*d-tj
dtj*dtj
dtj*d-tj
dtjd.S )/Nright)sider   r   r9   c                    s$   g | ]}t |j   d  qS )r   )rU   rH   )r{   k)r   ra   start_tr%   r&   r}   C  r   z1DroidLocalDataset.__getitem__.<locals>.<listcomp>rn        o@interpolationr   r   r   r   r   r   c                 S      g | ]
}t |d qS rw   r   	from_quatas_eulerr{   qr%   r%   r&   r}   a  r~   c                    "   g | ]}   t|  qS r%   invr   r   	as_rotvecr   ref_rotr%   r&   r}   h      r   r(   r   rgbheatmap_targettrajectory_2dtrajectory_3dtrajectory_grippertrajectory_quattrajectory_eulertrajectory_delta_rotvecrgb_frames_rawworld_to_camerar   base_z	target_3drC   r    r$   demo_idxr   clip_embedding   r   r0   Fr   task_description	has_wrist	wrist_rgbwrist_trajectory_2dwrist_camera_posewrist_cam_K_normwrist_world_to_camerawrist_in_view)+r_   r   searchsortedr^   r   rV   rG   r   r   r   r   shaperF   cv2resizeINTER_LINEARr)   r   r   r*   MIN_GRIPPERMAX_GRIPPERr   r   r   modelREF_ROTATION_QUATr   r   roundr   torch
from_numpypermutetensorIMAGENET_MEANviewIMAGENET_STDBASE_ZlongrS   boolr   )ra   idxr   rh   r   
raw_framesr   r   r   r   r   rgb_refr   tr   r   _model_moduleref_quatr   heatmap_targetst_kxyx_iy_ihmrgb_tmeanstdr%   )r   r   ra   r   r&   __getitem__;  s   "
 
 

  	



zDroidLocalDataset.__getitem__N)__name__
__module____qualname____doc__N_WINDOW
DEFAULT_FYrq   r   r   r   r  r%   r%   r%   r&   r.   v   s    
f1r.   c                   @   s>   e Zd ZdZdddedefddZdd	 Zd
d Zdd Z	dS )DroidStreamingDatasetuI  Streams DROID episodes from HuggingFace, caches in memory.

    Downloads episode parquet + video on first access, then serves
    samples from RAM. Suitable for prototyping with small episode counts.

    Args:
        episode_indices: list of int episode indices to load
        camera: "ext1" or "ext2" (exterior cameras only, no wrist)
        image_size: target image size (default 448)
        n_window: number of future timesteps per sample
        frame_stride: stride between frames (default 2 for 15Hz → ~7.5Hz)
        fy: estimated focal length for camera intrinsics
    Nr/   r
   r(   c                 C   s  |d u r
t td}|| _|| _|| _|| _|| _g | _g | _t	dt
| d |D ]I}z+| |}|d u r9W q+t
| j}	| j| |d }
t|
D ]
}| j|	|f qLW q+ tyt } zt	d| d|  W Y d }~q+d }~ww t	dt
| j dt
| j d	 d S )
N
   zDroidStreamingDataset: loading z episodes from HuggingFace...r9   z"  Warning: failed to load episode : zDroidStreamingDataset: rD   z samples)listrV   rF   rG   rH   r   rI   r6   samplesrQ   rR   _load_episoder)   rX   )ra   episode_indicesrI   rF   rG   rH   r   r8   Zep_datar   r   r   r|   r%   r%   r&   rq     s6   	

&zDroidStreamingDataset.__init__c           !      C   s  ddl }ddlm} ddl}d}|d }d|d}d|d	}||d
| d| ddd}	| jdkr5dnd}
||d| d|
 d| ddd}||	}t|}||}g }|jddD ]}|	|j
dd q^|  t|}|jd |krtd| d|jd  d|  t|jd |}|d| }t|d jd| }t|d jd| }|d jd| tj}| jdkrdnd}tj|| jd tjd}d |jv r|d  jd nd!}t|| jtt| jd"\}}}tj|d#ftjd}t|D ]}||dd$f tj}t||| j||< q|d# d%  }|d|d$d&f tj}tjd'd( |D dd)tj} td*| d+| d,|dd-  d. ||d|dd$f tj|tj| ||||||t |t!rp|d/S d!d/S )0z%Download and parse one DROID episode.r   N)hf_hub_downloadzcadene/droid_1.0.1r;   r=   r>   r:   r<   zdata//r@   dataset)Z	repo_typer1   r2   r3   zvideos/rB   r   r   r   z  Warning: ep z video frames z != parquet rows z observation.state.joint_positionrr   rs   r4   r5   r   rt   r0   ru   r(   r   r   r   c                 S   rv   rw   rx   rz   r%   r%   r&   r}     r~   z7DroidStreamingDataset._load_episode.<locals>.<listcomp>r   z  Loaded ep r  z frames, task='2   ')imagesr   r   r   r   r   r    r   r$   r9   r   )"r   huggingface_hubr  r   rI   r   rR   rO   r   r)   r   r   r   r   r   rQ   rU   r   r   r   r   r   r   r'   r   r   r   rF   r   rV   r   r-   
isinstancerT   )!ra   r8   r   r  r   ZREPO_IDri   rj   Z	chunk_strZparquet_pathrJ   rn   r   r   r   framesr   r  Zjoint_positionsr   r   Zext_keyr   r   r    r#   r$   r   r   r   r   r   r   r%   r%   r&   r    s   


 

$
z#DroidStreamingDataset._load_episodec                 C   s
   t | jS r   )rR   r  r   r%   r%   r&   r   &  s   
zDroidStreamingDataset.__len__c                    s  | j | \}}| j| }|d }g }g }g }g }	g }
d }t| jD ]w}t||| j  |d }|d | tjd }|j	d | j
ksK|j	d | j
krXtj|| j
| j
ftjd}|d u r^|}|
| ||d | tj |	|d | tj |tt|d	 | tt ||d
 |   q!tj|ddtj}tj|ddtj}tj|tjd}tj|	ddtj}	tjdd |	D ddtj}dd l}tj|jtjd}t| tj fdd|	D ddtj}tj|
ddtj}
g }t| jD ]@}|| \}}tttt|d| j
d }tttt|d| j
d }tj| j
| j
ftjd}d|||f< || qtj|dd}t| ddd }tj!t"tjd#ddd}tj!t$tjd#ddd}|| | }i d|dt| dt| dt| dt| dt|	 dt| dt| dt|
 dt|d  dtj!t%tjddt|d   d!t|d!  d"t|d"  d#tj!|tj&dd$tj!|tj&dd%td&|'d'd(tj!d)tj(dtd| j
| j
t| jdtj)d*tjdtj)dtjdtj)d*tjdt| jd+S ),Nr9   r   r  r   r   r   r   r   r   r   r   r   c                 S   r   rw   r   r   r%   r%   r&   r}   J  r~   z5DroidStreamingDataset.__getitem__.<locals>.<listcomp>c                    r   r%   r   r   r   r%   r&   r}   R  r   r   r(   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rC   r    r$   r   r   r   r   r   r0   Fr   r   )*r  r6   rV   rG   rU   rH   r   r   r   r   rF   r   r   r   r)   r   r   r*   r   r   r   r   r   r   r   r   r   r_   r   r   r   r   r   r   r   r   r   r   r   rS   r   r   )ra   r   r   r   rh   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r%   r   r&   r  )  s   
 
 

  	



z!DroidStreamingDataset.__getitem__)
r  r  r  r  r	  r
  rq   r  r   r  r%   r%   r%   r&   r    s    
'Vr  )r  ospathlibr   r   numpyr   r   torch.utils.datar   scipy.spatial.transformr   r   r	  r   r   r   r
  r   r   r   r   r   r   r'   r-   r.   r  r%   r%   r%   r&   <module>   s0    
=  &