o
    iT                     @   sJ  d Z ddlZddlZddlZddlmZ ddlZddlZddl	m
  mZ ddlmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ dd	lmZmZ dd
lmZmZ ee j j d Z!ej"#de$e! ddl%m&Z&m'Z'm(Z( ddlm)Z) dZ*dZ+dZ,dZ-G dd de)Z.G dd de)Z/dd Z0dd Z1e2dkre1  dS dS )a  Fine-tune Stable Video Diffusion (SVD) on LIBERO videos with LoRA.

Loads the pretrained SVD img2vid model, adds LoRA adapters to the UNet,
and trains on LIBERO parsed frames. VAE and CLIP image encoder are frozen.

Usage:
    CUDA_VISIBLE_DEVICES=4 python video_training/svd_finetune/train.py         --data-root /data/libero/parsed_libero/libero_spatial         --svd-path /data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid         --log_wandb --run-name svd_libero_spatial
    NPath)
DataLoader)	rearrange)tqdm)AutoencoderKLTemporalDecoder UNetSpatioTemporalConditionModelEulerDiscreteScheduler)compute_snr)CLIPVisionModelWithProjectionCLIPImageProcessor)
LoraConfigget_peft_modelZunified_video_action)LiberoVideoDatasetcollate_batch_natural_sort_key)DatasetzJ/data/cameron/vidgen/Ctrl-World/checkpoints/stable-video-diffusion-img2vid   i   c                   @   4   e Zd ZdZeeeddfddZdd Zdd	 Z	dS )
Smith300VideoDatasetu  Smith300 mac_robot_datasets episodes defined by rgb_overlay/episodes.json.

    Layout:
        <root>/rgb_NNNNNN.jpg                  — flat per-frame jpgs
        <root>/rgb_overlay/episodes.json       — defines [start, end] (inclusive) per episode
       Nc              
      s  dd l }ddlm  t|ttfr fdd|D }nt|tr1d|v r1 fdd|dD }n | g}|| _	|| _
|| _|| _|| _|d | d | _g | _|D ]nd d	 }	|	 shtd
|	 ||	 }
|
d D ]O}t|d t|d }}fddt||d D }dd |D }j d|d  }t|| jkr| j||f qstd| dt| d| j d qsqU|d ur| jd | | _d S d S )Nr   r   c                    s   g | ]} |  qS  resolve.0rr   r   7/data/cameron/para/video_training/svd_finetune/train.py
<listcomp>8       z1Smith300VideoDataset.__init__.<locals>.<listcomp>,c                    s   g | ]
}|r |  qS r   r   r   r   r   r   r   :   s    r   Zrgb_overlayzepisodes.jsonzMissing episodes.json: episodesstartendc                    s   g | ]} d |dd qS )Zrgb_Z06dz.jpgr   )r   i)r   r   r   r   L   s    c                 S   s   g | ]}|  r|qS r   )is_filer   pr   r   r   r   M   r   /idz  [smith300] skipping z: only z frames (need ))jsonpathlibr   
isinstancelisttuplestrsplitr   roots
num_framesheightwidthframe_stride
min_framesr!   r%   FileNotFoundErrorloads	read_textintrangenamelenappendprint)selfrootr3   r4   r5   r6   max_samplesr+   r2   Zepisodes_jsonmetaZepr"   r#   Zframe_pathstagr   )r   r   r   __init__2   sB   
zSmith300VideoDataset.__init__c                 C   
   t | jS Nr>   r!   rA   r   r   r   __len__W      
zSmith300VideoDataset.__len__c                 C   s   dd l }ddlm} | j| \}}t|| j }|dkr"|d|nd}g }t| jD ]<}	|||	| j	   }
|t
|
}|jd dkrI|d d }| d }tj|d| j| jfdddd}|| q+tj|d	d
}|d d }|dS Nr   )
read_image      g     o@bilinearFsizemodeZalign_cornersr   Zdim       @      ?ZrandomZtorchvision.iorN   r!   r>   r7   Zrandintr<   r3   r6   r0   shapefloatFinterpolate	unsqueezer4   r5   Zsqueezer?   torchZstack)rA   idx_rngrN   Z_ep_idjpgs	max_startr"   frameskr'   imgoutr   r   r   __getitem__Z   s*   
z Smith300VideoDataset.__getitem__
__name__
__module____qualname____doc__
NUM_FRAMES
IMG_HEIGHT	IMG_WIDTHrF   rK   rg   r   r   r   r   r   *   s    
%r   c                   @   r   )
LiberoVideoDatasetRectzELike LiberoVideoDataset but outputs non-square (H, W) frames for SVD.r   Nc                 C   s   ddl m} || | _|| _|| _|| _|| _|d | d | _g | _	t
| j D ]4}| s3q,t
| D ]&}	|	 s@q9|	d }
|
 sIq9t
|
dtd}t|| jkr_| j	| q9q,|d uro| j	d | | _	d S d S )Nr   r   r   rc   z*.png)key)r,   r   r   rB   r3   r4   r5   r6   r7   r!   sortediterdiris_dirglobr   r>   r?   )rA   rB   r3   r4   r5   r6   rC   r   Ztask_dirZdemo_dirZ
frames_dirpngsr   r   r   rF   r   s2   	zLiberoVideoDatasetRect.__init__c                 C   rG   rH   rI   rJ   r   r   r   rK      rL   zLiberoVideoDatasetRect.__len__c                 C   s   dd l }ddlm} | j| }t|| j }|d|}g }t| jD ]<}|||| j	   }	|t
|	}
|
jd dkrA|
d d }
|
 d }
tj|
d| j| jfdddd}
||
 q#tj|d	d
}|d d }|dS rM   rX   )rA   r_   r`   rN   rv   rb   r"   rc   rd   r'   re   rf   r   r   r   rg      s*   

z"LiberoVideoDatasetRect.__getitem__rh   r   r   r   r   rp   o   s    
rp   c           	      C   s   |d d }t jg d|ddddd}t jg d|ddddd}tj|dd	d
d}|| | }t   | |j}W d   n1 sGw   Y  |dS )zEncode conditioning image through CLIP.

    Args:
        pixel_values: (B, 3, H, W) in [-1, 1]
    Returns:
        image_embeddings: (B, 1, 1024) CLIP embeddings
    rW   rV   )g3<4'?gwgM?gy{ ?devicer   rP   )gB91?gwt.?g	U?)   ry   rQ   FrR   N)r^   tensorviewr[   r\   no_gradZimage_embedsr]   )	image_encoderfeature_extractorZpixel_valuesrx   Z	images_01Z	clip_meanZclip_stdZimages_clipimage_embeddingsr   r   r   encode_image_clip   s   	

r   c            J      C   s  t jdd} | jdtdddgd | jdtdd	 | jd
ttd	 | jdtdd	 | jdtdd	 | jdtdd	 | jdtdd	 | jdtdd	 | jdtdd	 | jdtdd	 | jddd | jdtdd	 | jdtdd	 | jdtd d	 | jd!td"d	 | jd#td$d	 | jd%dd&d' |  }d(d)lm	} ||j
d*}|j}|jrtjntj}t|j}|jr|jd&d&d+ |jr|jrd(d l}|jd,t||jd-d. t|j}td/ tjt|d0|d1}	|	|  |	 D ]}
d2|
_ qtd3 t!jt|d4|d1}||  | D ]}
d2|
_ qt"jt|d5d6}td7 t#jt|d8|d1}|| td9t$d:d; | D d<d= t%jt|d>d6}t&|j'|j'g d?d@dA}t(||}|)  |*  |j+dkrmt,nt-}||j.t/t0t1|j2dB}tdC|j+ dDt3| dEt/ dFt0 dGt1 dH|j2 dI t4||j5d&|j6t7d&|j6d(k|j6d(krdnd dJ}dKdL | D }tj8j9||j:ddM}|;|||\}}}d(}t<|j=D ]}t>|dN| |j dO}t?|D ]x\}}|@|E |jA\}}}}}tjBjCjD|jdP |d d d d d(f }tE||||} tF|dQ}!tG  |	H|!|jIJ }"W d    n	1 s6w   Y  tF|"dR|dS}"|"|	jKjL }"dT}#tG % |||#tM||  }$|	H|$jIN }%|%|	jKjL }%W d    n	1 sww   Y  |%OdPdU|dUdUdU}%tjQ|||dV}&|&dW dX R }'|'S|jKjT|jKjU}'tM|"}(|'V|dddd})|"|(|)  }*dY|)dZ d W  }+|*|+ },tjX|,|%gdZd[}-tjYd\d]|#gg| ||dV}.||-|'| |.d^jJ}/dY|)dZ d  }0|) |)dZ d W  }1|"|0|*  |1 }2dY|)dZ d  }3|3|/|2 dZ  Z }4W d    n	1 sw   Y  |[|4 |\  |]  W d    n	1 s9w   Y  |jr`|j^|4_ d_|d` |jr`d(d l}|j`da|4_ i|db |jr||ja d(kr|jr|d(krtdc| ddd&de zUd(d l}d(d lbmc}5 |d|}6|6  tG . tjBjCjD|jdP |d dd d d(f |}7tE|||7|}8dT}#|7|#tM|7  }9|	H|9jIN |	jKjL }:|:OdPdU|dUdUdU};tjYd\d]|#gg||dV}<t%jt|d>d6}=|=jedf|dg tjQd|d|:jAdZ |:jAd" ||dV|=jf }>|=jgD ]&}?|=h|>|?}@tjX|@|;gdZd[}A|6|A|?Od(|8|<d^jJ}B|=\|B|?|>ji}>qtF|>dh}C|C|	jKjL }C|	jj|C|dijJ}D|D k dY dj Sd(d}D|Dld(dZd"dm dk ndl}Etojpdmd2dn}F|Fjq}GW d    n	1 s~w   Y  |5jr|G|Eddodpddqdrgds |j`dt|js|Gdudvi|db t|Gjtd&dw tdx| d&de W d    n	1 sw   Y  W d    n	1 sw   Y  |6*  W n. tuy }H z!tdy| dz|H d&de d(d lv}I|Iw  |d|*  W Y d }H~Hnd }H~Hww |jrY|d(krY||jx d(krYz$|d|yt|d{  tz||{ d||d}  td~| d&de W n tuyX }H ztd| dz|H d&de W Y d }H~Hnd }H~Hww |d7 }q|jrmtd| d|  q|jr|jrd(d l}||  |jrtd|  d S d S )Nz!Fine-tune SVD on LIBERO with LoRA)Zdescriptionz	--datasetZliberoZsmith300)typedefaultZchoicesz--data-rootz)/data/libero/parsed_libero/libero_spatial)r   r   z
--svd-pathz--batch-sizer   z--gradient-accumulationrO   z--lrg-C6?z--epochsd   z	--workersz--devicecudaz
--run-nameZsvd_libero_spatialz--log_wandbZ
store_true)actionz--vis-every   z--checkpoint-everyi  z--checkpoint-dirz'video_training/svd_finetune/checkpointsz--frame-striderP   z--lora-rank   z--mixed-precisionT)r   r   r   )Accelerator)Zgradient_accumulation_steps)parentsexist_okZsvd_finetuneZonline)Zprojectconfigr=   rT   zLoading SVD components...vae)	subfolderZtorch_dtypeFz  VAE loaded (frozen)r}   r~   )r   z$  CLIP image encoder loaded (frozen)unetz  UNet loaded: c                 s   s    | ]}|  V  qd S rH   )Znumelr&   r   r   r   	<genexpr>   s    zmain.<locals>.<genexpr>r    z paramsZ	scheduler)Zto_qZto_kZto_vzto_out.0g        )r   Z
lora_alphaZtarget_modulesZlora_dropout)rB   r3   r4   r5   r6   zDataset[z]: z episodes (z
 frames @ xz	, stride r*   )
batch_sizeZshuffleZnum_workersZ
collate_fnZ
pin_memoryZpersistent_workersZprefetch_factorc                 S   s   g | ]}|j r|qS r   )requires_gradr&   r   r   r   r     s    zmain.<locals>.<listcomp>)lrZweight_decayzepoch )ZdescZdisable)enabledzb c t h w -> (b t) c h wz(b t) c h w -> b t c h w)bg{Gz?)rx   dtypeg?gffffff?rW      rU   g      @g     _@)Zencoder_hidden_statesadded_time_idsz.4f)lossstepz
train/loss)r   z 
[vis] Generating video at step z...)flush   rw   zb t c h w -> (b t) c h w)r3   rV      Zuint8z.mp4)suffixdeleteZlibx264   z	-movflagsz
+faststart)ZfpsZcodecZqualityZmacro_block_sizeZffmpeg_paramszvis/predicted_videoZmp4)format)
missing_okz[vis] Logged video at step z[vis] ERROR at step z: Z	unet_lora)r   Z	optimizerz	latest.ptz[ckpt] Saved at step z[ckpt] ERROR at step zEpoch z done, step=zDone. Checkpoints at )}argparseZArgumentParserZadd_argumentr0   SVD_DEFAULT_PATHr;   rZ   Z
parse_argsZ
accelerater   Zgradient_accumulationrx   Zmixed_precisionr^   Zfloat16Zfloat32r   Zcheckpoint_dirZis_main_processmkdirZ	log_wandbwandbinitvarsZrun_namesvd_pathr@   r   Zfrom_pretrainedtoevalZ
parametersr   r   r   r   sumr	   r   Z	lora_rankr   Zprint_trainable_parameterstraindatasetr   rp   Z	data_rootrm   rn   ro   r6   r>   r   r   Zworkersr   ZoptimZAdamWr   Zpreparer<   Zepochsr   	enumerate
accumulaterY   r   ZampZautocastr   r   r|   encodeZlatent_distZsampler   Zscaling_factorZ
randn_likerT   r]   expandZrandnZexpZclampZ	sigma_minZ	sigma_maxr{   Zsqrtcatrz   ZmeanZbackwardr   Z	zero_gradZset_postfixitemlogZ	vis_everyZ
imageio.v2Zv2Zunwrap_modelZset_timestepsZinit_noise_sigmaZ	timestepsZscale_model_inputZprev_sampledecodeZcpuZpermutenumpyZastypetempfileZNamedTemporaryFiler=   ZmimwriteZVideounlink	Exception	tracebackZ	print_excZcheckpoint_everyZsave_pretrainedZsaveZ
state_dictZfinish)Jr'   argsr   Zacceleratorrx   r   Zckpt_dirr   r   r   Zparamr}   r~   r   Znoise_schedulerZlora_configZ
DatasetClsr   loaderZtrainable_paramsZoptZglobal_stepZepochZpbarZ	batch_idxZbatchBCTHWZ
cond_imager   rc   ZlatentsZ	noise_augZcond_img_augZcond_latentZ
rnd_normalZsigmaZnoiseZsigma_bcZnoisy_latentsZc_inZscaled_noisyZ
unet_inputr   Z
model_predZc_skipZc_outtargetZweightr   ZimageioZunet_for_visZcond_imgZembZcond_augZcond_latZcond_repZ	added_idsZvis_schedulerzZt_stepZz_scaledZz_inputZpredZz_flatZdecodedZ	frames_npfZtmp_pather   r   r   r   main   s  



$
6







>
F

"

 ,$r   __main__)3rl   r   sysr   r,   r   r   Znpr^   Ztorch.nn.functionalZnnZ
functionalr[   Ztorch.utils.datar   Zeinopsr   r   Z	diffusersr   r   r	   Zdiffusers.training_utilsr
   Ztransformersr   r   Zpeftr   r   __file__r   parentZUVA_ROOTpathinsertr0   Zsimple_uva.datasetr   r   r   r   r   rm   rn   ro   r   rp   r   r   ri   r   r   r   r   <module>   s@    E4  	
