import logging
import os
import subprocess
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import tyro
from sagemaker.aws_batch.training_queue import TrainingQueue as Queue
from sagemaker.estimator import Estimator

import sagemaker

NAME = "openpi"
INSTANCE_MAPPER = {
    "p4de": "ml.p4de.24xlarge",
    "p5": "ml.p5.48xlarge",
    "p5en": "ml.p5en.48xlarge",
    "p6": "ml.p6-b200.48xlarge",
}
QUEUE_MAPPER = {
    "us-west-2": {
         "ml.p5.48xlarge": "fss-ml-p5-48xlarge-us-west-2",
         "ml.p5en.48xlarge": "fss-ml-p5en-48xlarge-us-west-2",
        "ml.p4de.24xlarge": "fss-ml-p4de-24xlarge-us-west-2",
        "ml.p4d.24xlarge": "fss-ml-p4d-24xlarge-us-west-2",
        "ml.p6-b200.48xlarge": "fss-ml-p6-b200-48xlarge-us-west-2",
    },
}


@dataclass(frozen=True)
class Args():
    config: str

    ##########
    user: str = "arhanjain"
    local: bool = False
    name_prefix: str|None = None

    # Volume size in GB
    volume_size: int = 500

    # AWS profile args
    region: str = "us-west-2"
    profile: str = "default"
    arn: str = "arn:aws:iam::124224456861:role/service-role/SageMaker-SageMakerAllAccess"

    # Instance args
    instance_count: int = 1
    instance_type: str = "p5"
    max_run: int = 5

    # SageMaker queue args
    queue_name: str = "ml"
    priority: int = 10


def run_command(command):
    print(f"=> {command}")
    subprocess.run(command, shell=True, check=True)



def get_image(user, profile="default", region="us-east-1"):
    os.environ["AWS_PROFILE"] = f"{profile}"
    account = subprocess.getoutput(
        f"aws --region {region} --profile {profile} sts get-caller-identity --query Account --output text"
    )
    assert account.isdigit(), f"Invalid account value: {account}"
    docker_dir = Path(__file__).parent
    algorithm_name = f"{user}-{NAME}"
    dockerfile_base = docker_dir / "Dockerfile.openpi"
    fullname = f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"

    login_cmd = (
        f"aws ecr get-login-password --region {region} --profile {profile} | "
        f"docker login --username AWS --password-stdin"
    )
    print("Building container")
    commands = [
        # Log in to Sagemaker account to get image.
        f"{login_cmd} 763104351884.dkr.ecr.{region}.amazonaws.com",
        f"docker build  -f {dockerfile_base} --build-arg AWS_REGION={region} -t {algorithm_name} .",
        f"docker tag {algorithm_name} {fullname}",
        f"{login_cmd} {fullname}",
        (
            f"aws --region {region} ecr describe-repositories --repository-names {algorithm_name} --no-cli-pager || "
            f"aws --region {region} ecr create-repository --repository-name {algorithm_name} --no-cli-pager"
        ),
    ]

    # Create command, making sure to exit if any part breaks.
    command = "\n".join([f"{x} || exit 1" for x in commands])
    run_command(command)
    run_command(f"docker push {fullname}")
    print("Sleeping for 5 seconds to ensure push succeeded")
    time.sleep(5)
    return f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"


def main():
    args = tyro.cli(Args)

    # # Check that batch sizes align. We do this here because the world_size is not known during `__post_init__`.
    # world_size = args.sagemaker.instance_count * 8
    # combined_batch_size = world_size * args.hparams.per_gpu_batch_size
    # assert args.hparams.global_batch_size % combined_batch_size == 0
    # args = args.sagemaker

    assert args.instance_type in INSTANCE_MAPPER
    if args.arn is None:
        assert "SAGEMAKER_ARN" in os.environ, "Please specify --arn or set the SAGEMAKER_ARN environment variable"
        object.__setattr__(args, "arn", os.environ["SAGEMAKER_ARN"])

    image = get_image(
        args.user,
        region=args.region,
        profile=args.profile,
    )
    os.environ["AWS_DEFAULT_REGION"] = args.region

    ##########
    # Create session and make sure of account and region
    ##########
    sagemaker_session = sagemaker.Session(boto_session=boto3.session.Session(region_name=args.region))

    if args.local:
        from sagemaker.local import LocalSession

        sagemaker_session = LocalSession()

    role = args.arn
    # provide a pre-existing role ARN as an alternative to creating a new role
    role_name = role.split("/")[-1]
    print(f"SageMaker Execution Role:{role}")
    print(f"The name of the Execution role: {role_name}")

    client = boto3.client("sts")
    account = client.get_caller_identity()["Account"]
    print(f"AWS account:{account}")

    ##########
    # Configure the training
    ##########
    def sanitize_name(name):
        name = name.replace("_", "-")
        clean = "".join(c if c.isalnum() or c == "-" else "" for c in name)
        clean = clean.strip("-")
        return clean or "job"

    base_job_name = sanitize_name(
        f"{args.name_prefix + '-' if args.name_prefix else ''}{args.user.replace('.', '-')}-{NAME}"
    )

    def get_job_name(base):
        now = datetime.now()
        # Format example: 2023-03-03-10-14-02-324
        date_str = f"{now.strftime('%Y-%m-%d-%H-%M-%S')}"
        # Ensure the job name follows SageMaker naming constraints: [a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}
        clean_base = sanitize_name(base)
        job_name = f"{clean_base}-{date_str}"
        job_name = job_name.lstrip("-")
        # Truncate if too long (SageMaker limit is 63 characters)
        if len(job_name) > 63:
            job_name = job_name[:63]
        # Remove trailing hyphens if any (truncation may have left some)
        job_name = job_name.rstrip("-")
        return job_name

    job_name = get_job_name(base_job_name)

    environment = {
        "SM_USE_RESERVED_CAPACITY": "1",
        "WANDB_PROJECT": "openpi",
        "NCCL_DEBUG": "INFO",
        "XLA_PYTHON_CLIENT_MEM_FRACTION": "0.95",
        "TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS": "1",

        "CONFIG_NAME": args.config,
        # "SAGEMAKER_PROGRAM": "/opt/ml/code/openpi/scpy",
        "FI_EFA_FORK_SAFE": "1",
        # "OPENPI_DATA_HOME": "/opt/ml/.cache/openpi",
        "OPENPI_DATA_HOME": "/tmp/.cache/openpi",
    }
    with open("secrets.env", "r") as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith("#") and "=" in line:
                key, value = line.split("=", 1)
                environment[key.strip()] = value.strip().strip("\"'")

    estimator = Estimator(
        # entry_point="train.py",
        # source_dir="../src/openpi/scripts",
        entry_point="sagemaker/train.sh",
        sagemaker_session=sagemaker_session,
        base_job_name=base_job_name,
        # hyperparameters={"config": args.config},
        role=role,
        image_uri=image,
        instance_count=args.instance_count,
        instance_type="local_gpu" if args.local else INSTANCE_MAPPER[args.instance_type],
        input_mode="FastFile",
        max_run=args.max_run * 24 * 60 * 60,  # max_run days
        environment=environment,
        keep_alive_period_in_seconds=5 * 60,  # 5 minutes
        tags=[
            {"Key": "tri.project", "Value": "idk yet"},
            {"Key": "tri.owner.email", "Value": f"{args.user}@tri.global"},
        ],
        volume_size=args.volume_size,
    )
    # estimator = PyTorch(
    #     entry_point="vla_foundry/main.py",
    #     sagemaker_session=sagemaker_session,
    #     base_job_name=base_job_name,
    #     hyperparameters={"config_path": hyperparameter_sagemaker_path},
    #     role=role,
    #     image_uri=image,
    #     instance_count=args.instance_count,
    #     instance_type="local_gpu" if args.local else INSTANCE_MAPPER[args.instance_type],
    #     job_name=job_name,
    #     checkpoint_local_path=None if args.local else checkpoint_local_path,
    #     # Training using SMDataParallel Distributed Training Framework
    #     distribution={"torch_distributed": {"enabled": True}},
    #     # Max run 5 days
    #     max_run=args.max_run * 24 * 60 * 60,  # max_run days
    #     input_mode="FastFile",
    #     environment=environment,
    #     keep_alive_period_in_seconds=5 * 60,  # 5 minutes
    #     tags=[
    #         {"Key": "tri.project", "Value": "MM:PJ-0077"},
    #         {"Key": "tri.owner.email", "Value": f"{args.user}@tri.global"},
    #     ],
    #     volume_size=args.volume_size,
    # )

    queue = Queue(
        queue_name=QUEUE_MAPPER[args.region][INSTANCE_MAPPER[args.instance_type]].replace("ml", args.queue_name)
    )
    queue.map(
        estimator,
        inputs=[None],
        job_names=[job_name],
        priority=args.priority,
        share_identifier="default",
        timeout={"attemptDurationSeconds": args.max_run * 24 * 60 * 60},
    )
    print(f"Queued {job_name}")


if __name__ == "__main__":
    main()
