"""Per-run configs for BC+PPO training.

Frozen dataclasses with a registry function, following the OpenPi pattern.
"""

import dataclasses
import tyro
from dataclasses import field


# Preset: RelativeIK action keys (14 dims total)
# Matches ActionsCfg term order: dual_arm IK first, then grippers.
RELATIVE_IK_ACTION_KEYS: list[str] = [
    "robot__action__poses__left::panda__xyz_relative",        # (3,)
    "robot__action__poses__left::panda__axis_angle_relative",  # (3,)
    "robot__action__poses__right::panda__xyz_relative",        # (3,)
    "robot__action__poses__right::panda__axis_angle_relative", # (3,)
    "robot__action__grippers__left::panda_hand",               # (1,)
    "robot__action__grippers__right::panda_hand",              # (1,)
]


@dataclasses.dataclass(frozen=True)
class RLConfig:
    """Tunable RL hyperparameters (embedded in BCPPORunConfig)."""

    # PPO
    learning_rate: float = 1e-4
    gamma: float = 0.98
    lam: float = 0.95
    entropy_coef: float = 0.006
    desired_kl: float = 0.01
    clip_param: float = 0.2
    max_grad_norm: float = 1.0
    num_learning_epochs: int = 5
    num_mini_batches: int = 4
    num_steps_per_env: int = 32
    value_loss_coef: float = 1.0

    # Actor-Critic
    actor_hidden_dims: tuple[int, ...] = (1024, 512, 256)
    critic_hidden_dims: tuple[int, ...] = (1024, 512, 256)
    activation: str = "elu"
    init_noise_std: float = 1.0
    actor_obs_normalization: bool = True
    critic_obs_normalization: bool = True

    # BC-PPO specific
    bc_coefficient: float = 0.1


@dataclasses.dataclass(frozen=True)
class BCPPORunConfig:
    """Configuration for a BC+PPO training run."""

    name: str
    environment: str
    scene_path: str
    library_dir: str

    log_dir: str = "./runs/test"
    overwrite: bool = False

    num_envs: int = 16
    action_keys: list[str] = field(default_factory=list)

    @property
    def dataset_path(self) -> str:
        """Resolved dataset path — falls back to {log_dir}/rollouts.hdf5."""
        return f"{self.log_dir}/rollouts.hdf5"

    # BC hyperparams
    bc_epochs: int = 200
    bc_batch_size: int = 256
    bc_learning_rate: float = 1e-4
    bc_eval_freq: int = 25

    # Reset distribution
    reset_from_dataset: bool = False
    reset_initial_only: bool = True
    curriculum: bool = False
    curriculum_initial_frontier: float = 0.8
    curriculum_advance_threshold: float = 0.7
    curriculum_retreat_threshold: float = 0.3
    curriculum_step_size: float = 0.05

    # RL hyperparams
    rl: RLConfig = field(default_factory=RLConfig)
    max_iterations: int = 5000

    device: str = "cuda"
    headless: bool = True
    seed: int = 42
    video: bool = False
    video_length: int = 200
    video_interval: int = 4500
    resume: str = ""
    eval_only: bool = False
    note: str = ""


def get_bcppo_configs() -> dict[str, BCPPORunConfig]:
    """Return a registry of named BC+PPO run configs."""
    configs = [
        BCPPORunConfig(
            name="kiwimanip_relative_ik",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/KiwiManip.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/kiwimanip",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=2048,
            bc_epochs=500,
            bc_eval_freq=500,
            bc_batch_size=256,
            bc_learning_rate=1e-3,
            max_iterations=5000,
            reset_from_dataset=True,
            reset_initial_only=False,
        ),
        BCPPORunConfig(
            name="kiwimanipsimple_relative_ik",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/KiwiManipSimple.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/kiwimanipsimple",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=10,
            bc_epochs=200,
            bc_batch_size=256,
            bc_learning_rate=1e-3,
            max_iterations=5000,
        ),
        BCPPORunConfig(
            name="kiwimanipsmall_relative_ik",
            environment="LBM-Scenario-ImplicitRelativeIK-State",
            scene_path="./envs/lbm_configs/3_cabot_breakfast/KiwiManipSmall.json",
            library_dir="./envs/lbm_usd_library",
            log_dir="runs/kiwimanipsmall",
            action_keys=RELATIVE_IK_ACTION_KEYS,
            num_envs=4096,
            bc_epochs=400,
            bc_eval_freq=400,
            bc_batch_size=256,
            bc_learning_rate=1e-3,
            max_iterations=30_000,
            reset_from_dataset=True,
            reset_initial_only=False,
            curriculum=True,
            rl=RLConfig(
                # num_steps_per_env=64,
                gamma=0.99,
            )
        ),
    ]
    return {c.name: c for c in configs}

def cli() -> BCPPORunConfig:
    return tyro.extras.overridable_config_cli({k: (k, v) for k, v in get_bcppo_configs().items()})