o
    j[                     @   sT  d Z ddlZddlZddlZddlZddlZddlmZ ddl	m
Z ddlZdZdZdZdZd	Zd
d Zejg dg dg dg dg dgejdZdejdejjdejfddZdd ZG dd deZedkredZedee  ed Z e ! D ](\Z"Z#e$e#ej%rede" de&e#j' d e#j(  qede" de#  qdS dS )!a  Smith300 arm dataset for PARA training.

Adapted from panda_streaming/data_panda_para.py for the smith300 capture
format:
  rgb_NNNNNN.jpg     # 960x540 (anisotropic resize to 448x448 for the model)
  joints.npz         # q_motors[T,6] (no gripper recorded yet) + ticks/timestamps
  meta.json          # K, T_camera_arucoBase, T_W_baseBody_inv_aruco_offset, image_size_wh
  rgb_overlay/episodes.json  # parsed episodes (or root episodes.json)

Key differences vs the panda version:
  - q_motors has 6 entries; the smith300 MuJoCo model has 7 hinge joints
    (6 arm + 1 gripper finger). We pad q[6] = 0 since gripper isn't recorded.
  - World->camera transform is T_camera_arucoBase @ T_W_baseBody_inv_aruco_offset
    (camera pose in arucoBase frame, then arucoBase->baseBody offset).
  - EEF body in the smith300 XML is "virtual_gripper_keypoint" (not "hand").
  - trajectory_gripper is set to 0 everywhere; train_smith300_para.py sets
    GRIPPER_LOSS_WEIGHT=0 so the head is still wired but contributes no loss.
    N)Dataset)Rotation      i  Zvirtual_gripper_keypointzB/home/cameronsmith/mnt/mac/smith300_para_stuff/example_twolink.xmlc                 C   s   |ddddf |  |dddf  }|d dkrdS |d |d  |d  |d  }|d |d  |d  |d	  }t j||gt jd
S )z1Project a 3D world point to 2D pixel coordinates.N      r   r   r   r   r      r   r   r   r   dtype)nparrayfloat32)Z	pos_worldZT_cwKZp_camuv r   1/data/cameron/para/para_mac/data_smith300_para.pyproject_to_pixel&   s   (  r   )r   r   r   )r   r   r   )r   r   r   r   r   r   )r   r   r   r   rgb_hwcrngreturnc                 C   sT  | }|  dk rt|dtt }|d|f }|  dk }|  dk }|s(|rqt|d tjtj	tj
}|rK|d t|dd d	 |d< |r`t|d
 t|dd dd|d
< t|tjtjtj
d }|  dk rd|d  d|d
   d|d   }tj|||gdd}|  dk rt|t|dd dd}|S )u  Aggressive color augmentation for the (H, W, 3) float32 RGB input in
    [0, 1]. Specifically tuned to defeat the gripper-appearance shortcut
    when training on one rig (UMI: green gripper) and deploying on another
    (smith300: white gripper):
      - Channel permutation @ 80% (was 50%): turns green into red/blue/etc,
        forcing features that aren't channel-specific.
      - Full-circle hue jitter (was ±30deg): hue can land anywhere, so
        "green" isn't a stable cue.
      - Saturation 0.0-1.8x (was 0.5-1.5): includes full grayscale.
      - 30% pure grayscale convert: hardest forcing — model must use shape
        and texture, not color.
      - Brightness 0.6-1.4x (was 0.7-1.3).
    g?r   .gffffff?   ).r   i[      ).r           g?     o@g333333?gA`"?gbX9?gv/?).r   )axisg333333?gffffff?r   )random_CHAN_PERMSintegerslencv2cvtColorastyper   uint8COLOR_RGB2HSVr   intclipfloatuniformCOLOR_HSV2RGBstack)r   r   outpermZdo_hueZdo_satZhsvZgrayr   r   r   _augment_color6   s&   " &"$r4   c                 C   st   | \}}|t | }|t | }| tj}|d  |9  < |d  |9  < |d  |9  < |d  |9  < |S )zJAnisotropic K rescale from (W_orig, H_orig) -> (target_size, target_size).r   r	   r
   r   )r.   copyr)   r   float64)Zimage_size_wh_origK_origtarget_sizeWHsxsyr   r   r   r   _scale_K_to_   s   r=   c                   @   sN   e Zd ZdZdedededfdededefdd	Z	d
d Z
dd Zdd ZdS )Smith300TrajectoryDatasetzFSame contract as PandaTrajectoryDataset (returns the same dict shape).Nr   Faugment_colorn_windowuse_keyframesc	           8   	      sP  || _ || _t|| _|| _t|| _t|| _tj	
 | _ttj|d}	t|	}
W d    n1 s7w   Y  |
d \}}tj|
d tjd}tj|
d tjd}tj|
d tjd}|| }|| _|| _|| _t||f||| _|| _|d u rdD ]}tj||}tj|r|} nq|d u stj|std| d	| d
t|}	t|	}W d    n1 sw   Y  |d | _ttj|d}tj|d tjd}|jd }|jd }d|jv rtj|d tjdnd }d|jv rtj|d tjdnd }d|jv rtj|d tjdnd }|d uo&|d uo&|d u}t d| dd t!j"#|}t!$|}t!%|t!j&j't(}|dk rSt)dt(d| |j*}t+ } | jD ]"}!t,t|!d |d  t-t|!d  d D ]}"| .|" qtq\t dt/|  d|rdnd ddd i | _0t1| D ]}"||" }#|r||" 2tj3}$||" 2tj3}%||" 2tj3}&nJtj4|tjd}'|#d | |'d t,||< |'|j5d |< t!6|| |j7| 8 2tj3}$|j9| 8 }(|(g d 2tj3}%t:;|%<d 2tj3}&t=|$2tj|| j})|)d ur^d!t>|)d   ko(t>|d kn  o@d!t>|)d   ko>t>|d kn  }*tjt?|)d d|d t?|)d d|d gtj3d}+n
d"}*tj4d#tj3d}+|d$krtt3|#d% nt3d!},|$|%|&|+t|*|,d&| j0|"< qg | _@g | _Ag | _Bd}-tC| jD ]\}.}!t|!d t,t|!d |d  | jBD  | jr fd'd(|!Ed)g D }/|/st d*|!Ed+|. d,dd | jADg  q| jAD|/ t-t/|/D ]}0| j0|/|0  Ed-ds|-d7 }-q| j@D|.|0f qq| jADg    d }1t-|1D ]}2|2 }	| j0|	 Ed-ds1|-d7 }-q| j@D|.|	f qq|-rJt d.|- d/dd |8 2tj3| _F| jFd  |  < | jFd  |  < tjg d0tj3d| _Gtjg d1tj3d| _H| jrd2d( | jAD }3t d3t/| j d4t/| j@ d5| j d6|3 dd nt d7t/| j d4t/| j@ d8| dd i | _It1| D ]I}4tj|d9|4d:d;}5tJK|5tJjL}6|6d u rtd<|5 |6jd |ks|6jd |krtJjM|6||ftJjNd=}6tJO|6tJjP| jIt|4< qtQd>d? | jIR D }7t d@t/| jI dA|7dB dCdDdd d S )ENz	meta.jsonZimage_size_whr   r   T_camera_arucoBaseZT_W_baseBody_inv_aruco_offset)zrgb_overlay/episodes.jsonzepisodes.jsonzNo episodes.json found under z/rgb_overlay/ or z/.episodesz
joints.npzq_motorsr   r   eef_poseef_quat	eef_eulerzLoading MuJoCo model: T)flushzbody z missing from XML endstartz!Pre-computing per-frame data for z	 frames (zusing saved EEF poses (UMI)zrunning arm FK))r   r   r   r   xyzr   Fr   r   r   )rE   rF   rG   pixel_2dpixel_validgripperc                    s8   g | ]}t |d    kr krn nt |d  qS )frame)r,   ).0kfep_endZep_startr   r   
<listcomp>  s     
z6Smith300TrajectoryDataset.__init__.<locals>.<listcomp>Z	keyframeszWARN: episode idz. has no keyframes; skipping in keyframes mode.rN   z
  skipped z; samples whose start frame had an off-frame EEF projection.)g
ףp=
?gv/?gCl?)gZd;O?gy&1?g?c                 S   s   g | ]}|rt |qS r   )r&   rQ   kr   r   r   rU   C  s    zSmith300Dataset (keyframes): z episodes, z samples, n_window=z, keyframes/ep (post-pad)=zSmith300Dataset: z samples, stride=rgb_06d.jpgmissing frame: interpolationc                 s   s    | ]}|j V  qd S N)nbytes)rQ   ar   r   r   	<genexpr>[  s    z5Smith300TrajectoryDataset.__init__.<locals>.<genexpr>z
Preloaded z RGB frames (i   z.1fz MB) into memory.)Sdata_dir
image_sizer,   r@   frame_strideboolr?   rA   r   r#   default_rng_rngopenospathjoinjsonloadr   r6   IMG_WIMG_Hr7   r=   ZK_targetT_CAM_WORLDexistsFileNotFoundErrorrC   asarrayshapefilesprintmujocoMjModelfrom_xml_pathMjData
mj_name2idmjtObj
mjOBJ_BODYEEF_BODY_NAMERuntimeErrornqsetminrangeaddr&   
frame_datasortedr)   r   zerosqpos
mj_forwardxposr5   xquatScipyR	from_quatas_eulerr   r.   r-   samplesepisode_keyframesepisode_ends	enumerateappendget
cam_k_normmeanstd	rgb_cacher'   imreadIMREAD_COLORresize
INTER_AREAr(   COLOR_BGR2RGBsumvalues)8selfrc   Zepisodes_jsonrd   re   
mujoco_xmlr?   r@   rA   fmetaro   rp   r7   rB   ZT_W_baseBodyrq   candpep_datajointsZq_motors_alln_recorded_motorsZn_frames_totalZeef_pos_savedZeef_quat_savedZeef_euler_savedZuse_saved_eefmj_modelmj_dataeef_idn_qposZall_frame_indicesepidxZq_inrE   rF   rG   q	quat_wxyzpixrN   rM   gripZn_skipped_invalidep_idx	kf_framesZstart_kfZep_lentZ	kf_counts	frame_idxrk   bgrZn_bytesr   rS   r   __init__o   sx  












	,("	
 
z"Smith300TrajectoryDataset.__init__c                 C   s
   t | jS r_   )r&   r   )r   r   r   r   __len___  s   
z!Smith300TrajectoryDataset.__len__c                 C   s   | j t|}|d u rOtj| jd|dd}t|tj	}|d u r+t
d| |jd | jks;|jd | jkrHtj|| j| jftjd}t|tj}|tjd S )	NrY   rZ   r[   r\   r   r   r]   r    )r   r   r,   rj   rk   rl   rc   r'   r   r   rs   ru   rd   r   r   r(   r   r)   r   r   )r   r   Zrgb_u8rk   r   r   r   r   _load_rgb_resizedb  s    z+Smith300TrajectoryDataset._load_rgb_resizedc              	      s~  j | \}jr"j| t  fddtjD }nj| fddtjD }g }g }g }g }g }g }	g }
d }tjD ]G}|| }|}|d u r[|}|
| j	| }||d  ||d  ||d  ||d  ||d  |	|
d	d
 qJt|}t|}tj|tjd}t|}t|}tj|	td}	t|
}
tjjjjftjd}tjD ]1}tttt||df djd }tttt||df djd }d||||f< q|}jrt| j}t|djd d d d f  jd d d d f  }jtj}tj !jtj}i dt"#| dt"#| dt"#| dt"#| dt"#| dt"#| dt"#| dt"#|	dt"#|
 dt"#| dt"j$dt"jddt"#|d  dt"#| dt"#j% d t"j$|t"j&dd!t"j$t"j&dS )"Nc                    s"   g | ]}t |  d   qS )r   )r   rW   )r   r   start_framer   r   rU   z  s    z9Smith300TrajectoryDataset.__getitem__.<locals>.<listcomp>c                    s    g | ]}t |j   qS r   )r   re   rW   )rT   r   r   r   r   rU     s    rE   rF   rG   rM   rO   rN   Tr   r   r   g      ?r   rgbZheatmap_targettrajectory_2dtrajectory_3dtrajectory_grippertrajectory_quattrajectory_eulertrajectory_validrgb_frames_rawworld_to_cameraZbase_zr   Z	target_3dr!   camera_pose
cam_K_normZdemo_idxstart_t)'r   rA   r   r&   r   r@   r   r   r   r   r   r   r1   r   r   rt   rf   r   rd   r,   r-   roundr.   r?   r4   r5   rh   	transposer   r   rq   r)   linalginvtorch
from_numpytensorr   long)r   r   r   Zwindow_framesr   r   r   r   r   r   r   Zrgb_refrX   r   r   fdZheatmap_targetsr   Zx_iZy_iZrgb_for_modelZrgb_tr   r   r   )r   rT   r   r   r   r   __getitem__r  s   









((4
	
z%Smith300TrajectoryDataset.__getitem__)__name__
__module____qualname____doc__
IMAGE_SIZEDEFAULT_SMITH300_XMLN_WINDOWrf   r,   r   r   r   r   r   r   r   r   r>   l   s&    
 qr>   __main__z8/data/cameron/mac_robot_datasets/dataset_20260501_180125zDataset size: z  z:  ))r   rj   rm   r'   numpyr   r   torch.utils.datar   scipy.spatial.transformr   r   rx   r   ZN_MOTORS_FULLr   r   r   r   r   int64r$   ndarrayr#   	Generatorr4   r=   r>   r   dsrw   r&   sitemsrX   r   
isinstanceTensortupleru   r   r   r   r   r   <module>   sL    )  `$