o
    ?߱i                     @   s  d dl Z d dlmZ d dlmZ d dlZd dlmZ d dlm	Z	 d dl
m	  mZ d dlmZ d dlmZ d dlmZmZ d dlmZ d dlmZmZmZ d d	lmZ d d
lmZ d dlmZ d dl m!Z! dgZ"dZ#G dd de	j$Z%G dd de	j&Z'G dd de	j(Z(G dd de	j&Z)G dd de	j&Z*G dd de	j&Z+G dd de	j&Z,G dd de	j&Z-dd  Z.G d!d" d"e	j&Z/			#	$	%	&d.d'e0d(e0fd)d*Z1G d+d dZ2G d,d- d-eZ3dS )/    N)nullcontext)Optional)	rearrange)parallel_state)INTERNALSMOKE)log)	broadcastget_ranksync_model_states)easy_io)VideoTokenizerInterface)plugin_mount)BenchmarkTimesWanVAE   c                       s.   e Zd ZdZ fddZd fdd	Z  ZS )CausalConv3dz 
    Causal 3d convolusion.
    c                    sP   t  j|i | | jd | jd | jd | jd d| jd  df| _d| _d S )Nr      r   r   r   r   )super__init__padding_padding)selfargskwargs	__class__ Z/data/cameron/vidgen/cosmos-predict2.5/cosmos_predict2/_src/predict2/tokenizers/wan2pt1.pyr   1   s   4
zCausalConv3d.__init__Nc                    sl   t | j}|d ur*| jd dkr*||j}tj||gdd}|d  |jd 8  < t||}t	 
|S )N   r   r   dim)listr   todevicetorchcatshapeFpadr   forward)r   xcache_xr   r   r   r   r+   6   s   
zCausalConv3d.forwardN__name__
__module____qualname____doc__r   r+   __classcell__r   r   r   r   r   ,   s    r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )RMS_normTFc                    sr   t    |s	dnd}|r|g|R n|f}|| _|d | _tt|| _|r4tt	|| _
d S d| _
d S )N)r   r   r   )r   r         ?        )r   r   channel_firstscalenn	Parameterr&   onesgammazerosbias)r   r"   r8   imagesr?   Zbroadcastable_dimsr(   r   r   r   r   B   s   

$zRMS_norm.__init__c                 C   s*   t j|| jrdndd| j | j | j S )Nr   r!   )r)   	normalizer8   r9   r=   r?   r   r,   r   r   r   r+   L   s   *zRMS_norm.forwardTTFr0   r1   r2   r   r+   r4   r   r   r   r   r5   A   s    
r5   c                       s   e Zd Z fddZ  ZS )Upsamplec                    s   t  | |S )zJ
        Fix bfloat16 support for nearest neighbor interpolation.
        )r   r+   floattype_asrC   r   r   r   r+   Q   s   zUpsample.forward)r0   r1   r2   r+   r4   r   r   r   r   rF   P   s    rF   c                       s<   e Zd Z fddZddgfddZdd Zd	d
 Z  ZS )Resamplec              	      s  |dv sJ t    || _|| _|dkr+ttdddtj||d ddd	| _d S |d
krPttdddtj||d ddd	| _t	||d ddd	| _
d S |dkrgttdtj||ddd| _d S |dkrttdtj||ddd| _t	||dddd| _
d S t | _d S )N)none
upsample2d
upsample3ddownsample2ddownsample3drK   )       @rO   znearest-exact)scale_factormoder      r   r   rL   )rR   r   r   )r   r   r   rM   )r   r   r   r   )r   r   )striderN   )r   r   r   r   )rT   r   )r   r   r"   rQ   r:   
SequentialrF   Conv2dresampler   	time_conv	ZeroPad2dIdentity)r   r"   rQ   r   r   r   r   Y   s&   

&"zResample.__init__Nr   c                 C   s  |  \}}}}}| jdkr|d ur|d }	||	 d u r)d||	< |d  d7  < n|d d d d t d d d d d f  }
|
jd dk rs||	 d urs||	 dkrstj||	 d d d d dd d d d f d|
j	|
gdd}
|
jd dk r||	 d ur||	 dkrtjt
|
|
j	|
gdd}
||	 dkr| |}n| |||	 }|
||	< |d  d7  < ||d||||}t|d d dd d d d d d d d f |d d dd d d d d d d d f fd}||||d ||}|jd }t|d	}| |}t|d
|d}| jdkrr|d urr|d }	||	 d u r3| ||	< |d  d7  < |S |d d d d dd d d d d f  }
| t||	 d d d d dd d d d d f |gd}|
||	< |d  d7  < |S )NrL   r   ZRepr   r   rA   r!   rR   b c t h w -> (b t) c h wz(b t) c h w -> b c t h wtrN   )sizerQ   CACHE_Tcloner(   r&   r'   	unsqueezer$   r%   
zeros_likerX   reshapestackr   rW   )r   r,   
feat_cachefeat_idxbcr]   hwidxr-   r   r   r   r+   s   sL   
,&8& X




*<zResample.forwardc           
      C   s~   |j }tj| | \}}}}}t||}|}	tj| |	|jd d d d dddf< |j j| tj|j	j d S )Nr   r   )
weightr:   initzeros_r^   r&   eyedatacopy_r?   )
r   convconv_weightc1c2r]   ri   rj   Z
one_matrixinit_matrixr   r   r   init_weight   s   zResample.init_weightc           	      C   s   |j j}tj| | \}}}}}t|d |}||d |d d d dddf< |||d d d d dddf< |j j| tj|j	j d S )Nr   rA   r   )
rl   rp   r:   rm   rn   r^   r&   ro   rq   r?   )	r   rr   rs   rt   ru   r]   ri   rj   rv   r   r   r   init_weight2   s   zResample.init_weight2)r0   r1   r2   r   r+   rw   rx   r4   r   r   r   r   rI   X   s
    1rI   c                       s.   e Zd Zd fdd	ZddgfddZ  ZS )	ResidualBlockr7   c                    s   t    || _|| _tt|ddt t||dddt|ddt t	|t||ddd| _
||kr?t||d| _d S t | _d S )NFr@   rR   r   rS   )r   r   in_dimout_dimr:   rU   r5   SiLUr   DropoutresidualrZ   shortcut)r   r{   r|   dropoutr   r   r   r      s   


(	zResidualBlock.__init__Nr   c              	   C   s   |  |}| jD ]k}t|tro|d uro|d }|d d d d t d d d d d f  }|jd dk r[|| d ur[tj|| d d d d dd d d d f 	d
|j|gdd}|||| }|||< |d  d7  < q||}q|| S Nr   r   rA   r!   r   )r   r   
isinstancer   r_   r`   r(   r&   r'   ra   r$   r%   )r   r,   re   rf   ri   layerrk   r-   r   r   r   r+      s   

,8
zResidualBlock.forward)r7   rE   r   r   r   r   ry      s    ry   c                       s(   e Zd ZdZ fddZdd Z  ZS )AttentionBlockz3
    Causal self-attention with a single head.
    c                    sR   t    || _t|| _t||d d| _t||d| _tj	
| jj d S )NrR   r   )r   r   r"   r5   normr:   rV   to_qkvprojrm   rn   rl   )r   r"   r   r   r   r      s   

zAttentionBlock.__init__c                 C   s   |}|  \}}}}}t|d}| |}| ||| d|d ddddd jddd\}}	}
t	||	|
}|
dddd|| |||}| |}t|d|d	}|| S )
Nr[   r   rR   rA   r   r   r!   z(b t) c h w-> b c t h wr\   )r^   r   r   r   rc   permute
contiguouschunkr)   scaled_dot_product_attentionsqueezer   )r   r,   identityrg   rh   r]   ri   rj   qkvr   r   r   r+      s   

>$
zAttentionBlock.forwardr/   r   r   r   r   r      s    r   c                       D   e Zd Zddg ddg g ddf fdd	Zd	d
gfddZ  ZS )	Encoder3d   r    r   r   r    r    r   rD   r7   c              
      s`  t     | _|| _|| _|| _|| _|| _ fdddg| D }d}	td|d ddd| _	g }
t
t|d d |dd  D ]@\}\}}t|D ]}|
t||| |	|v ra|
t| |}qK|t|d kr|| rrd	nd
}|
t||d |	d }	qAtj|
 | _tt|||t|t|||| _tt|ddt t||ddd| _d S )Nc                       g | ]} | qS r   r   .0ur!   r   r   
<listcomp>      z&Encoder3d.__init__.<locals>.<listcomp>r         ?rR   r   rS   rA   rN   rM   rQ   rO   Frz   )r   r   r"   z_dimdim_multnum_res_blocksattn_scalestemperal_downsampler   conv1	enumerateziprangeappendry   r   lenrI   r:   rU   downsamplesmiddler5   r}   head)r   r"   r   r   r   r   r   r   dimsr9   r   ir{   r|   _rQ   r   r!   r   r   	  s:   

*
zEncoder3d.__init__Nr   c              	   C   s  |d ura|d }|d d d d t  d d d d d f  }|jd dk rL|| d urLtj|| d d d d dd d d d f d|j|gdd}| ||| }|||< |d  d7  < n| |}| j	D ]}|d urv||||}qi||}qi| j
D ]}t|tr|d ur||||}q~||}q~| jD ]k}t|tr|d ur|d }|d d d d t  d d d d d f  }|jd dk r|| d urtj|| d d d d dd d d d f d|j|gdd}|||| }|||< |d  d7  < q||}q|S r   )r_   r`   r(   r&   r'   ra   r$   r%   r   r   r   r   ry   r   r   r   r,   re   rf   rk   r-   r   r   r   r   r+   =  s<   ,B





,8
zEncoder3d.forwardrE   r   r   r   r   r         4r   c                       r   )	Decoder3dr   r    r   r   FTTr7   c              
      s  t     | _|| _|| _|| _|| _|| _ fdd|d g|d d d  D }ddt|d   }	t	||d ddd	| _
tt|d |d |t|d t|d |d || _g }
tt|d d |dd  D ]R\}\}}|dks~|dks~|dkr|d }t|d D ]}|
t||| |	|v r|
t| |}q|t|d kr|| rd
nd}|
t||d |	d9 }	qltj|
 | _tt|ddt t	|dddd	| _d S )Nc                    r   r   r   r   r!   r   r   r   ~  r   z&Decoder3d.__init__.<locals>.<listcomp>rA   r   r   r   rR   r   rS   rL   rK   r   rO   Frz   )r   r   r"   r   r   r   r   temperal_upsampler   r   r   r:   rU   ry   r   r   r   r   r   r   rI   	upsamplesr5   r}   r   )r   r"   r   r   r   r   r   r   r   r9   r   r   r{   r|   r   rQ   r   r!   r   r   k  s:   

&.*,zDecoder3d.__init__Nr   c              	   C   s  |d ura|d }|d d d d t  d d d d d f  }|jd dk rL|| d urLtj|| d d d d dd d d d f d|j|gdd}| ||| }|||< |d  d7  < n| |}| j	D ]}t
|tr{|d ur{||||}qi||}qi| jD ]}|d ur||||}q||}q| jD ]k}t
|tr|d ur|d }|d d d d t  d d d d d f  }|jd dk r|| d urtj|| d d d d dd d d d f d|j|gdd}|||| }|||< |d  d7  < q||}q|S r   )r_   r`   r(   r&   r'   ra   r$   r%   r   r   r   ry   r   r   r   r   r   r   r   r+     s<   ,B





,8
zDecoder3d.forwardrE   r   r   r   r   r   j  r   r   c                 C   s(   d}|   D ]}t|tr|d7 }q|S )Nr   r   )modulesr   r   )modelcountmr   r   r   count_conv3d  s   
r   c                       s   e Zd Zddg ddg g dddf fdd	Zd	d
 ZdddZejjdd Z	ejjdd Z
dddZdd ZdddZdd Z  ZS )WanVAE_r   r    r   r   rD   r7   c	           	         s   t    || _|| _|| _|| _|| _|| _|d d d | _|| _	t
||d |||| j|| _t|d |d d| _t||d| _t|||||| j|| _d S )NrA   r   r   )r   r   r"   r   r   r   r   r   r   temporal_windowr   encoderr   r   conv2r   decoder)	r   r"   r   r   r   r   r   r   r   r   r   r   r     s   
zWanVAE_.__init__c                 C   s.   |  |\}}| ||}| |}|||fS r.   )encodereparameterizedecode)r   r,   mulog_varzZx_reconr   r   r   r+     s   

zWanVAE_.forwardTc              	   C   s  |r|    |jd }d|d | j  }t|D ]?}dg| _|dkr(| |}q| j|d d d d d| j|d   d| j|  d d d d f | j| jd}t	||gd}q|d | j rdg| _| j|d d d d d| j|d   d d d d d f | j| jd}t	||gd}| 
|jddd\}	}
t|d tjr|	|d d| jddd |d d| jddd }	n
|	|d  |d  }	|r|    |	S )Nr   r   r   re   rf   r!   )clear_cacher(   r   r   _enc_conv_idx
_i0_encoder   _enc_feat_mapr&   r'   r   r   r   Tensorviewr   )r   r,   r9   clear_encoder_cacher]   iter_r   outout_r   r   r   r   r   r     s:   
<26zWanVAE_.encodec                 C   s:   | j |ddddddddddf | j| jd}|S )zi
        If enabled torch.compile uses significantly more memory for this step, so we disable it
        Nr   r   )r   r   r   )r   r,   r   r   r   r   r     s   6zWanVAE_._i0_encodec                 C   s6   | j |d d d d ddd d d d f | j| jdS )Nr   r   r   )r   	_feat_map	_conv_idxrC   r   r   r   
_i0_decode#  s   6zWanVAE_._i0_decodec           	   	   C   s   |r|    t|d tjr)||d d| jddd |d d| jddd }n
||d  |d  }|jd }| |}t|D ]5}dg| _	|dkrQ| 
|}qA| j|d d d d ||d d d d d f | j| j	d}t||gd}qA|r}|    |S )Nr   r   r   r   )r   r   r&   r   r   r   r(   r   r   r   r   r   r   r'   )	r   r   r9   clear_decoder_cacher   r,   r   r   r   r   r   r   r   '  s    6

:zWanVAE_.decodec                 C   s$   t d| }t |}|| | S )Nr6   )r&   exp
randn_like)r   r   r   stdepsr   r   r   r   <  s   
zWanVAE_.reparameterizeFc                 C   s>   |  |\}}|r|S td|dd }||t|  S )Nr6   g      >g      4@)r   r&   r   clampr   )r   imgsdeterministicr   r   r   r   r   r   sampleA  s
   zWanVAE_.samplec                 C   sH   t | j| _dg| _d g| j | _t | j| _dg| _d g| j | _d S )Nr   )	r   r   Z	_conv_numr   r   r   Z_enc_conv_numr   r   r   r   r   r   r   H  s   zWanVAE_.clear_cacheT)F)r0   r1   r2   r   r+   r   r&   compilerdisabler   r   r   r   r   r   r4   r   r   r   r   r     s(    
#



r   cpucredentials/s3_training.secretFQs3://bucket/cosmos_diffusion_v2/pretrain_weights/tokenizer/wan2pt1/Wan2.1_VAE.pths3_credential_pathmean_std_pathc                 K   s  t d|g ddg g ddd}|jdi | td tdi |}W d   n1 s.w   Y  ts9| du rp|j|d	 |rotjd
dd
d
d
|d	tjd
dd
d
d
|d	}	}
tjd
ddd
d
|d	tjd
ddd
d
|d	}}nt dkrt	sddl
m} || } | drd}tj|d|dd nd}tj| ||d}|r|dd}|dd}t	sddl
m} ||}||}tj|||d\}	}
tj|||d\}}|	d
dd
d
d
}	|
d
dd
d
d
}
|d
ddd
d
}|d
ddd
d
}td|   |j|dd n7|j|d	 |r;tjd
dd
d
d
|d	tjd
dd
d
d
|d	}	}
tjd
ddd
d
|d	tjd
ddd
d
|d	}}t| |rbtd t|	d t|
d t|d t|d ||	|
||fS |tjd
d
d
d
d
|d	tjd
d
d
d
d
|d	tjd
d
dd
d
|d	tjd
d
dd
d
|d	fS )zF
    Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
    `   r   r   r   r7   )r"   r   r   r   r   r   r   metaN)r%   r          r   )get_checkpoint_pathzs3://Zwan2pt1_vaes3)backendr   )keybackend_args)backend_keymap_locationzmean_std.ptzimages_mean_std.ptzvideo_mean_std.ptzloading T)assignz"broadcast mean and std for wan2pt12   r   )dictupdater&   r%   r   r   to_emptyrandnr
   r   3cosmos_predict2._src.imaginaire.utils.checkpoint_dbr   
startswithr   set_s3_backendloadreplacerc   r   infoload_state_dictr   r	   r>   r<   )pretrained_pathr   r%   r   load_mean_stdr   r   cfgr   img_meanimg_std
video_mean	video_stdr   r   ckptZimg_mean_stdZvideo_mean_stdr   r   r   
_video_vaeR  s   	.





r  c                   @   s  e Zd Zdddddejdddddd	fd
edededededee	eef  fddZ
dd Ze d/ddZe d/ddZedd Zedd Zedd Zdejdejfd d!Zd"ejfd#d$Zd%d& Zd'd( Zd)ejdefd*d+Zd,ejdd	fd-d.Zd	S )0r   r   r   r   FzNs3://bucket/cosmos_diffusion_v2/pretrain_weights/tokenizer/wan2pt1/mean_std.ptcudaTr    Nr   r   	benchmarkr   is_parallelcp_grid_shapec              	   C   s0  || _ || _|	| _|
| _|| _|| _d| _d| _g d}g d}tj	|||d| _
tj	|||d| _| j
d| j g| _t|||||||
d\| _| _| _| _| _|rqd }t rgt }|d u rfd| f}nJ d| || | j d| _|| _|s| jj|d	| _t | _d S tjjd
|d	| _d S )NF)gy):gMOg^)gQ?gtVƿgZӼ?gBfjÿgU0*?gL
F%u?gMg&?gz6>׿gF%uȿg[ AcgMJ?gW2ıҿ)g_L@gNё\C?gQ@g?@g9#J{?g|a2U?gHPs@g0* @gJ{/L&
@gJY8@g]C@g(?gK46?gS:?go_Ι@g-?)dtyper%   r   )r   r   r   r   r   r%   r   r   z7is_parallel set, but context parallelism is initialized)r
  r  ) r
  r%   r  r   r  r	  context_parallel_enabledcp_group_initializedr&   tensormeanr   r9   r  r   r   r  r  r  r   is_initializedget_context_parallel_groupr^   _initialize_context_parallelevalrequires_grad_is_ampr$   r   contextampautocast)r   r   vae_pthr   r   r   r
  r%   r  r  r   r  r	  r  r   cp_groupr   r   r   r     sJ   
zWanVAE.__init__c                 C   s   t dd | j D S )Nc                 s   s    | ]}|  V  qd S r.   )numel)r   pr   r   r   	<genexpr>  s    z%WanVAE.count_param.<locals>.<genexpr>)sumr   
parametersr   r   r   r   count_param  s   zWanVAE.count_paramc           	   
   C   sD  | j r6| |r|   n)z| |}|   W n ty5 } ztt| |   W Y d}~nd}~ww | j	rEt
j  t }t }|j}| j5 | jsU|| j}| j	rat
j  t }| j|| j|}| j	ryt
j  t | |_W d   n1 sw   Y  ||}| j	rt
j  t | |_||fS |S )zH
        videos: A list of videos each with shape [C, T, H, W].
        N)r  _is_image_batch_disable_context_parallel&_broadcast_split_for_model_parallelsim_enable_context_parallel
ValueErrorr   warningstrr  r&   r  synchronizer   timeperf_counterr
  r  r  r$   r   r   r9   model_invocationtotal)	r   videosr   ebenchmark_times
total_timein_dtype
model_timelatentr   r   r   r     sD   








zWanVAE.encodec           	      C   s  | j rtj  t }t }| jra| |r| 	  nE|j
d | jd  dko3|j
d | jd  dk}|s]td| j d|j
d  d| jd  d|j
d  d| jd  d	 | 	  n|   |j}| j5 | jsq|| j}| j r}tj  t }| j|| j|}| j rtj  t | |_W d    n1 sw   Y  ||}| jr| jr| |}| j rtj  t | |_||fS |S )
NrR   r   r    r   &For parallel encoding with grid_shape z9 latent height should be divisible by grid_shape[0], got z / z5 and width should be divisible by grid_shape[1], got z, falling back to non CP)r  r&   r  r'  r   r(  r)  r  r   r!  r(   r	  r   r%  r#  r
  r  r  r$   r   r   r9   r*  r  _cat_outputs_cpr+  )	r   zsr   r.  r/  Zcan_apply_cpr0  r1  video_reconr   r   r   r   E  sF   


0<






zWanVAE.decodec                 C      dS N   r   r   r   r   r   spatial_compression_factorm     z!WanVAE.spatial_compression_factorc                 C   r7  Nr    r   r   r   r   r   temporal_compression_factorq  r;  z"WanVAE.temporal_compression_factorc                 C   r7  )NrR   r   r   r   r   r   _cp_dimu  r;  zWanVAE._cp_dimstatereturnc                    sZ  t  jdksJ d| j\}} jd || j  dko' jd || j  dk}|sStd| j d jd  d| jd  d	| j d
 jd  d| jd  d	| j d    fddt|| D }tj| | j	d |d   jd | } jd | }tj
| j	d}|| }	|| }
 d d d d d d |	| |	d | |
| |
d | f S )N   zState should be of shape BCTHWrR   r   r    r3  zE height should be divisible by compression_factor*grid_shape[0], got z / (z * zI) and width should be divisible by compression_factor*grid_shape[1], got r   z), falling back to non CPc                       g | ]}t  qS r   r&   rb   r   r   r?  r   r   r         zAWanVAE._broadcast_split_for_model_parallelsim.<locals>.<listcomp>group)r   r(   r	  r:  r$  r   r   distributed
all_gatherr  r
   )r   r?  Zcp_rowscp_colsZcan_cp_be_applied_to_shape
state_listZchunk_hZchunk_w
group_rankrow_idZcol_idr   rE  r   r"  y  s&   
L>z-WanVAE._broadcast_split_for_model_parallelsimlocal_video_reconc                    sV    fddt jD tj jd tjfddt jd D dd}|S )Nc                    rB  r   rC  rD  )rO  r   r   r     rF  z*WanVAE._cat_outputs_cp.<locals>.<listcomp>rG  c                    s*   g | ]}t j|d  jd  ddqS )Nr   rR   r!   )r&   r'   r	  )r   rh   )r   video_recon_chunksr   r   r     s   * r   r    r!   )r   cp_group_sizerI  rJ  r  r&   r'   r	  )r   rO  r6  r   )rO  r   rP  r   r4    s   zWanVAE._cat_outputs_cpc                 C   :   d| _ | j D ]\}}| D ]	\}}|d qqd S )NTr  pluginsitemsZ
set_enabler   r   Zplugin_listpluginr   r   r   r#       zWanVAE._enable_context_parallelc                 C   rR  )NFrS  rV  r   r   r   r!    rX  z WanVAE._disable_context_parallelr,   c                 C   s$   t |jdksJ d|jd dkS )NrA  z#Expected tensor's shape to be BCTHWr   r   )r   r(   rC   r   r   r   r     s   zWanVAE._is_image_batchr  c                 C   sh   | j du sJ d| _|| _d| _|| _tt| j| _t	| j
| j|| _|   td| d d S )NFTzEnabled CP with grid_shape: z for Wan2.1 tokenizer)r  r  r	  r  r  r   rI  get_process_group_ranksrQ  r   r   rT  r#  r   r   r   r  r	  r   r   r   r    s   z#WanVAE._initialize_context_parallelr   )r0   r1   r2   r&   bfloat16r&  boolintr   tupler   r  no_gradr   r   propertyr:  r=  r>  r   r"  r4  r#  r!  r   rI  ProcessGroupr  r   r   r   r   r     sV    

^%'


c                   @   s   e Zd Zd,defddZdejdeeef dd	fd
dZe	dd Z
dd Zdd ZdejdejfddZdejdejfddZdedefddZdedefddZe	dd Ze	d d! Ze	d"d# Ze	d$d% Ze	d&d' Ze	d(d) Ze	d*d+ Zd	S )-Wan2pt1VAEInterfaceQ   Fchunk_durationc                 K   sr   | dd| _| dd| _ttjd|| dd| dd| dd	| d
d| dd d| _~|| _d| _d S )Nkeep_decoder_cacheFkeep_encoder_cacher  r   r   r   r   r    r  r	  )r
  r  r   r  r   r   r  r	  )	getre  rf  r   r&   r[  r   rd  cp_initialized)r   rd  r   r   r   r   r   r     s$   




zWan2pt1VAEInterface.__init__r  r	  r@  Nc                 C   s&   | j du sJ d| _ | j|| d S )NFT)rh  r   r  rZ  r   r   r   initialize_context_parallel  s   z/Wan2pt1VAEInterface.initialize_context_parallelc                 C   s   | j jS r.   )r   r
  r   r   r   r   r
    s   zWan2pt1VAEInterface.dtypec                 C   s   d S r.   r   r   r   r   r   reset_dtype  s   zWan2pt1VAEInterface.reset_dtypec                 C   s   | j j   dS )z5Clear the feature cache for both encoder and decoder.N)r   r   r   r   r   r   r     s   zWan2pt1VAEInterface.clear_cacher?  c                 C   s   | j j|| j d}|jd }|dkr#|| j j| | j j| S || j jd d d d d |f | | j jd d d d d |f | S )N)r   r   r   )	r   r   rf  r(   r   rH   r  r  r  )r   r?  latents
num_framesr   r   r   r     s   
 ,zWan2pt1VAEInterface.encoder2  c                 C   s   |j d }|dkr | j|| jj| | jj|  }n,| j|| jjd d d d d |f | | jjd d d d d |f |  }t	|t
rbt|dks[J d|d d}|S )Nr   r   zAssuming batch_size=1 was usedr   )r(   r   r   r  rH   r   r   r  r  r   r#   r   ra   )r   r2  rl  Zreconr   r   r   r     s   
"&"
zWan2pt1VAEInterface.decodenum_pixel_framesc                 C   s   d|d d  S Nr   r    r   )r   rm  r   r   r   get_latent_num_frames     z)Wan2pt1VAEInterface.get_latent_num_framesnum_latent_framesc                 C   s   |d d d S rn  r   )r   rq  r   r   r   get_pixel_num_frames  rp  z(Wan2pt1VAEInterface.get_pixel_num_framesc                 C   r7  r8  r   r   r   r   r   r:  
  r;  z.Wan2pt1VAEInterface.spatial_compression_factorc                 C   r7  r<  r   r   r   r   r   r=    r;  z/Wan2pt1VAEInterface.temporal_compression_factorc                 C   s   | j S r.   )rd  r   r   r   r   pixel_chunk_duration  s   z(Wan2pt1VAEInterface.pixel_chunk_durationc                 C   s   |  | jS r.   )ro  rd  r   r   r   r   latent_chunk_duration  s   z)Wan2pt1VAEInterface.latent_chunk_durationc                 C   r7  )Nr   r   r   r   r   r   	latent_ch  r;  zWan2pt1VAEInterface.latent_chc                 C   r7  )Ni   r   r   r   r   r   spatial_resolution  r;  z&Wan2pt1VAEInterface.spatial_resolutionc                 C   r7  )Nwan2pt1_tokenizerr   r   r   r   r   name"  r;  zWan2pt1VAEInterface.name)rc  F)r0   r1   r2   r]  r   rI  ra  r^  ri  r`  r
  rj  r   r&   r   r   r   ro  rr  r:  r=  rs  rt  ru  rv  rx  r   r   r   r   rb    s2     







rb  )NNr   r   Fr   )4r(  
contextlibr   typingr   r&   torch.distributedrI  torch.nnr:   Ztorch.nn.functional
functionalr)   einopsr   megatron.corer   %cosmos_predict2._src.imaginaire.flagsr   r   %cosmos_predict2._src.imaginaire.utilsr   Z1cosmos_predict2._src.imaginaire.utils.distributedr	   r
   r   -cosmos_predict2._src.imaginaire.utils.easy_ior   Z2cosmos_predict2._src.predict2.tokenizers.interfacer   Z;cosmos_predict2._src.predict2.tokenizers.wan2pt1_2d_pluginsr   Z:cosmos_predict2._src.predict2.utils.tokenizer_benchmarkingr   __all__r_   Conv3dr   Moduler5   rF   rI   ry   r   r   r   r   r   r&  r  r   rb  r   r   r   r   <module>   sX   d%'bc~
k  