o
    2Aiu                     @   sR  d dl Z d dlmZ d dlmZmZmZmZmZ d dl	Z
d dlZd dlZd dlmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZmZ d d
lm Z  d dl!m"Z" ddl#m$Z$ e%e&Z'dZ(dd Z)d(dej*dede+fddZ,eG dd deZ-G dd de e"Z.d)ddZ/dd  Z0d!d" Z1d#e2fd$d%Z3d&d' Z4dS )*    N)	dataclass)CallableDictListOptionalUnion)CLIPImageProcessorCLIPVisionModelWithProjection)PipelineImageInputVaeImageProcessor)AutoencoderKLTemporalDecoder)EulerDiscreteScheduler)
BaseOutputloggingreplace_example_docstring)is_compiled_modulerandn_tensor)DiffusionPipeline)LoraLoaderMixin   ) UNetSpatioTemporalConditionModela  
    Examples:
        ```py
        >>> from diffusers import StableVideoDiffusionPipeline
        >>> from diffusers.utils import load_image, export_to_video

        >>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
        >>> pipe.to("cuda")

        >>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
        >>> image = image.resize((1024, 576))

        >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
        >>> export_to_video(frames, "generated.mp4", fps=7)
        ```
c                 C   s:   || j  }|dk rtd| j  d| d| dd|   S )zNAppends dimensions to the end of a tensor until it has target_dims dimensions.r   z
input has z dims but target_dims is z, which is less).N)ndim
ValueError)xZtarget_dimsZdims_to_append r   N/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/svd/pipelines/svd_pipeline.py_append_dims7   s   
r   npvideo	processoroutput_typec                 C   s   | j \}}}}}g }t|D ]}	| |	 dddd}
||
|}|| q|dkr1t|}|S |dkr<t|}|S |dksGt| d|S )	N   r   r      r   ptpilz9 does not exist. Please choose one of ['np', 'pt', 'pil'])	shaperangepermutepostprocessappendr   stacktorchr   )r   r    r!   
batch_sizechannels
num_framesheightwidthoutputs	batch_idxZ	batch_vidZbatch_outputr   r   r   
tensor2vid@   s   

r4   c                   @   s4   e Zd ZU dZeeeejj  ej	e
jf ed< dS )"StableVideoDiffusionPipelineOutputaG  
    Output class for Stable Video Diffusion pipeline.

    Args:
        frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
            List of denoised PIL images of length `batch_size` or numpy array or torch tensor
            of shape `(batch_size, num_frames, height, width, num_channels)`.
    framesN)__name__
__module____qualname____doc__r   r   PILImager   ndarrayr,   FloatTensor__annotations__r   r   r   r   r5   U   s   
 &	r5   c                '       s$  e Zd ZdZdZdgZdededede	de
f
 fd	d
Zdedeeejf dededejf
ddZdejdeeejf dedefddZdedededejdededefddZdCdejdedefddZd d! Z	"dDdeded#ed$ed%edejdeeejf d&ejdeej fd'd(Zed)d* Z ed+d, Z!ed-d. Z"e# e$e%d/d0d"d1d2d3d4d5d6d"d7d"d"d8d"dgd9fdee&j'j'e(e&j'j' ejf d$ed%edee d:ed;ed<ededededee dee d&eeeje(ej f  deej d=ee d>ee)eee*gd"f  d?e(e d@ef$dAdBZ+  Z,S )EStableVideoDiffusionPipelineaY  
    Pipeline to generate video from an input image using Stable Video Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Args:
        vae ([`AutoencoderKLTemporalDecoder`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
            Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
        unet ([`UNetSpatioTemporalConditionModel`]):
            A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
        scheduler ([`EulerDiscreteScheduler`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images.
    zimage_encoder->unet->vaelatentsvaeimage_encoderunet	schedulerfeature_extractorc                    sH   t    | j|||||d dt| jjjd  | _t| jd| _	d S )N)rB   rC   rD   rE   rF   r   r"   )vae_scale_factor)
super__init__Zregister_moduleslenrB   configblock_out_channelsrG   r   image_processor)selfrB   rC   rD   rE   rF   	__class__r   r   rI   z   s   
z%StableVideoDiffusionPipeline.__init__imagedevicenum_videos_per_promptdo_classifier_free_guidancereturnc                 C   s   t | j j}t|tjs| j|}| j	|}|d d }t
|d}|d d }| j|ddddddj}|j||d}| |j}|d	}|j\}}}	|d	|d	}||| |d
}|rmt|}
t|
|g}|S )N       @      ?)   rX   TFr$   )imagesdo_normalizedo_center_crop	do_resize
do_rescalereturn_tensorsrR   dtyper"   )nextrC   
parametersr`   
isinstancer,   TensorrM   pil_to_numpynumpy_to_pt_resize_with_antialiasingrF   pixel_valuestoimage_embeds	unsqueezer&   repeatview
zeros_likecat)rN   rQ   rR   rS   rT   r`   image_embeddingsZbs_embedseq_len_Znegative_image_embeddingsr   r   r   _encode_image   s6   
	

z*StableVideoDiffusionPipeline._encode_imagec                 C   s   t d|  d|   |j|d}| j|j }t d|  d|   |r9t	|}t
||g}||ddd}|S )Nz
vae_image:,rR   zvae_image_latents:r"   )printmeanstdrj   rB   encodelatent_distmoder,   ro   rp   rm   )rN   rQ   rR   rS   rT   image_latentsZnegative_image_latentsr   r   r   _encode_vae_image   s   
z.StableVideoDiffusionPipeline._encode_vae_imagefpsmotion_bucket_idnoise_aug_strengthr`   r-   c                 C   s|   |||g}| j jjt| }	| j jjj}
|
|	kr#td|
 d|	 dtj	|g|d}|
|| d}|r<t||g}|S )Nz7Model expects an added time embedding vector of length z, but a vector of z was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.r`   r"   )rD   rK   addition_time_embed_dimrJ   add_embeddinglinear_1in_featuresr   r,   tensorrm   rp   )rN   r   r   r   r`   r-   rS   rT   add_time_idspassed_add_embed_dimexpected_add_embed_dimr   r   r   _get_add_time_ids   s   

z.StableVideoDiffusionPipeline._get_add_time_ids   r/   decode_chunk_sizec                 C   s  | dd}d| jjj | }t| jr| jjjn| jj}dtt	|j
 v }g }td|jd |D ]+}||||  jd }i }	|rI||	d< | jj||||  fi |	j}
||
 q4tj|dd}|jd|g|jdd  R  ddddd}| }|S )	Nr   r"   r/   dimra   r   r#      )flattenrB   rK   scaling_factorr   	_orig_modforwardsetinspect	signaturerc   keysr'   r&   decodesampler*   r,   rp   reshaper(   float)rN   rA   r/   r   Zforward_vae_fnZaccepts_num_framesr6   iZnum_frames_inZdecode_kwargsframer   r   r   decode_latents   s    ",z+StableVideoDiffusionPipeline.decode_latentsc                 C   sh   t |tjst |tjjst |tstdt| |d dks'|d dkr2td| d| dd S )Nze`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is    r   z7`height` and `width` have to be divisible by 8 but are z and .)rd   r,   re   r;   r<   listr   type)rN   rQ   r0   r1   r   r   r   check_inputs  s   
z)StableVideoDiffusionPipeline.check_inputsNnum_channels_latentsr0   r1   	generatorc
                 C   s   |||d || j  || j  f}
t|tr't||kr'tdt| d| d|	d u r4t|
|||d}	n|	|}	|	| jj }	|	S )Nr   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.r   rR   r`   )	rG   rd   r   rJ   r   r   rj   rE   init_noise_sigma)rN   r-   r/   r   r0   r1   r`   rR   r   rA   r&   r   r   r   prepare_latents  s"   
z,StableVideoDiffusionPipeline.prepare_latentsc                 C      | j S r   )_guidance_scalerN   r   r   r   guidance_scale@     z+StableVideoDiffusionPipeline.guidance_scalec                 C   s(   t | jttfr| jdkS | j dkS )Nr"   )rd   r   intr   maxr   r   r   r   rT   G  s   
z8StableVideoDiffusionPipeline.do_classifier_free_guidancec                 C   r   r   )_num_timestepsr   r   r   r   num_timestepsM  r   z*StableVideoDiffusionPipeline.num_timestepsi@  i      rW   g      @      g{Gz?r"   r%   Tnum_inference_stepsmin_guidance_scalemax_guidance_scaler!   callback_on_step_end"callback_on_step_end_tensor_inputsreturn_dictc           )      C   s  |p	| j jj| j }|p| j jj| j }|dur|n| j jj}|dur%|n|}| ||| t|tjjr8d}nt|t	rBt
|}n|jd }| j}|| _| |||| j}|d }| jj|||d|}td| | f  t|
 t|j|||jd}||
|  }| jjtjko| jjj}|r| jjtjd | j|||| jd}||j}|r| jjtjd |dd|ddd}|  ||	|
|j||| j}||}| j!j"||d	 | j!j#}| j jj$}| %|| |||||j|||	}t&|||d}|||j}||| d}t'||j(}|| _t
||| j!j)  }t
|| _*| j+|d
}t,|D ]\}} | jrBt-|gd n|}!| j!.|!| }!tj-|!|gdd}!| j |!| ||ddd }"| jrt|"/d\}#}$|#| j0|$|#   }"| j!1|"| |j2}|duri }%|D ]
}&t3 |& |%|&< q|| || |%}'|'4d|}|t
|d ks|d |kr|d | j!j) dkr|5  q2W d   n	1 sw   Y  |dks|r| jjtjd | 6|||}(t7|(| j|d}(n|}(| 8  |s|(S t9|(dS )a  
        The call function to the pipeline for generation.

        Args:
            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
                Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_frames (`int`, *optional*):
                The number of video frames to generate. Defaults to `self.unet.config.num_frames`
                (14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps. More denoising steps usually lead to a higher quality video at the
                expense of slower inference. This parameter is modulated by `strength`.
            min_guidance_scale (`float`, *optional*, defaults to 1.0):
                The minimum guidance scale. Used for the classifier free guidance with first frame.
            max_guidance_scale (`float`, *optional*, defaults to 3.0):
                The maximum guidance scale. Used for the classifier free guidance with last frame.
            fps (`int`, *optional*, defaults to 7):
                Frames per second. The rate at which the generated images shall be exported to a video after generation.
                Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
            motion_bucket_id (`int`, *optional*, defaults to 127):
                Used for conditioning the amount of motion for the generation. The higher the number the more motion
                will be in the video.
            noise_aug_strength (`float`, *optional*, defaults to 0.02):
                The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
            decode_chunk_size (`int`, *optional*):
                The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
                quality. For lower memory usage, reduce `decode_chunk_size`.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of videos to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `pil`, `np` or `pt`.
            callback_on_step_end (`Callable`, *optional*):
                A function that is called at the end of each denoising step during inference. The function is called
                with the following arguments:
                    `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
                `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
        Nr"   r   )r0   r1   zpure image: r   r   )rR   rS   rT   rv   )totalr   r   F)encoder_hidden_statesadded_time_idsr   rA   latent)r!   )r6   ):rD   rK   sample_sizerG   r/   r   rd   r;   r<   r   rJ   r&   Z_execution_devicer   rt   rT   rM   
preprocessrj   rw   rx   ry   r   r`   rB   r,   float16force_upcastfloat32r~   rl   rm   r   rE   set_timesteps	timestepsin_channelsr   linspacer   r   orderr   progress_bar	enumeraterp   scale_model_inputchunkr   stepprev_samplelocalspopupdater   r4   Zmaybe_free_model_hooksr5   ))rN   rQ   r0   r1   r/   r   r   r   r   r   r   r   rS   r   rA   r!   r   r   r   r-   rR   rq   noiseZneeds_upcastingr}   r   r   r   r   num_warmup_stepsr   r   tZlatent_model_inputZ
noise_predZnoise_pred_uncondZnoise_pred_condZcallback_kwargskZcallback_outputsr6   r   r   r   __call__Q  s   W



	

	
6
%
z%StableVideoDiffusionPipeline.__call__)r   r   )-r7   r8   r9   r:   Zmodel_cpu_offload_seqZ_callback_tensor_inputsr   r	   r   r   r   rI   r
   r   strr,   rR   r   boolr>   rt   re   r~   r   r`   r   r   r   	Generatorr   r   propertyr   rT   r   no_gradr   EXAMPLE_DOC_STRINGr;   r<   r   r   r   r   __classcell__r   r   rO   r   r@   c   s   
1

	

"


	
r@   bicubicTc           
      C   s   | j dd  \}}||d  ||d  f}t|d d d dt|d d d df}ttd|d  dttd|d  df}|d d	 dkrS|d d |d f}|d d	 dkre|d |d d f}t| ||} tjjj| |||d
}	|	S )Nr   r"   rW   rV   gMbP?g      @r#   r   )sizer|   align_corners)r&   r   r   _gaussian_blur2dr,   nn
functionalinterpolate)
inputr   interpolationr   hwfactorssigmasksoutputr   r   r   rh   D  s   ,rh   c                 C   s   t | dk r
t| dd | D }dt |  dg }tt | D ]!}||d   }|d }|| }||d| d < ||d| d < q |S )zCompute padding tuple.r   c                 S   s   g | ]}|d  qS )r"   r   ).0r   r   r   r   
<listcomp>g  s    z$_compute_padding.<locals>.<listcomp>r   r"   )rJ   AssertionErrorr'   )kernel_sizecomputedout_paddingr   computed_tmp	pad_frontpad_rearr   r   r   _compute_paddinga  s   r   c                 C   s   | j \}}}}|d d d df j| j| jd}|d|dd}|j dd  \}}t||g}	tjjj	| |	dd} |
dd||}| d|d| d| d} tjjj| ||dddd	}
|
||||}|S )
N.r_   ra   r   reflect)r|   r"   r   )groupspaddingstride)r&   rj   rR   r`   expandr   r,   r   r   padr   rn   r   conv2d)r   kernelbcr   r   
tmp_kernelr0   r1   padding_shaper   outr   r   r   	_filter2dx  s    "r  window_sizec                 C   s   t |trt|gg}|jd }tj| |j|jd| d  |d}| d dkr-|d }t	|
d d|
d  }||jddd S )	Nr   r_   r   ra   g      ?rV   T)keepdim)rd   r   r,   r   r&   arangerR   r`   r   exppowsum)r  sigmar-   r   gaussr   r   r   	_gaussian  s   

$ r  c           
      C   s   t |trtj|g| jd}n|j| jd}t|d t|d }}|jd }t||d d df 	|d}t||d d df 	|d}t
| |dd d d f }t
||d }	|	S )Nr   r   r"   .).N)rd   tupler,   r   r`   rj   r   r&   r  rn   r  )
r   r   r  kykxbskernel_xkernel_yout_xr  r   r   r   r     s   

r   )r   )r   T)5r   dataclassesr   typingr   r   r   r   r   numpyr   	PIL.Imager;   r,   transformersr   r	   Zdiffusers.image_processorr
   r   Zdiffusers.modelsr   Zdiffusers.schedulersr   diffusers.utilsr   r   r   Zdiffusers.utils.torch_utilsr   r   Z"diffusers.pipelines.pipeline_utilsr   Z	LoRA.lorar   modelsr   
get_loggerr7   loggerr   r   re   r   r4   r5   r@   rh   r   r  r   r  r   r   r   r   r   <module>   s<   
	   
d