# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from hydra.core.config_store import ConfigStore

from cosmos_predict2._src.imaginaire.lazy_config import LazyCall as L
from cosmos_predict2._src.predict2.datasets.local_datasets.dataset_video import (
    VideoDataset,
    get_generic_dataloader,
    get_sampler,
)
from cosmos_predict2.config import MODEL_CHECKPOINTS, ModelKey, ModelSize

DEFAULT_CHECKPOINT_2B = MODEL_CHECKPOINTS[ModelKey(post_trained=False)]
DEFAULT_CHECKPOINT_14B = MODEL_CHECKPOINTS[ModelKey(post_trained=False, size=ModelSize._14B)]

# GR1 dataset and dataloader (93 frames)
example_video_dataset_gr1 = L(VideoDataset)(
    dataset_dir="datasets/benchmark_train/gr1",
    num_frames=93,
    video_size=(432, 768),
)

# Create DataLoader with distributed sampler
dataloader_train_gr1 = L(get_generic_dataloader)(
    dataset=example_video_dataset_gr1,
    sampler=L(get_sampler)(dataset=example_video_dataset_gr1),
    batch_size=1,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)

# GR1 dataset and dataloader (45 frames)
example_video_dataset_gr1_short = L(VideoDataset)(
    dataset_dir="datasets/benchmark_train/gr1",
    num_frames=45,
    video_size=(432, 768),
)

# Create DataLoader with distributed sampler
dataloader_train_gr1_short = L(get_generic_dataloader)(
    dataset=example_video_dataset_gr1_short,
    sampler=L(get_sampler)(dataset=example_video_dataset_gr1_short),
    batch_size=1,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)

# Video2World post-training configuration for 2B model
# torchrun --nproc_per_node=1 --master_port=12341 -m scripts.train --config=cosmos_predict2/_src/predict2/configs/video2world/config.py -- experiment=predict2_video2world_training_2b_groot_gr1_480
predict2_video2world_training_2b_groot_gr1_480 = dict(
    defaults=[
        f"/experiment/{DEFAULT_CHECKPOINT_2B.experiment}",
        {"override /data_train": "mock"},
        {"override /data_val": "mock"},
        "_self_",
    ],
    dataloader_train=dataloader_train_gr1,
    checkpoint=dict(
        save_iter=200,
        load_path=DEFAULT_CHECKPOINT_2B.s3.uri,
        load_from_object_store=dict(
            enabled=False,
        ),
        save_to_object_store=dict(
            enabled=False,
        ),
    ),
    job=dict(
        project="cosmos_predict_v2p5",
        group="video2world",
        name="2b_groot_gr1_480",
    ),
    optimizer=dict(
        lr=2 ** (-14.5),
        weight_decay=0.001,
    ),
    scheduler=dict(
        f_max=[0.5],
        f_min=[0.2],
        warm_up_steps=[1_000],
        cycle_lengths=[100000],
    ),
    trainer=dict(
        logging_iter=100,
        max_iter=1000,
        callbacks=dict(
            heart_beat=dict(
                save_s3=False,
            ),
            iter_speed=dict(
                hit_thres=100,
                save_s3=False,
            ),
            device_monitor=dict(
                save_s3=False,
            ),
            every_n_sample_reg=dict(
                every_n=200,
                save_s3=False,
            ),
            every_n_sample_ema=dict(
                every_n=200,
                save_s3=False,
            ),
            wandb=dict(
                save_s3=False,
            ),
            wandb_10x=dict(
                save_s3=False,
            ),
            dataloader_speed=dict(
                save_s3=False,
            ),
        ),
    ),
    model_parallel=dict(
        context_parallel_size=1,
    ),
)

# Video2World post-training configuration for 14B model
# torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/_src/predict2/configs/video2world/config.py -- experiment=predict2_video2world_training_14b_groot_gr1_480
predict2_video2world_training_14b_groot_gr1_480 = dict(
    defaults=[
        f"/experiment/{DEFAULT_CHECKPOINT_14B.experiment}",
        {"override /data_train": "mock"},
        {"override /data_val": "mock"},
        "_self_",
    ],
    dataloader_train=dataloader_train_gr1_short,
    checkpoint=dict(
        save_iter=200,
        load_path=DEFAULT_CHECKPOINT_14B.s3.uri,
        load_from_object_store=dict(
            enabled=False,
        ),
        save_to_object_store=dict(
            enabled=False,
        ),
    ),
    job=dict(
        project="cosmos_predict_v2p5",
        group="video2world",
        name="14b_groot_gr1_480",
    ),
    optimizer=dict(
        lr=2 ** (-15.5),
        weight_decay=0.001,
    ),
    scheduler=dict(
        f_max=[0.5],
        f_min=[0.2],
        warm_up_steps=[1_000],
        cycle_lengths=[100000],
    ),
    trainer=dict(
        logging_iter=100,
        max_iter=1000,
        straggler_detection=dict(
            enabled=False,
        ),
        callbacks=dict(
            heart_beat=dict(
                save_s3=False,
            ),
            iter_speed=dict(
                hit_thres=100,
                save_s3=False,
            ),
            device_monitor=dict(
                save_s3=False,
            ),
            every_n_sample_reg=dict(
                every_n=200,
                save_s3=False,
            ),
            every_n_sample_ema=dict(
                every_n=200,
                save_s3=False,
            ),
            wandb=dict(
                save_s3=False,
            ),
            wandb_10x=dict(
                save_s3=False,
            ),
            dataloader_speed=dict(
                save_s3=False,
            ),
        ),
    ),
    model_parallel=dict(
        context_parallel_size=1,
    ),
    model=dict(
        config=dict(
            state_t=12,
        ),
    ),
)

cs = ConfigStore.instance()

for _item in [
    predict2_video2world_training_2b_groot_gr1_480,
    predict2_video2world_training_14b_groot_gr1_480,
]:
    # Get the experiment name from the global variable
    experiment_name = [name.lower() for name, value in globals().items() if value is _item][0]
    cs.store(
        group="experiment",
        package="_global_",
        name=experiment_name,
        node=_item,
    )
