# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from einops import rearrange
from hydra.core.config_store import ConfigStore
from omegaconf import ListConfig

from cosmos_predict2._src.imaginaire.lazy_config import LazyCall as L
from cosmos_predict2._src.imaginaire.lazy_config import LazyDict
from cosmos_predict2._src.imaginaire.utils import log
from cosmos_predict2._src.imaginaire.utils.context_parallel import broadcast_split_tensor
from cosmos_predict2._src.imaginaire.utils.validator import Validator
from cosmos_predict2._src.predict2.conditioner import Text2WorldCondition, TextAttr
from cosmos_predict2._src.predict2.configs.video2world.defaults.conditioner import (
    _SHARED_CONFIG,
    GeneralConditioner,
    ReMapkey,
    Video2WorldCondition,
)
from cosmos_predict2._src.predict2_multiview.conditioner import MVTextAttr


class ConditionLocation(Enum):
    """
    Enum representing different camera condition locations for anymulti-to-multiview video generation.

    Attributes:
        NO_CAM: Indicates no camera is used for conditioning (i.e text2world)
        REF_CAM: Indicates a reference camera is used for conditioning. (i.e single-to-multiview-text2world)
        ANY_CAM: Indicates any camera can be used for conditioning. (i.e any-to-multiview-text2world)
        FIRST_RANDOM_N: Indicates a random number of frames from all cameras are used for conditioning. (i.e video2world-multiview)

    Note: Multiple locations can be set together when compatible.
        - NO_CAM cannot be set with any other location.
        - ANY_CAM and REF_CAM cannot be set simultaneously.
        - FIRST_RANDOM_N can be set with ANY_CAM or REF_CAM.
    """

    NO_CAM = "no_cam"
    REF_CAM = "ref_cam"
    ANY_CAM = "any_cam"
    FIRST_RANDOM_N = "first_random_n"


class ConditionLocationListValidator(Validator):
    """
    Validator for a list of ConditionLocation objects.
    Validates that:
        - NO_CAM is not set with any other location
        - ANY_CAM and REF_CAM are not set together
    """

    def __init__(self, default: List[ConditionLocation], hidden=False, tooltip=None):
        self.default = default
        self.hidden = hidden
        self.tooltip = tooltip

    def validate(self, value: List[ConditionLocation]):
        for v in value:
            if not isinstance(v, ConditionLocation):
                raise TypeError(f"All elements must be ConditionLocation enums, got {type(v)}: {v}")
        if ConditionLocation.NO_CAM in value:
            assert len(value) == 1, f"Cannot set ConditionLocation.NO_CAM and other locations together. Got {value=}"
        elif ConditionLocation.ANY_CAM in value and ConditionLocation.REF_CAM in value:
            raise ValueError("ConditionLocation.ANY_CAM and ConditionLocation.REF_CAM cannot be set together.")
        return value

    def __repr__(self) -> str:
        return f"ConditionLocationValidator({self.default=}, {self.hidden=})"

    def json(self):
        return {
            "type": ConditionLocationListValidator.__name__,
            "default": self.default,
            "tooltip": self.tooltip,
        }


class ConditionLocationList(list):
    def __init__(self, locations: List[ConditionLocation]):
        enum_locations = []
        for loc in locations:
            if not isinstance(loc, ConditionLocation):
                loc = ConditionLocation(loc)  # Will raise ValueError if invalid
            enum_locations.append(loc)
        super().__init__(enum_locations)
        self.validator = ConditionLocationListValidator(default=[])
        self.validator.validate(self)

    def __repr__(self) -> str:
        return f"ConditionLocationList({super().__repr__()})"

    def to_json(self):
        return {
            "type": ConditionLocationList.__name__,
            "locations": [location.value for location in self],
        }


@dataclass(frozen=True)
class MultiViewCondition(Video2WorldCondition):
    state_t: Optional[int] = None
    view_indices_B_T: Optional[torch.Tensor] = None
    ref_cam_view_idx_sample_position: Optional[torch.Tensor] = None

    def set_video_condition(
        self,
        state_t: int,
        gt_frames: torch.Tensor,
        condition_locations: Union[ConditionLocationList, ListConfig] = field(
            default_factory=lambda: ConditionLocationList([])
        ),
        random_min_num_conditional_frames_per_view: Optional[int] = None,
        random_max_num_conditional_frames_per_view: Optional[int] = None,
        num_conditional_frames_per_view: Optional[int | List[int]] = None,
        condition_cam_idx: Optional[int] = None,
        view_condition_dropout_max: int = 0,
        conditional_frames_probs: Optional[Dict[int, float]] = None,
    ) -> "MultiViewCondition":
        """
        Sets the video conditioning frames for anymulti-to-multiview generation.

        This method creates a conditioning mask for the input video frames that determines
        which frames will be used as context frames for generating new frames. The method
        handles video batches (T>1) and does not support images (T=1).

        Args:
            gt_frames: A tensor of ground truth frames with shape [B, C, T, H, W], where:
                B = batch size
                C = number of channels
                T = number of frames per view * self.sample_n_views
                H = height
                W = width

            random_min_num_conditional_frames_per_view: Minimum number of frames per view to use for conditioning
                when randomly selecting a number of conditioning frames.

            random_max_num_conditional_frames_per_view: Maximum number of frames per view to use for conditioning
                when randomly selecting a number of conditioning frames.

            num_conditional_frames_per_view: Optional[int | List[int]]; If provided, all examples in the batch will use
                exactly this many frames per view for conditioning. If None, a random number of frames per view
                between random_min_num_conditional_frames_per_view and random_max_num_conditional_frames_per_view
                will be selected for each example in the batch. Can also be a list of integers, one for each view.

            condition_cam_idx: Optional; Used only if ConditionLocation.ANY_CAM is in condition_locations.
                If provided, all examples in the batch will use the same cam_idx for conditioning. If None,
                a random cam_idx will be selected for each example in the batch.
            view_condition_dropout_max: Optional; If provided and > 0, then a random number of views will be dropped from the conditioning.

            conditional_frames_probs: Optional; Dictionary mapping number of frames to probabilities.
                If provided, overrides the random_min/max_num_conditional_frames with weighted sampling.
                Example: {0: 0.5, 1: 0.25, 2: 0.25} for 50% chance of 0 frames, 25% for 1, 25% for 2.

        Returns:
            A new MultiViewCondition object with the gt_frames and conditioning mask set.
            The conditioning mask (condition_video_input_mask_B_C_T_H_W) is a binary tensor
            of shape [B, 1, T, H, W] where 1 indicates frames used for conditioning and 0
            indicates frames to be generated.

        Notes:
            - Image batches are not supported.
            - For video batches multiple condition_locations can be provided and combined:
                - If num_conditional_frames_per_view is provided and "random_n" is in condition_locations,
                then all examples will use the same number of frames per view for conditioning,
                otherwise, if num_conditional_frames_per_view is not provided,
                then each example will randomly uses between random_min_num_conditional_frames_per_view
                and random_max_num_conditional_frames_per_view frames per view.
                - If "ref_cam" is in condition_locations, then for each example,
                all frames of the first view will be used for conditioning.
        """
        kwargs = self.to_dict(skip_underscore=False)
        kwargs["state_t"] = state_t
        kwargs["gt_frames"] = gt_frames
        B, _, T, H, W = gt_frames.shape

        if not isinstance(condition_locations, ConditionLocationList):
            condition_locations = ConditionLocationList(condition_locations)
        assert len(condition_locations) > 0, "condition_locations must be provided."
        assert state_t is not None, "state_t must be provided."
        assert T > 1, "Image batches are not supported."
        assert T % state_t == 0, f"T must be a multiple of state_t. Got T={T} and state_t={state_t}."
        sample_n_views = T // state_t
        condition_video_input_mask_B_C_V_T_H_W = torch.zeros(
            B, 1, sample_n_views, state_t, H, W, dtype=gt_frames.dtype, device=gt_frames.device
        )
        views_eligible_for_dropout = list(range(sample_n_views))

        if ConditionLocation.REF_CAM in condition_locations:
            ref_cam_view_idx_sample_position = kwargs["ref_cam_view_idx_sample_position"]
            ref_cam_idx_B = (
                torch.ones(B, dtype=torch.int32, device=ref_cam_view_idx_sample_position.device)
                * ref_cam_view_idx_sample_position
            )
            condition_video_input_mask_B_C_V_T_H_W = self.enable_ref_cam_condition(
                ref_cam_idx_B, condition_video_input_mask_B_C_V_T_H_W
            )
            assert (ref_cam_view_idx_sample_position == ref_cam_view_idx_sample_position[0]).all(), (
                f"ref_cam_view_idx_sample_position must be the same for all examples. Got {ref_cam_view_idx_sample_position=}"
            )
            ref_cam_view_idx_sample_position_int = ref_cam_view_idx_sample_position[0].cpu().item()
            views_eligible_for_dropout.remove(ref_cam_view_idx_sample_position_int)
        elif ConditionLocation.ANY_CAM in condition_locations:
            if condition_cam_idx is None:
                assert kwargs["view_indices_B_T"].shape[-1] % sample_n_views == 0, (
                    f"view_indices_B_T last dimension must be a multiple of sample_n_views. Got view_indices_B_T.shape={kwargs['view_indices_B_T'].shape}, sample_n_views={sample_n_views}"
                )
                view_indices = kwargs["view_indices_B_T"]
                selected_cam_latent_t_index = torch.randint(0, state_t, size=(B,))
                any_cam_idx_B = view_indices[torch.arange(B), selected_cam_latent_t_index]
            else:
                any_cam_idx_B = torch.full((B,), condition_cam_idx, dtype=torch.int32)
            condition_video_input_mask_B_C_V_T_H_W = self.enable_ref_cam_condition(
                any_cam_idx_B, condition_video_input_mask_B_C_V_T_H_W
            )
            assert (any_cam_idx_B == any_cam_idx_B[0]).all(), (
                f"any_cam_idx_B must be the same for all examples. Got {any_cam_idx_B=}"
            )
            any_cam_idx_B_int = any_cam_idx_B[0].cpu().item()
            views_eligible_for_dropout.remove(any_cam_idx_B_int)
        if ConditionLocation.FIRST_RANDOM_N in condition_locations:
            if (
                num_conditional_frames_per_view is None
                and random_min_num_conditional_frames_per_view == random_max_num_conditional_frames_per_view
            ):
                num_conditional_frames_per_view = random_min_num_conditional_frames_per_view
            if num_conditional_frames_per_view is not None:
                if isinstance(num_conditional_frames_per_view, list):
                    assert len(num_conditional_frames_per_view) == sample_n_views, (
                        f"num_conditional_frames_per_view must be a list of length {sample_n_views}. Got {num_conditional_frames_per_view=}"
                    )
                    log.info(
                        f"Setting num_conditional_frames_per_view_B_V explicitly from list: {num_conditional_frames_per_view}"
                    )
                    num_conditional_frames_per_view_B_V = torch.tensor(
                        num_conditional_frames_per_view, dtype=torch.int32
                    ).repeat(B, 1)
                else:
                    num_conditional_frames_per_view_B_V = (
                        torch.ones((B, sample_n_views), dtype=torch.int32) * num_conditional_frames_per_view
                    )
            elif conditional_frames_probs is not None:
                # Use weighted sampling based on provided probabilities
                frames_options = list(conditional_frames_probs.keys())
                weights = list(conditional_frames_probs.values())
                num_conditional_frames_per_view_B_V = (
                    torch.tensor(random.choices(frames_options, weights=weights, k=B), dtype=torch.int32)
                    .view(B, 1)
                    .repeat(1, sample_n_views)
                )
            else:
                assert (
                    random_min_num_conditional_frames_per_view is not None
                    and random_max_num_conditional_frames_per_view is not None
                ), (
                    f"random_min_num_conditional_frames_per_view and random_max_num_conditional_frames_per_view must be provided if num_conditional_frames_per_view is None. Got {random_min_num_conditional_frames_per_view=}, {random_max_num_conditional_frames_per_view=}, {num_conditional_frames_per_view=}"
                )
                num_conditional_frames_per_view_B_V = torch.randint(
                    random_min_num_conditional_frames_per_view,
                    random_max_num_conditional_frames_per_view + 1,
                    size=(B, 1),
                ).repeat(1, sample_n_views)
            condition_video_input_mask_B_C_V_T_H_W = self.enable_first_random_n_condition(
                condition_video_input_mask_B_C_V_T_H_W, num_conditional_frames_per_view_B_V
            )
        if view_condition_dropout_max > 0:
            random.shuffle(views_eligible_for_dropout)
            n_views_to_dropout = random.randint(0, view_condition_dropout_max)
            views_to_dropout = views_eligible_for_dropout[:n_views_to_dropout]
            for view_idx in views_to_dropout:
                condition_video_input_mask_B_C_V_T_H_W[:, :, view_idx] = 0

        condition_video_input_mask_B_C_T_H_W = rearrange(
            condition_video_input_mask_B_C_V_T_H_W, "B C V T H W -> B C (V T) H W", V=sample_n_views
        )
        kwargs["condition_video_input_mask_B_C_T_H_W"] = condition_video_input_mask_B_C_T_H_W
        return type(self)(**kwargs)

    def enable_ref_cam_condition(self, cam_idx_B: torch.Tensor, condition_video_input_mask_B_C_V_T_H_W: torch.Tensor):
        """
        Sets condition video input mask to 1 for all frames of the cam_idx[i] view in each example i
        Args:
            cam_idx_B: A tensor of shape [B]
            condition_video_input_mask_B_C_V_T_H_W: A tensor of shape [B, 1, V, T, H, W]
            where V is the number of views, T is the number of frames per view, H is the height, and W is the width
        Returns:
            A copy of the condition video input mask with the cam_idx[i] view set to 1 for example i
        """
        assert condition_video_input_mask_B_C_V_T_H_W.ndim == 6, (
            f"condition_video_input_mask_B_C_V_T_H_W must have 6 dimensions. Got {condition_video_input_mask_B_C_V_T_H_W.shape=}"
        )
        assert cam_idx_B.ndim == 1, f"cam_idx_B must have 1 dimension. Got {cam_idx_B.shape=}"
        copy_condition_video_input_mask_B_C_V_T_H_W = condition_video_input_mask_B_C_V_T_H_W.clone()
        for i in range(copy_condition_video_input_mask_B_C_V_T_H_W.shape[0]):
            copy_condition_video_input_mask_B_C_V_T_H_W[i, :, cam_idx_B[i]] = 1
        return copy_condition_video_input_mask_B_C_V_T_H_W

    def enable_first_random_n_condition(
        self, condition_video_input_mask_B_C_V_T_H_W: torch.Tensor, num_conditional_frames_per_view_B_V: torch.Tensor
    ):
        """
        Sets condition video input mask to 1 for the first num_conditional_frames_per_view_B frames of each view
        Args:
            condition_video_input_mask_B_C_V_T_H_W: A tensor of shape [B, 1, V, T, H, W]
            num_conditional_frames_per_view_B_V: A tensor of shape [B, V]
        Returns:
            A copy of the condition video input mask with the first num_conditional_frames_per_view_B_V frames of each view set to 1
        """
        assert condition_video_input_mask_B_C_V_T_H_W.ndim == 6, (
            "condition_video_input_mask_B_C_V_T_H_W must have 6 dimensions"
        )
        B, _, _, _, _, _ = condition_video_input_mask_B_C_V_T_H_W.shape
        copy_condition_video_input_mask_B_C_V_T_H_W = condition_video_input_mask_B_C_V_T_H_W.clone()
        for idx in range(B):
            for view_idx in range(num_conditional_frames_per_view_B_V.shape[1]):
                copy_condition_video_input_mask_B_C_V_T_H_W[
                    idx, :, view_idx, : num_conditional_frames_per_view_B_V[idx, view_idx]
                ] = 1
        return copy_condition_video_input_mask_B_C_V_T_H_W

    def edit_for_inference(
        self,
        condition_locations: Union[ConditionLocationList, ListConfig] = field(
            default_factory=lambda: ConditionLocationList([])
        ),
        is_cfg_conditional: bool = True,
        num_conditional_frames_per_view: int = 1,
    ) -> "MultiViewCondition":
        _condition = self.set_video_condition(
            state_t=self.state_t,
            gt_frames=self.gt_frames,
            condition_locations=condition_locations,
            random_min_num_conditional_frames_per_view=0,
            random_max_num_conditional_frames_per_view=0,
            num_conditional_frames_per_view=num_conditional_frames_per_view,
            view_condition_dropout_max=0,
        )
        if not is_cfg_conditional:
            # Do not use classifier free guidance on conditional frames.
            # YB found that it leads to worse results.
            _condition.use_video_condition.fill_(True)
        return _condition

    def broadcast(self, process_group: torch.distributed.ProcessGroup) -> "MultiViewCondition":
        if self.is_broadcasted:
            return self
        gt_frames_B_C_T_H_W = self.gt_frames
        view_indices_B_T = self.view_indices_B_T
        condition_video_input_mask_B_C_T_H_W = self.condition_video_input_mask_B_C_T_H_W
        kwargs = self.to_dict(skip_underscore=False)
        kwargs["gt_frames"] = None
        kwargs["condition_video_input_mask_B_C_T_H_W"] = None
        kwargs["view_indices_B_T"] = None
        new_condition = Text2WorldCondition.broadcast(
            type(self)(**kwargs),
            process_group,
        )

        kwargs = new_condition.to_dict(skip_underscore=False)
        _, _, T, _, _ = gt_frames_B_C_T_H_W.shape
        n_views = T // self.state_t
        assert T % self.state_t == 0, f"T must be a multiple of state_t. Got T={T} and state_t={self.state_t}."
        if process_group is not None:
            if T > 1 and process_group.size() > 1:
                log.debug(f"Broadcasting {gt_frames_B_C_T_H_W.shape=} to {n_views=} views")
                gt_frames_B_C_V_T_H_W = rearrange(gt_frames_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=n_views)
                condition_video_input_mask_B_C_V_T_H_W = rearrange(
                    condition_video_input_mask_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=n_views
                )
                view_indices_B_V_T = rearrange(view_indices_B_T, "B (V T) -> B V T", V=n_views)

                gt_frames_B_C_V_T_H_W = broadcast_split_tensor(
                    gt_frames_B_C_V_T_H_W, seq_dim=3, process_group=process_group
                )
                condition_video_input_mask_B_C_V_T_H_W = broadcast_split_tensor(
                    condition_video_input_mask_B_C_V_T_H_W, seq_dim=3, process_group=process_group
                )
                view_indices_B_V_T = broadcast_split_tensor(view_indices_B_V_T, seq_dim=2, process_group=process_group)

                gt_frames_B_C_T_H_W = rearrange(gt_frames_B_C_V_T_H_W, "B C V T H W -> B C (V T) H W", V=n_views)
                condition_video_input_mask_B_C_T_H_W = rearrange(
                    condition_video_input_mask_B_C_V_T_H_W, "B C V T H W -> B C (V T) H W", V=n_views
                )
                view_indices_B_T = rearrange(view_indices_B_V_T, "B V T -> B (V T)", V=n_views)

        kwargs["gt_frames"] = gt_frames_B_C_T_H_W
        kwargs["condition_video_input_mask_B_C_T_H_W"] = condition_video_input_mask_B_C_T_H_W
        kwargs["view_indices_B_T"] = view_indices_B_T
        return type(self)(**kwargs)


class MultiViewConditioner(GeneralConditioner):
    def forward(self, batch: Dict, override_dropout_rate: Optional[Dict[str, float]] = None) -> MultiViewCondition:
        output = super()._forward(batch, override_dropout_rate)
        return MultiViewCondition(**output)

    def get_condition_with_negative_prompt(
        self,
        data_batch: Dict,
    ) -> Tuple[Any, Any]:
        """
        Similar functionality as get_condition_uncondition
        But use negative prompts for unconditon
        """
        cond_dropout_rates, uncond_dropout_rates = {}, {}
        for emb_name, embedder in self.embedders.items():
            cond_dropout_rates[emb_name] = 0.0
            if isinstance(embedder, TextAttr) or isinstance(embedder, MVTextAttr):
                uncond_dropout_rates[emb_name] = 0.0
            else:
                uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0

        data_batch_neg_prompt = copy.deepcopy(data_batch)
        if "neg_t5_text_embeddings" in data_batch_neg_prompt:
            if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor):
                data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"]

        condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
        un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates)

        return condition, un_condition


MultiViewConditionerConfig: LazyDict = L(MultiViewConditioner)(
    **_SHARED_CONFIG,
    view_indices_B_T=L(ReMapkey)(
        input_key="latent_view_indices_B_T",
        output_key="view_indices_B_T",
        dropout_rate=0.0,
        dtype=None,
    ),
    ref_cam_view_idx_sample_position=L(ReMapkey)(
        input_key="ref_cam_view_idx_sample_position",
        output_key="ref_cam_view_idx_sample_position",
        dropout_rate=0.0,
        dtype=None,
    ),
)


class TextAttrEmptyStringDropout(TextAttr):
    def __init__(
        self,
        input_key: str,
        pos_input_key: str,
        dropout_input_key: str,
        dropout_rate: Optional[float] = 0.0,
        use_empty_string: bool = False,
        **kwargs,
    ):
        self._input_key = input_key
        self._pos_input_key = pos_input_key
        self._dropout_input_key = dropout_input_key
        self._dropout_rate = dropout_rate
        self._use_empty_string = use_empty_string
        super().__init__(input_key, dropout_rate)

    def forward(self, tensor: torch.Tensor):
        return {"crossattn_emb": tensor}

    def random_dropout_input(
        self,
        in_tensor_dict: torch.Tensor | Dict[str, torch.Tensor],
        dropout_rate: Optional[float] = None,
        key: Optional[str] = None,
    ) -> torch.Tensor:
        if key is not None and "mask" in key:
            return in_tensor_dict
        del key
        assert isinstance(in_tensor_dict, dict), f"in_tensor_dict must be a dict. Got {type(in_tensor_dict)}"
        in_tensor = in_tensor_dict[self._pos_input_key]
        B = in_tensor.shape[0]
        dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
        keep_mask = torch.bernoulli((1.0 - dropout_rate) * torch.ones(B)).type_as(in_tensor)
        if self._use_empty_string:
            empty_prompt = in_tensor_dict[self._dropout_input_key]
            if empty_prompt.shape[0] != B:
                empty_prompt = empty_prompt.repeat(B, 1, 1)
        else:
            empty_prompt = torch.zeros_like(in_tensor)

        return keep_mask * in_tensor + (1 - keep_mask) * empty_prompt

    def details(self) -> str:
        return "Output key: [crossattn_emb]"


_SHARED_CONFIG_PER_VIEW_DROPOUT = copy.deepcopy(_SHARED_CONFIG)
_SHARED_CONFIG_PER_VIEW_DROPOUT["text"] = L(MVTextAttr)(
    input_key=["t5_text_embeddings"],
    dropout_rate=0.2,
    use_empty_string=False,
)

MultiViewConditionerPerViewDropoutConfig: LazyDict = L(MultiViewConditioner)(
    **_SHARED_CONFIG_PER_VIEW_DROPOUT,
    view_indices_B_T=L(ReMapkey)(
        input_key="latent_view_indices_B_T",
        output_key="view_indices_B_T",
        dropout_rate=0.0,
        dtype=None,
    ),
    ref_cam_view_idx_sample_position=L(ReMapkey)(
        input_key="ref_cam_view_idx_sample_position",
        output_key="ref_cam_view_idx_sample_position",
        dropout_rate=0.0,
        dtype=None,
    ),
)


def register_conditioner():
    cs = ConfigStore.instance()
    cs.store(
        group="conditioner",
        package="model.config.conditioner",
        name="video_prediction_multiview_conditioner",
        node=MultiViewConditionerConfig,
    )
    cs.store(
        group="conditioner",
        package="model.config.conditioner",
        name="video_prediction_multiview_conditioner_per_view_dropout",
        node=MultiViewConditionerPerViewDropoutConfig,
    )
