o
    jG                     @   s   d Z ddlZddlZddlZddlmZ ddlZddlZddl	Z	ddl
mZ ddlmZ ejdd ddlmZmZ dZddlZd	Zd
ZdZdZdZdd ZG dd deZedkre Zed Ze  D ]'\Z!Z"e#e"dre$de! de%e"j& de"j'  qfe$de! de"  qfdS dS )u1  All-in-memory dataset for DA3 volume training (smith300).

Extends Smith300DA3Dataset to also store world-Z (EEF height) per frame and
compute height-bin discretization stats over the whole dataset.

Per sample:
  rgb:           (3, 504, 504) float32 in [0, 1]
  gt_pix_504:    (N_WINDOW, 2) GT EEF pixel coords in 504-space
  gt_pix_valid:  (N_WINDOW,) bool — False for clamped (off-episode) future steps
  gt_z_bin:      (N_WINDOW,) long — height bin index in [0, N_HEIGHT_BINS-1]
  da3_depth:     (504, 504) float32 — frozen DA3 depth for distillation
    N)Path)DatasetRotationz/data/cameron/para/para_mac)EEF_BODY_NAME_scale_K_toz-/data/cameron/para/libero/example_twolink.xmli         c           
      C   s   t | jd df}t j| |gdd}||j jd d d df }t |d d df dd }|d d d df |d d d f  }t j|t |jd dfgdd}||j jd d d df }	|	|d d df fS )Nr      axis      gMbP?)nponesshapeconcatenateTclip)
Z	world_ptsZworld_to_cameraKr   Zpts_hZcamznormZhomogpix r   ,/data/cameron/para/libero/data_da3_volume.pyproject'   s   $"r   c                   @   sh   e Zd ZdZdeedededdddfddZd	d
 Z	dd Z
dd Zdd Zdd Zdd Zdd ZdS )Smith300DA3VolumeDatasetz}Same as Smith300DA3Dataset but adds per-frame world-Z (EEF height) +
    binning into N_HEIGHT_BINS over the dataset min/max.z8/data/cameron/mac_robot_datasets/first_mobile_collectionr
   Zda3_depth_large皙?Nc           I         s
  |	rt |	nd  _|
 _| _| _| _| _| _| _t	|}t
dd | D } jr9 fdd|D }|sOtd|  jrKd j  d tdt| d jr\d	nd  tj|}t|}t|tjjt}|j}g  _g  _g  _g  _g  _g  _g  _g  _t |D ]a\d
 }|! sqt"#t$|}|d \}}t%j&|d t%j'd}t%j&|d t%j'd}t%j&|d t%j'd}|| }t(||f||}t)fdddD d }|d u rqt"#t$|d }t%#d }t%j*|d t%j'd}|j+d }|j+d } d|j,v rt%j*|d t%j'dnd }!d|j,v r/t%j*|d t%j'dnd }"t  }#|D ] }$t-t.|$d t/t.|$d | d d D ]}%|#0|% qLq6i }&t
|#D ]=}%|!d urw|"d urw|!|% 1 }'|"|% 1 }(n8t%j2|t%j'd})||%d |f |)d t/||< |)|j3d |< t4|| |j5| 1 }'|j6| 1 }*|*g d }(t78|(9d:t%j;}+|dkrt<||%df nd },t=|'>dd!||\}-}.|-d :t%j;}-d|-d   kr|k rn nd|-d   kr|k sn q]d"|%d#d$ }/ j d"|%d#d% }0|/! sq]t?@tA|/}1|1d u r)q]t?B|1t?jC}2t?jD|2||ft?jEd&}2|2:t%j;d' }2|2Fd(dd}3|0! rYt%#|0:t%j;}4n
t%j2||ft%j;d}4t j}5|5|&|%<  jG|3  jG|4  jG|-  jGt<|'d(   jG|+  jG|,  jG q]|D ]@}$g }6t-t.|$d t/t.|$d | d d D ]}%|%|&v r|6G|&|%  qt|6d(k rϐq jGd)t%j*|6t%jHdi qtd*jI d+tJfd,d jD  d- qtKLt%jM jdd. _tKLt%jM jdd. _tKLt%jM jdd. _tKjN jtKj;d _tKLt%jM jdd.<  _tKjN jtKj;d _t< j/ t< jO }7}8|8|7krV|8|7 d/ nd/}9|7|9  _P|8|9  _QtR _StT _U j/djV jOdjV}:};|;|: d/ }<|:|< W  _X|;|< W  _Y jrtZj[! jrt%# j}=|=d0  _\|=d1  _]|=d2  _^t.|=d3  __d4 _`d  _ad  _bd  _cd5 _dd  _e f j _gtd6 j_ d7 j d8 j^W   nm jr,tZj[! jr,t%# j}>|>d9  _a|>d:  _bt<|>d;  _ct<|>d<  _dt<|>d=  _ed> _` h j _gd __td? j d@ jedAdB n d  _ad  _bd  _cd5 _dd  _edC _`d __ i j _gtdD  j j _kt< j/ t< jO }?}@|@|? }A|?|A|   _l|@|A|   _m n j _o jm jl  j }B jltKp j< dE |B   _qg  _rt  jD ]\}C}$t-t|$d) d D ]}D jrG|C|Df qqt j}E js  jt  dF }F js  jt  dF }GtdGt j dH|E dIt jr dJ|FdKdL|GdKdM tdN|?dOdP|@dOdQ jldOdP jmdOdR| dS|BdT dUdV tKu  tKjv jo|dWW }HW d    n	1 s+w   Y  tdXt/|H dYtO|H dZtJd[d |HD   d S )\Nc                 S   s   g | ]}|  r|qS r   )is_dir).0dr   r   r   
<listcomp>G   s    z5Smith300DA3VolumeDataset.__init__.<locals>.<listcomp>c                    s   g | ]
}|j  jv r|qS r   )namesessions_whitelistr    sselfr   r   r"   I       zNo sessions under z matching whitelist  z"Smith300DA3VolumeDataset: loading z	 sessionsz (whitelisted)z	meta.jsonZimage_size_whr   dtypeT_camera_arucoBaseZT_W_baseBody_inv_aruco_offsetc                 3   s$    | ]} |   r | V  qd S N)exists)r    p)sessr   r   	<genexpr>k   s    
z4Smith300DA3VolumeDataset.__init__.<locals>.<genexpr>)zrgb_overlay/episodes.jsonzepisodes.jsonepisodesz
joints.npzZq_motorsr
   r   eef_posZeef_quatstartend)r
   r   r   r   xyz      g        r   Zrgb_06dz.jpgz.npy)interpolationg     o@r   frames  : c                 3   s    | ]	}| krd V  qdS )r
   Nr   r%   )sess_idxr   r   r2          z frames loadedr   r   centroids_quatcentroids_euler
bin_counts
n_clusterskmeansg      ?z  Rotation: k-means mode (K=z	), basis=z, bin counts=meanZprincipal_axisZpca_minZpca_maxZev_ratio_pc1Z1d_pcaz  Rotation: 1D PCA mode, basis=z (EV ratio z.3f)Zper_axisz0  Rotation: per-axis mode (no PCA path provided)g      ?g    eAz Smith300DA3VolumeDataset ready: z eps, z	 frames, z samples, rgb=z.2fz GB, depth=z GBz  Height range observed: [z.4fz, u   ] → padded [z], z
 bins of ~i  z.1fzmm each)	minlengthz  Z-bin occupancy: min=z, max=z, empty_bins=c                 s   s    | ]	}|d krdV  qdS )r   r
   Nr   )r    cr   r   r   r2     r@   )wsetr$   rot_pca_pathrot_kmeans_path
image_sizen_windowr&   depth_subdirn_height_binsr   sortediterdirFileNotFoundErrorprintlenmujocoZMjModelZfrom_xml_pathZMjDataZ
mj_name2idZmjtObjZ
mjOBJ_BODYr   nqr3   rgb_tdepth_tpix_tZeef_z_teef_euler_tZ	gripper_tZsession_idx	enumerater/   jsonloadopenr   arrayfloat64r   nextasarrayr   filesrangeintminaddcopyzerosZqposZ
mj_forwardxposZxquatScipyR	from_quatas_eulerastypefloat32floatr   reshapecv2ZimreadstrZcvtColorZCOLOR_BGR2RGBresizeZINTER_LINEAR	transposeappendint64r#   sumtorch
from_numpystacktensormaxmin_gripmax_grip
N_ROT_BINS
n_rot_binsN_GRIPPER_BINSn_gripper_binsvaluestolistmin_rotmax_rotospathkmeans_centroids_quatZkmeans_centroids_eulerZkmeans_bin_countsZkmeans_n_clustersZrotation_moderot_pca_meanrot_pca_axisrot_pca_minrot_pca_maxZrot_pca_ev_ratio_bin_rotation_kmeans	rot_bin_t_bin_rotation_1d_pca_bin_rotation_bin_gripper
grip_bin_t
min_height
max_height_bin_heightz_bin_tarangebin_centerssampleselement_sizenumelno_gradbincount)Ir(   root_dirrM   rN   frame_strideZ
mujoco_xmlrO   rP   Zheight_pad_fracr$   rK   rL   rootZsessionsZmj_modelZmj_dataZeef_idZn_qpos	meta_pathmetaZIMG_WZIMG_HZK_origr-   ZT_W_baseBodyZT_CAM_WORLDZK_targetZep_pathZsess_episodesZjointsZq_motors_allZn_motorsn_framesZjoint_eef_pos_allZjoint_eef_quat_allneededepfZlocal_to_globalr4   Zeef_quat_xyzwqZ	quat_wxyzZ	eef_eulerZgripperr   _Zimg_pathZ
depth_pathZbgrrgbZrgb_chwdepthZg_idxfsZg_loZg_hiZg_padZr_loZr_hiZr_padkmpcaZz_loZz_hiZz_rangeZbin_wep_idxtnZgb_rgbZgb_dcountsr   )r(   r1   r?   r   __init__6   s  



$$,>


,
 .







 



z!Smith300DA3VolumeDataset.__init__c                 C   @   |  }|| j t| j| j d }|| j  d| jd S N:0yE>r   r
   )rq   r   r~   r   rP   longclamp)r(   r   r   r   r   r   r        z$Smith300DA3VolumeDataset._bin_heightc                 C   s\   |  }tj| jtjd}tj| jtjd}|| || d }|dd| jd  	 S )u@   eul: (N, 3). Returns (N, 3) long — per-axis euler bin indices.r+   r   r   r
   )
rq   rz   r}   r   rp   r   	clamp_minr   r   r   )r(   eullohir   r   r   r   r     s
   z&Smith300DA3VolumeDataset._bin_rotationc           	      C   s   ddl m} | }|d| }|dddf dk }||  d9  < |tjj|dddd	  }| j}|dddddf |dddddf  d
 	d}|j
dd}tj|tjdS )uz   eul: (N, 3) euler xyz. Returns (N,) long — assign each sample to nearest
        centroid in canonical-quaternion space.r   r   r7   Nr   r   T)r   keepdimsg-q=r   r   r+   )scipy.spatial.transformr   numpy
from_euleras_quatr   linalgr   r   ry   argminrz   r}   r   )	r(   r   rl   Zeul_npquatsmask	centroidsd2binsr   r   r   r     s   6z-Smith300DA3VolumeDataset._bin_rotation_kmeansc                 C   sn   |  }tj| jtjd}tj| jtjd}|| | }|| j t| j| j d }|	dd| j
d   S )uC   eul: (N, 3). Returns (N,) long — bin index along the PCA-1D axis.r+   r   r   r
   )rq   rz   r}   r   rp   r   r   r~   r   r   r   r   )r(   r   rF   r   projr   r   r   r   r   +  s   z-Smith300DA3VolumeDataset._bin_rotation_1d_pcac                 C   r   r   )rq   r   r~   r   r   r   r   )r(   gr   r   r   r   r   4  r   z%Smith300DA3VolumeDataset._bin_gripperc                 C   s
   t | jS r.   )rU   r   r'   r   r   r   __len__9  s   
z Smith300DA3VolumeDataset.__len__c                    s,  | j | \}| j| d }t|}| jt| }|d  fddt| jD } fdd|D }tj fdd|D tj	d}|| }	| j
|	 }
| j|	 }| j|	 }| j|	 }| j| }| j| }| j
| }| j| }| j| }| j| }||
|||||||||tj|tjdtjtjddS )	Nr<   r
   c                    s   g | ]
}|d     qS )r
   r   r    i)r&   r   r   r   r"   C  r)   z8Smith300DA3VolumeDataset.__getitem__.<locals>.<listcomp>c                    s   g | ]}t | qS r   )rg   r   	last_realr   r   r"   D  s    c                    s   g | ]}| kqS r   r   )r    rawr   r   r   r"   E  s    r+   )r   Z
gt_pix_504Zgt_pix_validgt_z_bin
gt_rot_bingt_grip_bincur_grip_bincur_rot_bin	cur_z_binZ	da3_depthZstart_pix_504r   start_t)r   r3   rU   r&   rf   re   rN   rz   r}   boolrZ   r   r   r   rX   rY   r   )r(   idxr   r<   LZcur_gZfuture_local_rawZfuture_localvalidZfuture_globalZgt_pixr   r   r   r   r   Z	start_pixr   r   r   r   )r   r&   r   r   __getitem__<  sD   









z$Smith300DA3VolumeDataset.__getitem__)__name__
__module____qualname____doc__	DA3_INPUTN_WINDOWDEFAULT_SMITH300_XMLN_HEIGHT_BINSr   r   r   r   r   r   r   r   r   r   r   r   r   2   s$    
 Z	r   __main__r   r=   r>    )(r   r   sysr]   pathlibr   rs   r   r   rz   torch.utils.datar   r   r   rl   r   insertZdata_smith300_parar   r   r   rV   r   r   r   r   r   r   r   r   dsr&   itemskvhasattrrT   tupler   r,   r   r   r   r   <module>   s:      4.