# -----------------------------------------------------------------------------
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# This codebase constitutes NVIDIA proprietary technology and is strictly
# confidential. Any unauthorized reproduction, distribution, or disclosure
# of this code, in whole or in part, outside NVIDIA is strictly prohibited
# without prior written consent.
#
# For inquiries regarding the use of this code in other NVIDIA proprietary
# projects, please contact the Deep Imagination Research Team at
# dir@exchange.nvidia.com.
# -----------------------------------------------------------------------------

from dataclasses import dataclass
from typing import Tuple

import torch
import torch.distributed as dist
import torch.utils.data
import wandb
from hydra.core.config_store import ConfigStore

from cosmos_policy._src.imaginaire.lazy_config import LazyCall as L
from cosmos_policy._src.imaginaire.model import ImaginaireModel
from cosmos_policy._src.imaginaire.utils import distributed, log
from cosmos_policy._src.imaginaire.utils.callback import WandBCallback as WandBCallbackImage
from cosmos_policy._src.imaginaire.utils.easy_io import easy_io
from cosmos_policy._src.predict2.callbacks.wandb_log import _LossRecord


@dataclass
class _LossRecordNoEDM:
    loss: float = 0
    iter_count: int = 0

    def reset(self) -> None:
        self.loss = 0
        self.iter_count = 0

    def get_stat(self) -> Tuple[float, float]:
        if self.iter_count > 0:
            avg_loss = self.loss / self.iter_count
            dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
            avg_loss = avg_loss.item()
        else:
            avg_loss = 0
        self.reset()
        return avg_loss


class WandbCallback(WandBCallbackImage):
    def __init__(
        self,
        logging_iter_multipler: int = 1,
        save_logging_iter_multipler: int = 1,
        save_s3: bool = False,
    ) -> None:
        super().__init__()
        self.train_image_log = _LossRecord()
        self.train_video_log = _LossRecord()
        self.train_final_loss_log = _LossRecord()
        self.train_demo_sample_action_mse_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_action_l1_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_proprio_mse_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_proprio_l1_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_wrist_image_mse_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_wrist_image_l1_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_image_mse_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_future_image_l1_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.train_demo_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_proprio_mse_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_proprio_l1_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_wrist_image_mse_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_wrist_image_l1_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_image_mse_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_future_image_l1_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.train_world_model_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.train_value_function_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.train_value_function_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.train_img_unstable_count = torch.zeros(1, device="cuda")
        self.train_video_unstable_count = torch.zeros(1, device="cuda")

        self.val_image_log = _LossRecord()
        self.val_video_log = _LossRecord()
        self.val_final_loss_log = _LossRecord()
        self.val_demo_sample_action_mse_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_action_l1_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_proprio_mse_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_proprio_l1_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_wrist_image_mse_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_wrist_image_l1_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_image_mse_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_future_image_l1_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.val_demo_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_proprio_mse_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_proprio_l1_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_wrist_image_mse_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_wrist_image_l1_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_image_mse_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_future_image_l1_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.val_world_model_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.val_value_function_sample_value_mse_loss_log = _LossRecordNoEDM()
        self.val_value_function_sample_value_l1_loss_log = _LossRecordNoEDM()
        self.val_img_unstable_count = torch.zeros(1, device="cuda")
        self.val_video_unstable_count = torch.zeros(1, device="cuda")

        self.logging_iter_multipler = logging_iter_multipler
        self.save_logging_iter_multipler = save_logging_iter_multipler
        assert self.logging_iter_multipler > 0, "logging_iter_multipler should be greater than 0"
        self.save_s3 = save_s3
        self.wandb_extra_tag = f"@{logging_iter_multipler}" if logging_iter_multipler > 1 else ""
        self.name = "wandb_loss_log" + self.wandb_extra_tag

    def on_training_step_end(
        self,
        model: ImaginaireModel,
        data_batch: dict[str, torch.Tensor],
        output_batch: dict[str, torch.Tensor],
        loss: torch.Tensor,
        iteration: int = 0,
    ) -> None:
        skip_update_due_to_unstable_loss = False
        if torch.isnan(loss) or torch.isinf(loss):
            skip_update_due_to_unstable_loss = True
            log.critical(
                f"Unstable loss {loss} at iteration {iteration} with is_image_batch: {model.is_image_batch(data_batch)}",
                rank0_only=False,
            )

        if not skip_update_due_to_unstable_loss:
            if model.is_image_batch(data_batch):
                self.train_image_log.loss += loss.detach().float()
                self.train_image_log.iter_count += 1
                self.train_image_log.edm_loss += output_batch["edm_loss"].detach().float()
            else:
                self.train_video_log.loss += loss.detach().float()
                self.train_video_log.iter_count += 1
                self.train_video_log.edm_loss += output_batch["edm_loss"].detach().float()

            self.train_final_loss_log.loss += loss.detach().float()
            self.train_final_loss_log.iter_count += 1
            self.train_final_loss_log.edm_loss += output_batch["edm_loss"].detach().float()

            demo_sample_action_mse_loss = output_batch["demo_sample_action_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_action_mse_loss):
                self.train_demo_sample_action_mse_loss_log.loss += demo_sample_action_mse_loss
                self.train_demo_sample_action_mse_loss_log.iter_count += 1
            demo_sample_action_l1_loss = output_batch["demo_sample_action_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_action_l1_loss):
                self.train_demo_sample_action_l1_loss_log.loss += demo_sample_action_l1_loss
                self.train_demo_sample_action_l1_loss_log.iter_count += 1
            demo_sample_future_proprio_mse_loss = output_batch["demo_sample_future_proprio_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_future_proprio_mse_loss):
                self.train_demo_sample_future_proprio_mse_loss_log.loss += demo_sample_future_proprio_mse_loss
                self.train_demo_sample_future_proprio_mse_loss_log.iter_count += 1
            demo_sample_future_proprio_l1_loss = output_batch["demo_sample_future_proprio_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_future_proprio_l1_loss):
                self.train_demo_sample_future_proprio_l1_loss_log.loss += demo_sample_future_proprio_l1_loss
                self.train_demo_sample_future_proprio_l1_loss_log.iter_count += 1
            demo_sample_future_wrist_image_mse_loss = (
                output_batch["demo_sample_future_wrist_image_mse_loss"].detach().float()
            )
            if not torch.isnan(demo_sample_future_wrist_image_mse_loss):
                self.train_demo_sample_future_wrist_image_mse_loss_log.loss += demo_sample_future_wrist_image_mse_loss
                self.train_demo_sample_future_wrist_image_mse_loss_log.iter_count += 1
            demo_sample_future_wrist_image_l1_loss = (
                output_batch["demo_sample_future_wrist_image_l1_loss"].detach().float()
            )
            if not torch.isnan(demo_sample_future_wrist_image_l1_loss):
                self.train_demo_sample_future_wrist_image_l1_loss_log.loss += demo_sample_future_wrist_image_l1_loss
                self.train_demo_sample_future_wrist_image_l1_loss_log.iter_count += 1
            demo_sample_future_image_mse_loss = output_batch["demo_sample_future_image_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_future_image_mse_loss):
                self.train_demo_sample_future_image_mse_loss_log.loss += demo_sample_future_image_mse_loss
                self.train_demo_sample_future_image_mse_loss_log.iter_count += 1
            demo_sample_future_image_l1_loss = output_batch["demo_sample_future_image_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_future_image_l1_loss):
                self.train_demo_sample_future_image_l1_loss_log.loss += demo_sample_future_image_l1_loss
                self.train_demo_sample_future_image_l1_loss_log.iter_count += 1
            demo_sample_value_mse_loss = output_batch["demo_sample_value_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_value_mse_loss):
                self.train_demo_sample_value_mse_loss_log.loss += demo_sample_value_mse_loss
                self.train_demo_sample_value_mse_loss_log.iter_count += 1
            demo_sample_value_l1_loss = output_batch["demo_sample_value_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_value_l1_loss):
                self.train_demo_sample_value_l1_loss_log.loss += demo_sample_value_l1_loss
                self.train_demo_sample_value_l1_loss_log.iter_count += 1

            world_model_sample_future_proprio_mse_loss = (
                output_batch["world_model_sample_future_proprio_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_proprio_mse_loss):
                self.train_world_model_sample_future_proprio_mse_loss_log.loss += (
                    world_model_sample_future_proprio_mse_loss
                )
                self.train_world_model_sample_future_proprio_mse_loss_log.iter_count += 1
            world_model_sample_future_proprio_l1_loss = (
                output_batch["world_model_sample_future_proprio_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_proprio_l1_loss):
                self.train_world_model_sample_future_proprio_l1_loss_log.loss += (
                    world_model_sample_future_proprio_l1_loss
                )
                self.train_world_model_sample_future_proprio_l1_loss_log.iter_count += 1
            world_model_sample_future_wrist_image_mse_loss = (
                output_batch["world_model_sample_future_wrist_image_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_wrist_image_mse_loss):
                self.train_world_model_sample_future_wrist_image_mse_loss_log.loss += (
                    world_model_sample_future_wrist_image_mse_loss
                )
                self.train_world_model_sample_future_wrist_image_mse_loss_log.iter_count += 1
            world_model_sample_future_wrist_image_l1_loss = (
                output_batch["world_model_sample_future_wrist_image_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_wrist_image_l1_loss):
                self.train_world_model_sample_future_wrist_image_l1_loss_log.loss += (
                    world_model_sample_future_wrist_image_l1_loss
                )
                self.train_world_model_sample_future_wrist_image_l1_loss_log.iter_count += 1
            world_model_sample_future_image_mse_loss = (
                output_batch["world_model_sample_future_image_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_image_mse_loss):
                self.train_world_model_sample_future_image_mse_loss_log.loss += world_model_sample_future_image_mse_loss
                self.train_world_model_sample_future_image_mse_loss_log.iter_count += 1
            world_model_sample_future_image_l1_loss = (
                output_batch["world_model_sample_future_image_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_image_l1_loss):
                self.train_world_model_sample_future_image_l1_loss_log.loss += world_model_sample_future_image_l1_loss
                self.train_world_model_sample_future_image_l1_loss_log.iter_count += 1
            world_model_sample_value_mse_loss = output_batch["world_model_sample_value_mse_loss"].detach().float()
            if not torch.isnan(world_model_sample_value_mse_loss):
                self.train_world_model_sample_value_mse_loss_log.loss += world_model_sample_value_mse_loss
                self.train_world_model_sample_value_mse_loss_log.iter_count += 1
            world_model_sample_value_l1_loss = output_batch["world_model_sample_value_l1_loss"].detach().float()
            if not torch.isnan(world_model_sample_value_l1_loss):
                self.train_world_model_sample_value_l1_loss_log.loss += world_model_sample_value_l1_loss
                self.train_world_model_sample_value_l1_loss_log.iter_count += 1

            value_function_sample_value_mse_loss = output_batch["value_function_sample_value_mse_loss"].detach().float()
            if not torch.isnan(value_function_sample_value_mse_loss):
                self.train_value_function_sample_value_mse_loss_log.loss += value_function_sample_value_mse_loss
                self.train_value_function_sample_value_mse_loss_log.iter_count += 1
            value_function_sample_value_l1_loss = output_batch["value_function_sample_value_l1_loss"].detach().float()
            if not torch.isnan(value_function_sample_value_l1_loss):
                self.train_value_function_sample_value_l1_loss_log.loss += value_function_sample_value_l1_loss
                self.train_value_function_sample_value_l1_loss_log.iter_count += 1

        else:
            if model.is_image_batch(data_batch):
                self.train_img_unstable_count += 1
            else:
                self.train_video_unstable_count += 1

        if iteration % (self.config.trainer.logging_iter * self.logging_iter_multipler) == 0:
            if self.logging_iter_multipler > 1:
                timer_results = {}
            else:
                timer_results = self.trainer.training_timer.compute_average_results()
            avg_image_loss, avg_image_edm_loss = self.train_image_log.get_stat()
            avg_video_loss, avg_video_edm_loss = self.train_video_log.get_stat()
            avg_final_loss, avg_final_edm_loss = self.train_final_loss_log.get_stat()

            avg_demo_sample_action_mse_loss = self.train_demo_sample_action_mse_loss_log.get_stat()
            avg_demo_sample_action_l1_loss = self.train_demo_sample_action_l1_loss_log.get_stat()
            avg_future_proprio_mse_loss = self.train_demo_sample_future_proprio_mse_loss_log.get_stat()
            avg_future_proprio_l1_loss = self.train_demo_sample_future_proprio_l1_loss_log.get_stat()
            avg_future_wrist_image_mse_loss = self.train_demo_sample_future_wrist_image_mse_loss_log.get_stat()
            avg_future_wrist_image_l1_loss = self.train_demo_sample_future_wrist_image_l1_loss_log.get_stat()
            avg_demo_sample_future_image_mse_loss = self.train_demo_sample_future_image_mse_loss_log.get_stat()
            avg_demo_sample_future_image_l1_loss = self.train_demo_sample_future_image_l1_loss_log.get_stat()
            avg_demo_sample_value_mse_loss = self.train_demo_sample_value_mse_loss_log.get_stat()
            avg_demo_sample_value_l1_loss = self.train_demo_sample_value_l1_loss_log.get_stat()

            avg_world_model_sample_future_proprio_mse_loss = (
                self.train_world_model_sample_future_proprio_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_proprio_l1_loss = (
                self.train_world_model_sample_future_proprio_l1_loss_log.get_stat()
            )
            avg_world_model_sample_future_wrist_image_mse_loss = (
                self.train_world_model_sample_future_wrist_image_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_wrist_image_l1_loss = (
                self.train_world_model_sample_future_wrist_image_l1_loss_log.get_stat()
            )
            avg_world_model_sample_future_image_mse_loss = (
                self.train_world_model_sample_future_image_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_image_l1_loss = (
                self.train_world_model_sample_future_image_l1_loss_log.get_stat()
            )
            avg_world_model_sample_value_mse_loss = self.train_world_model_sample_value_mse_loss_log.get_stat()
            avg_world_model_sample_value_l1_loss = self.train_world_model_sample_value_l1_loss_log.get_stat()

            avg_value_function_sample_value_mse_loss = self.train_value_function_sample_value_mse_loss_log.get_stat()
            avg_value_function_sample_value_l1_loss = self.train_value_function_sample_value_l1_loss_log.get_stat()

            dist.all_reduce(self.train_img_unstable_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(self.train_video_unstable_count, op=dist.ReduceOp.SUM)

            if distributed.is_rank0():
                info = {f"timer/{key}": value for key, value in timer_results.items()}
                info.update(
                    {
                        f"train{self.wandb_extra_tag}/image_loss": avg_image_loss,
                        f"train{self.wandb_extra_tag}/image_edm_loss": avg_image_edm_loss,
                        f"train{self.wandb_extra_tag}/video_loss": avg_video_loss,
                        f"train{self.wandb_extra_tag}/video_edm_loss": avg_video_edm_loss,
                        f"train{self.wandb_extra_tag}/loss": avg_final_loss,
                        f"train{self.wandb_extra_tag}/edm_loss": avg_final_edm_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_action_mse_loss": avg_demo_sample_action_mse_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_action_l1_loss": avg_demo_sample_action_l1_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_proprio_mse_loss": avg_future_proprio_mse_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_proprio_l1_loss": avg_future_proprio_l1_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_wrist_image_mse_loss": avg_future_wrist_image_mse_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_wrist_image_l1_loss": avg_future_wrist_image_l1_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_image_mse_loss": avg_demo_sample_future_image_mse_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_future_image_l1_loss": avg_demo_sample_future_image_l1_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_value_mse_loss": avg_demo_sample_value_mse_loss,
                        f"train{self.wandb_extra_tag}/demo_sample_value_l1_loss": avg_demo_sample_value_l1_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_proprio_mse_loss": avg_world_model_sample_future_proprio_mse_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_proprio_l1_loss": avg_world_model_sample_future_proprio_l1_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_wrist_image_mse_loss": avg_world_model_sample_future_wrist_image_mse_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_wrist_image_l1_loss": avg_world_model_sample_future_wrist_image_l1_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_image_mse_loss": avg_world_model_sample_future_image_mse_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_future_image_l1_loss": avg_world_model_sample_future_image_l1_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_value_mse_loss": avg_world_model_sample_value_mse_loss,
                        f"train{self.wandb_extra_tag}/world_model_sample_value_l1_loss": avg_world_model_sample_value_l1_loss,
                        f"train{self.wandb_extra_tag}/value_function_sample_value_mse_loss": avg_value_function_sample_value_mse_loss,
                        f"train{self.wandb_extra_tag}/value_function_sample_value_l1_loss": avg_value_function_sample_value_l1_loss,
                        f"train{self.wandb_extra_tag}/train_img_unstable_count": self.train_img_unstable_count.item(),
                        f"train{self.wandb_extra_tag}/train_video_unstable_count": self.train_video_unstable_count.item(),
                        "iteration": iteration,
                        "sample_counter": getattr(self.trainer, "sample_counter", iteration),
                    }
                )
                if self.save_s3:
                    if (
                        iteration
                        % (
                            self.config.trainer.logging_iter
                            * self.logging_iter_multipler
                            * self.save_logging_iter_multipler
                        )
                        == 0
                    ):
                        easy_io.dump(
                            info,
                            f"s3://rundir/{self.name}/Train_Iter{iteration:09d}.json",
                        )

                if wandb:
                    wandb.log(info, step=iteration)
            if self.logging_iter_multipler == 1:
                self.trainer.training_timer.reset()

            # reset unstable count
            self.train_img_unstable_count.zero_()
            self.train_video_unstable_count.zero_()

    def on_validation_step_end(
        self,
        model: ImaginaireModel,
        data_batch: dict[str, torch.Tensor],
        output_batch: dict[str, torch.Tensor],
        loss: torch.Tensor,
        iteration: int = 0,
    ) -> None:
        """
        Callback that is run after validation step is executed; similar to self.on_train_step_end().

        Things that are different from self.on_train_step_end():
            - No use of training timer
            - Using validation_iter instead of logging_iter
            - Doesn't do the push to WandB here; see self.on_validation_end() for that
        """
        skip_update_due_to_unstable_loss = False
        if torch.isnan(loss) or torch.isinf(loss):
            skip_update_due_to_unstable_loss = True
            log.critical(
                f"Unstable loss {loss} at iteration {iteration} with is_image_batch: {model.is_image_batch(data_batch)}",
                rank0_only=False,
            )

        if not skip_update_due_to_unstable_loss:
            if model.is_image_batch(data_batch):
                self.val_image_log.loss += loss.detach().float()
                self.val_image_log.iter_count += 1
                self.val_image_log.edm_loss += output_batch["edm_loss"].detach().float()
            else:
                self.val_video_log.loss += loss.detach().float()
                self.val_video_log.iter_count += 1
                self.val_video_log.edm_loss += output_batch["edm_loss"].detach().float()

            self.val_final_loss_log.loss += loss.detach().float()
            self.val_final_loss_log.iter_count += 1
            self.val_final_loss_log.edm_loss += output_batch["edm_loss"].detach().float()

            demo_sample_action_mse_loss = output_batch["demo_sample_action_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_action_mse_loss):
                self.val_demo_sample_action_mse_loss_log.loss += demo_sample_action_mse_loss
                self.val_demo_sample_action_mse_loss_log.iter_count += 1
            demo_sample_action_l1_loss = output_batch["demo_sample_action_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_action_l1_loss):
                self.val_demo_sample_action_l1_loss_log.loss += demo_sample_action_l1_loss
                self.val_demo_sample_action_l1_loss_log.iter_count += 1
            demo_sample_future_proprio_mse_loss = output_batch["demo_sample_future_proprio_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_future_proprio_mse_loss):
                self.val_demo_sample_future_proprio_mse_loss_log.loss += demo_sample_future_proprio_mse_loss
                self.val_demo_sample_future_proprio_mse_loss_log.iter_count += 1
            demo_sample_future_proprio_l1_loss = output_batch["demo_sample_future_proprio_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_future_proprio_l1_loss):
                self.val_demo_sample_future_proprio_l1_loss_log.loss += demo_sample_future_proprio_l1_loss
                self.val_demo_sample_future_proprio_l1_loss_log.iter_count += 1
            demo_sample_future_wrist_image_mse_loss = (
                output_batch["demo_sample_future_wrist_image_mse_loss"].detach().float()
            )
            if not torch.isnan(demo_sample_future_wrist_image_mse_loss):
                self.val_demo_sample_future_wrist_image_mse_loss_log.loss += demo_sample_future_wrist_image_mse_loss
                self.val_demo_sample_future_wrist_image_mse_loss_log.iter_count += 1
            demo_sample_future_wrist_image_l1_loss = (
                output_batch["demo_sample_future_wrist_image_l1_loss"].detach().float()
            )
            if not torch.isnan(demo_sample_future_wrist_image_l1_loss):
                self.val_demo_sample_future_wrist_image_l1_loss_log.loss += demo_sample_future_wrist_image_l1_loss
                self.val_demo_sample_future_wrist_image_l1_loss_log.iter_count += 1
            demo_sample_future_image_mse_loss = output_batch["demo_sample_future_image_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_future_image_mse_loss):
                self.val_demo_sample_future_image_mse_loss_log.loss += demo_sample_future_image_mse_loss
                self.val_demo_sample_future_image_mse_loss_log.iter_count += 1
            demo_sample_future_image_l1_loss = output_batch["demo_sample_future_image_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_future_image_l1_loss):
                self.val_demo_sample_future_image_l1_loss_log.loss += demo_sample_future_image_l1_loss
                self.val_demo_sample_future_image_l1_loss_log.iter_count += 1
            demo_sample_value_mse_loss = output_batch["demo_sample_value_mse_loss"].detach().float()
            if not torch.isnan(demo_sample_value_mse_loss):
                self.val_demo_sample_value_mse_loss_log.loss += demo_sample_value_mse_loss
                self.val_demo_sample_value_mse_loss_log.iter_count += 1
            demo_sample_value_l1_loss = output_batch["demo_sample_value_l1_loss"].detach().float()
            if not torch.isnan(demo_sample_value_l1_loss):
                self.val_demo_sample_value_l1_loss_log.loss += demo_sample_value_l1_loss
                self.val_demo_sample_value_l1_loss_log.iter_count += 1

            world_model_sample_future_proprio_mse_loss = (
                output_batch["world_model_sample_future_proprio_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_proprio_mse_loss):
                self.val_world_model_sample_future_proprio_mse_loss_log.loss += (
                    world_model_sample_future_proprio_mse_loss
                )
                self.val_world_model_sample_future_proprio_mse_loss_log.iter_count += 1
            world_model_sample_future_proprio_l1_loss = (
                output_batch["world_model_sample_future_proprio_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_proprio_l1_loss):
                self.val_world_model_sample_future_proprio_l1_loss_log.loss += world_model_sample_future_proprio_l1_loss
                self.val_world_model_sample_future_proprio_l1_loss_log.iter_count += 1
            world_model_sample_future_wrist_image_mse_loss = (
                output_batch["world_model_sample_future_wrist_image_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_wrist_image_mse_loss):
                self.val_world_model_sample_future_wrist_image_mse_loss_log.loss += (
                    world_model_sample_future_wrist_image_mse_loss
                )
                self.val_world_model_sample_future_wrist_image_mse_loss_log.iter_count += 1
            world_model_sample_future_wrist_image_l1_loss = (
                output_batch["world_model_sample_future_wrist_image_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_wrist_image_l1_loss):
                self.val_world_model_sample_future_wrist_image_l1_loss_log.loss += (
                    world_model_sample_future_wrist_image_l1_loss
                )
                self.val_world_model_sample_future_wrist_image_l1_loss_log.iter_count += 1
            world_model_sample_future_image_mse_loss = (
                output_batch["world_model_sample_future_image_mse_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_image_mse_loss):
                self.val_world_model_sample_future_image_mse_loss_log.loss += world_model_sample_future_image_mse_loss
                self.val_world_model_sample_future_image_mse_loss_log.iter_count += 1
            world_model_sample_future_image_l1_loss = (
                output_batch["world_model_sample_future_image_l1_loss"].detach().float()
            )
            if not torch.isnan(world_model_sample_future_image_l1_loss):
                self.val_world_model_sample_future_image_l1_loss_log.loss += world_model_sample_future_image_l1_loss
                self.val_world_model_sample_future_image_l1_loss_log.iter_count += 1
            world_model_sample_value_mse_loss = output_batch["world_model_sample_value_mse_loss"].detach().float()
            if not torch.isnan(world_model_sample_value_mse_loss):
                self.val_world_model_sample_value_mse_loss_log.loss += world_model_sample_value_mse_loss
                self.val_world_model_sample_value_mse_loss_log.iter_count += 1
            world_model_sample_value_l1_loss = output_batch["world_model_sample_value_l1_loss"].detach().float()
            if not torch.isnan(world_model_sample_value_l1_loss):
                self.val_world_model_sample_value_l1_loss_log.loss += world_model_sample_value_l1_loss
                self.val_world_model_sample_value_l1_loss_log.iter_count += 1

            value_function_sample_value_mse_loss = output_batch["value_function_sample_value_mse_loss"].detach().float()
            if not torch.isnan(value_function_sample_value_mse_loss):
                self.val_value_function_sample_value_mse_loss_log.loss += value_function_sample_value_mse_loss
                self.val_value_function_sample_value_mse_loss_log.iter_count += 1
            value_function_sample_value_l1_loss = output_batch["value_function_sample_value_l1_loss"].detach().float()
            if not torch.isnan(value_function_sample_value_l1_loss):
                self.val_value_function_sample_value_l1_loss_log.loss += value_function_sample_value_l1_loss
                self.val_value_function_sample_value_l1_loss_log.iter_count += 1

        else:
            if model.is_image_batch(data_batch):
                self.val_img_unstable_count += 1
            else:
                self.val_video_unstable_count += 1

    def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
        """Computes and logs averages of all the validation metrics."""
        if iteration % (self.config.trainer.validation_iter * self.logging_iter_multipler) == 0:
            avg_image_loss, avg_image_edm_loss = self.val_image_log.get_stat()
            avg_video_loss, avg_video_edm_loss = self.val_video_log.get_stat()
            avg_final_loss, avg_final_edm_loss = self.val_final_loss_log.get_stat()

            avg_demo_sample_action_mse_loss = self.val_demo_sample_action_mse_loss_log.get_stat()
            avg_demo_sample_action_l1_loss = self.val_demo_sample_action_l1_loss_log.get_stat()
            avg_future_proprio_mse_loss = self.val_demo_sample_future_proprio_mse_loss_log.get_stat()
            avg_future_proprio_l1_loss = self.val_demo_sample_future_proprio_l1_loss_log.get_stat()
            avg_future_wrist_image_mse_loss = self.val_demo_sample_future_wrist_image_mse_loss_log.get_stat()
            avg_future_wrist_image_l1_loss = self.val_demo_sample_future_wrist_image_l1_loss_log.get_stat()
            avg_demo_sample_future_image_mse_loss = self.val_demo_sample_future_image_mse_loss_log.get_stat()
            avg_demo_sample_future_image_l1_loss = self.val_demo_sample_future_image_l1_loss_log.get_stat()
            avg_demo_sample_value_mse_loss = self.val_demo_sample_value_mse_loss_log.get_stat()
            avg_demo_sample_value_l1_loss = self.val_demo_sample_value_l1_loss_log.get_stat()

            avg_world_model_sample_future_proprio_mse_loss = (
                self.val_world_model_sample_future_proprio_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_proprio_l1_loss = (
                self.val_world_model_sample_future_proprio_l1_loss_log.get_stat()
            )
            avg_world_model_sample_future_wrist_image_mse_loss = (
                self.val_world_model_sample_future_wrist_image_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_wrist_image_l1_loss = (
                self.val_world_model_sample_future_wrist_image_l1_loss_log.get_stat()
            )
            avg_world_model_sample_future_image_mse_loss = (
                self.val_world_model_sample_future_image_mse_loss_log.get_stat()
            )
            avg_world_model_sample_future_image_l1_loss = (
                self.val_world_model_sample_future_image_l1_loss_log.get_stat()
            )
            avg_world_model_sample_value_mse_loss = self.val_world_model_sample_value_mse_loss_log.get_stat()
            avg_world_model_sample_value_l1_loss = self.val_world_model_sample_value_l1_loss_log.get_stat()

            avg_value_function_sample_value_mse_loss = self.val_value_function_sample_value_mse_loss_log.get_stat()
            avg_value_function_sample_value_l1_loss = self.val_value_function_sample_value_l1_loss_log.get_stat()

            dist.all_reduce(self.val_img_unstable_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(self.val_video_unstable_count, op=dist.ReduceOp.SUM)

            if distributed.is_rank0():
                info = {}
                info.update(
                    {
                        f"val{self.wandb_extra_tag}/image_loss": avg_image_loss,
                        f"val{self.wandb_extra_tag}/image_edm_loss": avg_image_edm_loss,
                        f"val{self.wandb_extra_tag}/video_loss": avg_video_loss,
                        f"val{self.wandb_extra_tag}/video_edm_loss": avg_video_edm_loss,
                        f"val{self.wandb_extra_tag}/loss": avg_final_loss,
                        f"val{self.wandb_extra_tag}/edm_loss": avg_final_edm_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_action_mse_loss": avg_demo_sample_action_mse_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_action_l1_loss": avg_demo_sample_action_l1_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_proprio_mse_loss": avg_future_proprio_mse_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_proprio_l1_loss": avg_future_proprio_l1_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_wrist_image_mse_loss": avg_future_wrist_image_mse_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_wrist_image_l1_loss": avg_future_wrist_image_l1_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_image_mse_loss": avg_demo_sample_future_image_mse_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_future_image_l1_loss": avg_demo_sample_future_image_l1_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_value_mse_loss": avg_demo_sample_value_mse_loss,
                        f"val{self.wandb_extra_tag}/demo_sample_value_l1_loss": avg_demo_sample_value_l1_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_proprio_mse_loss": avg_world_model_sample_future_proprio_mse_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_proprio_l1_loss": avg_world_model_sample_future_proprio_l1_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_wrist_image_mse_loss": avg_world_model_sample_future_wrist_image_mse_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_wrist_image_l1_loss": avg_world_model_sample_future_wrist_image_l1_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_image_mse_loss": avg_world_model_sample_future_image_mse_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_future_image_l1_loss": avg_world_model_sample_future_image_l1_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_value_mse_loss": avg_world_model_sample_value_mse_loss,
                        f"val{self.wandb_extra_tag}/world_model_sample_value_l1_loss": avg_world_model_sample_value_l1_loss,
                        f"val{self.wandb_extra_tag}/value_function_sample_value_mse_loss": avg_value_function_sample_value_mse_loss,
                        f"val{self.wandb_extra_tag}/value_function_sample_value_l1_loss": avg_value_function_sample_value_l1_loss,
                        f"val{self.wandb_extra_tag}/val_img_unstable_count": self.val_img_unstable_count.item(),
                        f"val{self.wandb_extra_tag}/val_video_unstable_count": self.val_video_unstable_count.item(),
                    }
                )
                if self.save_s3:
                    if (
                        iteration
                        % (
                            self.config.trainer.validation_iter
                            * self.logging_iter_multipler
                            * self.save_logging_iter_multipler
                        )
                        == 0
                    ):
                        easy_io.dump(
                            info,
                            f"s3://rundir/{self.name}/Val_Iter{iteration:09d}.json",
                        )

                if wandb:
                    wandb.log(info, step=iteration)

                log.info(f"Validation final loss (iteration {iteration}): {avg_final_loss:4f}")

            # reset unstable count
            self.val_img_unstable_count.zero_()
            self.val_video_unstable_count.zero_()


WANDB_CALLBACK_ACTIONS = dict(
    wandb=L(WandbCallback)(
        save_s3="${upload_reproducible_setup}",
        logging_iter_multipler=1,
        save_logging_iter_multipler=10,
    ),
    wandb_10x=L(WandbCallback)(
        logging_iter_multipler=10,
        save_logging_iter_multipler=1,
        save_s3="${upload_reproducible_setup}",
    ),
)


class VideoLogCallback(WandBCallbackImage):
    """Logs predicted future frame videos to wandb every N iterations.
    Uses the same inference path as test_libero_rollout.py with a fixed reference observation."""

    def __init__(self, log_every_n_iters: int = 100) -> None:
        super().__init__()
        self.log_every_n_iters = log_every_n_iters
        self.name = "video_log_callback"
        self._ref_obs = None
        self._ref_task = None
        self._dataset_stats = None
        self._t5_cache_initialized = False

    def _ensure_reference_loaded(self):
        """Load reference observation and dataset stats once."""
        if self._ref_obs is not None:
            return
        import pickle
        import os

        # Load the sample LIBERO observation (same as sample_rollout.mp4)
        sample_path = "cosmos_policy/experiments/robot/libero/sample_libero_10_observation.pkl"
        if os.path.exists(sample_path):
            with open(sample_path, "rb") as f:
                self._ref_obs = pickle.load(f)
            self._ref_task = "put both the alphabet soup and the tomato sauce in the basket"

        # Load dataset stats
        from cosmos_policy.experiments.robot.cosmos_utils import load_dataset_stats, init_t5_text_embeddings_cache
        stats_path = "nvidia/Cosmos-Policy-LIBERO-Predict2-2B/libero_dataset_statistics.json"
        self._dataset_stats = load_dataset_stats(stats_path)

        if not self._t5_cache_initialized:
            t5_path = "nvidia/Cosmos-Policy-LIBERO-Predict2-2B/libero_t5_embeddings.pkl"
            init_t5_text_embeddings_cache(t5_path)
            self._t5_cache_initialized = True

    def on_training_step_end(
        self,
        model: ImaginaireModel,
        data_batch: dict[str, torch.Tensor],
        output_batch: dict[str, torch.Tensor],
        loss: torch.Tensor,
        iteration: int = 0,
    ) -> None:
        # Log on first iteration (before any training) AND every N iters
        if not hasattr(self, '_logged_initial') or not self._logged_initial:
            self._logged_initial = True
            # Fall through to log
        elif iteration % self.log_every_n_iters != 0:
            return
        if not distributed.is_rank0():
            return

        try:
            import numpy as np
            from cosmos_policy.experiments.robot.cosmos_utils import get_action

            self._ensure_reference_loaded()
            if self._ref_obs is None:
                log.warning("[VideoLogCallback] No reference observation found")
                return

            # Build a minimal config for get_action (same as test_libero_rollout.py)
            from dataclasses import dataclass
            from typing import Optional

            @dataclass
            class _InfCfg:
                suite: str = "libero"
                use_third_person_image: bool = True
                num_third_person_images: int = 1
                use_wrist_image: bool = True
                num_wrist_images: int = 1
                use_proprio: bool = True
                normalize_proprio: bool = True
                unnormalize_actions: bool = True
                use_jpeg_compression: bool = True
                trained_with_image_aug: bool = True
                flip_images: bool = True
                use_variance_scale: bool = False
                chunk_size: int = 16
                num_denoising_steps_action: int = 5

            cfg = _InfCfg()

            from PIL import Image as PILImage

            # --- 1) Build observation from the training batch (same scene as GT) ---
            b = 0  # first sample in batch
            current_img_idx = data_batch.get("current_image_latent_idx")
            wrist_img_idx = data_batch.get("current_wrist_image_latent_idx")
            gt_video = data_batch.get("video")

            obs_from_batch = None
            input_img = self._ref_obs["primary_image"]  # fallback
            if gt_video is not None and current_img_idx is not None:
                ci = current_img_idx[b].item()
                wi = wrist_img_idx[b].item() if wrist_img_idx is not None else -1
                if ci >= 0:
                    # Extract current primary frame, denormalize to uint8
                    cs = 1 + (ci - 1) * 4
                    curr_frame = gt_video[b, :, cs]  # (C, H, W) in [-1,1]
                    curr_np = ((curr_frame.permute(1, 2, 0).cpu().float().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
                    # Resize to 256x256 (what get_action expects as input)
                    curr_256 = np.array(PILImage.fromarray(curr_np).resize((256, 256), PILImage.BILINEAR))
                    input_img = curr_256

                    wrist_256 = curr_256  # fallback
                    if wi >= 0:
                        ws = 1 + (wi - 1) * 4
                        wrist_frame = gt_video[b, :, ws]
                        wrist_np = ((wrist_frame.permute(1, 2, 0).cpu().float().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
                        wrist_256 = np.array(PILImage.fromarray(wrist_np).resize((256, 256), PILImage.BILINEAR))

                    proprio = data_batch["proprio"][b].cpu().float().numpy() if "proprio" in data_batch else np.zeros(9)
                    obs_from_batch = {
                        "primary_image": curr_256,
                        "wrist_image": wrist_256,
                        "proprio": proprio.astype(np.float64),
                    }

            # Use batch observation if available, else fall back to fixed reference
            obs_to_use = obs_from_batch if obs_from_batch is not None else self._ref_obs
            # Get task description from T5 cache (use first available)
            task_desc = self._ref_task

            # --- 2) Run inference to get predicted future ---
            was_training = model.training
            model.eval()
            with torch.inference_mode():
                result = get_action(
                    cfg, model, self._dataset_stats, obs_to_use, task_desc,
                    num_denoising_steps_action=5,
                    generate_future_state_and_value_in_parallel=True,
                )
            if was_training:
                model.train()

            value = result.get("value_prediction", 0)

            # --- 3) Decode all 4 future frames from the generated latent ---
            pred_four_frames = None
            try:
                gen_latent = result.get("generated_latent")
                latent_indices = result.get("latent_indices", {})
                fidx = latent_indices.get("future_image_latent_idx", -1)
                if gen_latent is not None and fidx >= 0:
                    decoded_full = model.decode(gen_latent)
                    pixel_start = 1 + (fidx - 1) * 4
                    pixel_end = pixel_start + 4
                    if pixel_end <= decoded_full.shape[2]:
                        future_pixels = decoded_full[0, :, pixel_start:pixel_end]
                        pred_four_frames = ((future_pixels.permute(1, 2, 3, 0).cpu().float().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
            except Exception as e4:
                log.warning(f"[VideoLogCallback] Failed to decode pred 4 frames: {e4}")

            # --- 2) GT VIDEO: extract from the training data_batch ---
            gt_four_frames = None
            gt_current_frame = None
            try:
                gt_video = data_batch.get("video")
                future_img_idx = data_batch.get("future_image_latent_idx")
                current_img_idx = data_batch.get("current_image_latent_idx")
                if gt_video is not None and future_img_idx is not None:
                    b = 0  # first sample in batch
                    fi = future_img_idx[b].item()
                    if fi >= 0:
                        ps = 1 + (fi - 1) * 4
                        pe = ps + 4
                        if pe <= gt_video.shape[2]:
                            gt_fut = gt_video[b, :, ps:pe]
                            gt_four_frames = ((gt_fut.permute(1, 2, 3, 0).cpu().float().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
                    if current_img_idx is not None:
                        ci = current_img_idx[b].item()
                        if ci >= 0:
                            cs = 1 + (ci - 1) * 4
                            gt_curr = gt_video[b, :, cs]
                            gt_current_frame = ((gt_curr.permute(1, 2, 0).cpu().float().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
            except Exception as e5:
                log.warning(f"[VideoLogCallback] Failed to extract GT frames: {e5}")

            # --- 3) Save to disk ---
            debug_dir = "/data/cameron/vidgen/cosmos-policy/rollout_outputs/training_debug"
            import os
            os.makedirs(debug_dir, exist_ok=True)
            if pred_four_frames is not None:
                for fi in range(pred_four_frames.shape[0]):
                    PILImage.fromarray(pred_four_frames[fi]).save(f"{debug_dir}/iter{iteration:06d}_pred_{fi}.png")
            if gt_four_frames is not None:
                for fi in range(gt_four_frames.shape[0]):
                    PILImage.fromarray(gt_four_frames[fi]).save(f"{debug_dir}/iter{iteration:06d}_gt_{fi}.png")

            # --- 4) Log to wandb as MP4 videos ---
            if wandb:
                log_dict = {"video/value_prediction": value}

                # Pred video: input frame → 4 predicted future frames (from fixed reference)
                if pred_four_frames is not None:
                    inp_resized = np.array(PILImage.fromarray(input_img).resize(
                        (pred_four_frames.shape[2], pred_four_frames.shape[1]), PILImage.BILINEAR))
                    pred_vid = np.concatenate([[inp_resized], pred_four_frames], axis=0)  # (5, H, W, 3)
                    log_dict["video/pred_rollout"] = wandb.Video(
                        pred_vid.transpose(0, 3, 1, 2), fps=2, format="mp4")

                # GT video: current training frame → 4 GT future frames (from training batch, varies each log)
                if gt_four_frames is not None:
                    if gt_current_frame is not None:
                        gt_vid = np.concatenate([[gt_current_frame], gt_four_frames], axis=0)
                    else:
                        gt_vid = gt_four_frames
                    log_dict["video/gt_rollout"] = wandb.Video(
                        gt_vid.transpose(0, 3, 1, 2), fps=2, format="mp4")

                wandb.log(log_dict, step=iteration)
                log.info(f"[VideoLogCallback] Logged pred + GT videos at iter {iteration}, value={value:.3f}, batch_obs={'yes' if obs_from_batch else 'no'}")

        except Exception as e:
            import traceback
            err_msg = f"[VideoLogCallback] Failed at iteration {iteration}: {e}\n{traceback.format_exc()}"
            log.critical(err_msg)
            print(err_msg, flush=True)


WANDB_CALLBACK_ACTIONS = dict(
    wandb=L(WandbCallback)(
        save_s3="${upload_reproducible_setup}",
        logging_iter_multipler=1,
        save_logging_iter_multipler=10,
    ),
    wandb_10x=L(WandbCallback)(
        logging_iter_multipler=10,
        save_logging_iter_multipler=1,
        save_s3="${upload_reproducible_setup}",
    ),
    video_log=L(VideoLogCallback)(
        log_every_n_iters=100,
    ),
)


def register_configs():
    cs = ConfigStore.instance()
    cs.store(
        group="callbacks",
        package="trainer.callbacks",
        name="wandb_callback_actions",
        node=WANDB_CALLBACK_ACTIONS,
    )
