"""Per-run configs for SAC training.

Frozen dataclasses with a registry function, following the same pattern
as bcppo_configs.py.
"""

import dataclasses
from dataclasses import field

import tyro


RELATIVE_IK_ACTION_KEYS: list[str] = [
    "robot__action__poses__left::panda__xyz_relative",
    "robot__action__poses__left::panda__axis_angle_relative",
    "robot__action__poses__right::panda__xyz_relative",
    "robot__action__poses__right::panda__axis_angle_relative",
    "robot__action__grippers__left::panda_hand",
    "robot__action__grippers__right::panda_hand",
]


@dataclasses.dataclass(frozen=True)
class SACAlgorithmConfig:
    """SAC algorithm hyperparameters."""

    # Network
    actor_hidden_dim: int = 512
    critic_hidden_dim: int = 768
    use_tanh: bool = True
    use_layer_norm: bool = True
    log_std_max: float = 0.0
    log_std_min: float = -5.0
    num_atoms: int = 101
    v_min: float = -20.0
    v_max: float = 20.0
    num_q_networks: int = 2

    # Optimizers
    actor_learning_rate: float = 3e-4
    critic_learning_rate: float = 3e-4
    alpha_learning_rate: float = 3e-4
    weight_decay: float = 0.001

    # SAC core
    alpha_init: float = 0.001
    target_entropy_ratio: float = 0.0
    use_autotune: bool = True
    gamma: float = 0.97
    tau: float = 0.125

    # Replay / training schedule
    buffer_size: int = 1024
    batch_size: int = 8192
    num_updates: int = 8
    policy_frequency: int = 4
    learning_starts: int = 10
    num_steps: int = 1
    max_grad_norm: float = 0.0

    # Normalization / precision
    obs_normalization: bool = True
    amp: bool = True
    amp_dtype: str = "bf16"
    compile: bool = False


@dataclasses.dataclass(frozen=True)
class SACRunConfig:
    """Configuration for a SAC training run."""

    name: str
    environment: str
    scene_path: str
    library_dir: str

    log_dir: str = "./runs/test"
    overwrite: bool = False

    num_envs: int = 16
    action_keys: list[str] = field(default_factory=list)

    # Obs group mapping (actor / critic)
    actor_obs_groups: list[str] = field(default_factory=lambda: ["policy"])
    critic_obs_groups: list[str] = field(default_factory=lambda: ["policy"])

    # RLPD demo buffer
    demo_dataset_paths: list[str] = field(default_factory=list)
    demo_ratio: float = 0.0
    demo_success_only: bool = False

    # Reset distribution
    reset_dataset_path: str = ""
    reset_from_dataset: bool = False
    reset_initial_only: bool = True
    curriculum: bool = False
    curriculum_initial_frontier: float = 0.8
    curriculum_advance_threshold: float = 0.7
    curriculum_retreat_threshold: float = 0.3
    curriculum_step_size: float = 0.05
    curriculum_graduation: bool = True

    # SAC algorithm hyperparams
    sac: SACAlgorithmConfig = field(default_factory=SACAlgorithmConfig)
    max_iterations: int = 50_000
    save_interval: int = 5000
    log_interval: int = 100

    device: str = "cuda"
    headless: bool = True
    seed: int = 42
    video: bool = False
    video_length: int = 200
    video_interval: int = 10_000
    resume: str = ""
    eval_only: bool = False
    note: str = ""

    # Logging
    logger: str = "wandb"
    wandb_project: str = "sim-improvement"
    experiment_name: str = ""


def get_sac_configs() -> dict[str, SACRunConfig]:
    """Return a registry of named SAC run configs."""
    configs = [
        SACRunConfig(
            name="mughang_single",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/1_riverway_drying_rack/HangMugFromDryingRack.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/HangMugFromDryingRack_processed_single_clean",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=8096,
            max_iterations=200_000,

            demo_dataset_paths=["runs/HangMugFromDryingRack_processed/clean_single.hdf5"],
            demo_ratio=0.05,
            demo_success_only=False,

            reset_dataset_path="runs/HangMugFromDryingRack_processed/clean_single.hdf5",
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            curriculum_initial_frontier=0.95,
            curriculum_retreat_threshold=0.2,
            curriculum_advance_threshold=0.7,
            curriculum_graduation=False,

            sac=SACAlgorithmConfig(
                amp=False,
            ),
        ),
        SACRunConfig(
            name="mughang",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/1_riverway_drying_rack/HangMugFromDryingRack.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/HangMugFromDryingRack_processed",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=8096,
            max_iterations=200_000,

            demo_dataset_paths=["runs/HangMugFromDryingRack_processed/replayed.hdf5", "runs/HangMugFromDryingRack_processed/replayed_failed.hdf5"],
            demo_ratio=0.1,
            demo_success_only=False,

            reset_dataset_path="runs/HangMugFromDryingRack_processed/replayed.hdf5",
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            curriculum_initial_frontier=0.95,
            curriculum_retreat_threshold=0.2,
            curriculum_advance_threshold=0.7,
            curriculum_graduation=False,

            sac=SACAlgorithmConfig(
                amp=False,
            ),
        ),
        SACRunConfig(
            name="bananaonsaucer_rlpd",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/BananaOnSaucer.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/bananaonsaucer",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=4096,
            max_iterations=100_000,

            demo_dataset_paths=["runs/bananaonsaucer/rollouts.hdf5", "runs/bananaonsaucer/rollouts_failed.hdf5"],
            demo_ratio=0.05,

            reset_dataset_path="runs/bananaonsaucer/rollouts.hdf5",
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            sac=SACAlgorithmConfig(
                amp=False,
            ),
        ),
        SACRunConfig(
            name="kiwimanip_rlpd",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/KiwiManip.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/kiwimanip_rlpd",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=4096,
            max_iterations=100_000,

            demo_dataset_paths=["runs/kiwimanip_rlpd/rollouts.hdf5", "runs/kiwimanip_rlpd/rollouts_failed.hdf5"],
            demo_ratio=0.05,

            reset_dataset_path="runs/kiwimanip_rlpd/rollouts.hdf5",
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            curriculum_graduation=False,
            sac=SACAlgorithmConfig(
                amp=False,
            ),
        ),
        SACRunConfig(
            name="kiwimanipsmall_sac",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/KiwiManipSmall.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/kiwimanipsmall_sac",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=4096,
            max_iterations=100_000,
            reset_dataset_path="runs/kiwimanipsmall/rollouts.hdf5",
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            sac=SACAlgorithmConfig(),
        ),
    ]
    return {c.name: c for c in configs}


def cli() -> SACRunConfig:
    return tyro.extras.overridable_config_cli({k: (k, v) for k, v in get_sac_configs().items()})
