"""
Training script wrapper that sets up transformer_engine and flash_attn mocks
for systems without compiled CUDA extensions (e.g., CUDA 12.4 without TE build).
Then delegates to the real training script.
"""

import sys
import types

# === Mock transformer_engine_torch ===
sys.modules["transformer_engine_torch"] = types.ModuleType("transformer_engine_torch")

# === Mock transformer_engine.pytorch ===
import torch

def _apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=False, cu_seqlens=None):
    """Fallback RoPE using halved rotation (matches TE convention)."""
    rot_dim = freqs.shape[-1]
    dtype = t.dtype
    if tensor_format == "bshd" and freqs.ndim == 4 and freqs.shape[1] == 1:
        freqs = freqs.permute(1, 0, 2, 3)
    cos_ = torch.cos(freqs).to(dtype)
    sin_ = torch.sin(freqs).to(dtype)
    t_rot = t[..., :rot_dim]
    t_pass = t[..., rot_dim:]
    x1, x2 = t_rot.chunk(2, dim=-1)
    rotated_half = torch.cat([-x2, x1], dim=-1)
    output = t_rot * cos_ + rotated_half * sin_
    if t_pass.shape[-1] > 0:
        return torch.cat([output, t_pass], dim=-1)
    return output

for mod_name in ["transformer_engine.pytorch", "transformer_engine.pytorch.attention",
                 "transformer_engine.pytorch.attention.rope", "transformer_engine.pytorch.distributed",
                 "transformer_engine.pytorch.float8_tensor"]:
    if mod_name not in sys.modules:
        sys.modules[mod_name] = types.ModuleType(mod_name)

sys.modules["transformer_engine.pytorch.attention.rope"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].DotProductAttention = type("DotProductAttention", (), {})

class _FT: pass
sys.modules["transformer_engine.pytorch.float8_tensor"].Float8Tensor = _FT

class _RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

sys.modules["transformer_engine.pytorch"].RMSNorm = _RMSNorm

# === Mock flash_attn ===
_dummy = lambda *a, **k: None
flash_attn_mock = types.ModuleType("flash_attn")
flash_attn_mock.__version__ = "2.7.0"
flash_attn_mock.flash_attn_varlen_func = _dummy
flash_attn_mock.flash_attn_func = _dummy
sys.modules["flash_attn"] = flash_attn_mock
for submod in ["flash_attn.bert_padding", "flash_attn.flash_attn_interface",
               "flash_attn.layers.rotary", "flash_attn.ops.rms_norm",
               "flash_attn.ops", "flash_attn.layers"]:
    m = types.ModuleType(submod)
    for attr in ["pad_input", "unpad_input", "flash_attn_varlen_func", "index_first_axis", "flash_attn_func", "rms_norm"]:
        setattr(m, attr, _dummy)
    m.apply_rotary_emb = None
    sys.modules[submod] = m

import importlib, importlib.util, importlib.metadata
_of = importlib.util.find_spec
importlib.util.find_spec = lambda n, *a, **k: True if n == "flash_attn" else _of(n, *a, **k)
_ov = importlib.metadata.version
importlib.metadata.version = lambda n: "2.7.0" if n == "flash_attn" else _ov(n)

# === Patch megatron-core 0.4.0 compatibility ===
import megatron.core.parallel_state as _mps
if not hasattr(_mps, "is_initialized") and hasattr(_mps, "is_unitialized"):
    _mps.is_initialized = lambda: not _mps.is_unitialized()

# === Now run the actual training script ===
import argparse
import os
import traceback
from loguru import logger as logging
from cosmos_policy._src.imaginaire.config import load_config, pretty_print_overrides
from cosmos_policy._src.imaginaire.lazy_config import LazyConfig
from cosmos_policy._src.imaginaire.serialization import to_yaml
from cosmos_policy.scripts.train import launch

parser = argparse.ArgumentParser(description="Training with mocks")
parser.add_argument("--config", help="Path to the config file", required=False)
parser.add_argument(
    "opts", default=None, nargs=argparse.REMAINDER,
    help="Config overrides: space-separated path.key=value pairs",
)
parser.add_argument("--dryrun", action="store_true")
args = parser.parse_args()

config = load_config(args.config, args.opts, enable_one_logger=True)

if args.dryrun:
    logging.info("Config:\n" + config.pretty_print(use_color=True))
    os.makedirs(config.job.path_local, exist_ok=True)
    try:
        to_yaml(config, f"{config.job.path_local}/config.yaml")
    except Exception:
        LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
    print(f"{config.job.path_local}/config.yaml")
else:
    launch(config, args)
