diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 73a308ddf5..15051dc04a 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -91,7 +91,7 @@ def remote( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun=False, - nproc_per_node=1, + nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -284,8 +284,9 @@ def remote( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node for distributed training. - Defaults to ``1``. + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ def _remote(func): @@ -325,9 +326,13 @@ def _remote(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun) + or (spark_config is None and use_torchrun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun. " + "Please provide instance_count = 1" ) @@ -532,7 +537,7 @@ def __init__( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun=False, - nproc_per_node=1, + nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -725,17 +730,22 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node. - Defaults to ``1``. + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.max_parallel_jobs = max_parallel_jobs if self.max_parallel_jobs <= 0: raise ValueError("max_parallel_jobs must be greater than 0.") - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun) + or (spark_config is None and use_torchrun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun. " + "Please provide instance_count = 1" ) diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index ade4a9e652..862c67d9ee 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -55,8 +55,6 @@ def __init__( hmac_key: str, s3_kms_key: str = None, context: Context = Context(), - use_torchrun: bool = False, - nproc_per_node: int = 1, ): """Construct a StoredFunction object. @@ -67,16 +65,12 @@ def __init__( s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. - use_torchrun: Whether to use torchrun for distributed training. - nproc_per_node: Number of processes per node for distributed training. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.hmac_key = hmac_key self.context = context - self.use_torchrun = use_torchrun - self.nproc_per_node = nproc_per_node self.func_upload_path = s3_path_join( s3_base_uri, context.step_name, context.func_step_s3_dir diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 8ab4d420e5..4e2e749bcb 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -130,9 +130,12 @@ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json printf "INFO: Bootstraping runtime environment.\\n" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] then @@ -155,9 +158,11 @@ fi printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n" $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function\\n" + printf "INFO: python -m sagemaker.remote_function.invoke_function \\n" python -m sagemaker.remote_function.invoke_function "$@" fi """ @@ -175,9 +180,12 @@ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json printf "INFO: Bootstraping runtime environment.\\n" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] then @@ -200,11 +208,18 @@ fi printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" - $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \ + printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" - torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" + printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n" + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@" fi """ @@ -262,8 +277,8 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, - nproc_per_node=1, + use_torchrun: bool = False, + nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -445,6 +460,13 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.sagemaker_session = sagemaker_session or Session() self.environment_variables = resolve_value_from_config( @@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non ) logger.info("Creating job: %s", job_name) + job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) return _Job( @@ -776,8 +799,6 @@ def compile( s3_base_uri=s3_base_uri, hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, - use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -790,8 +811,6 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), - use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, ) stored_function.save_pipeline_step_function(serialized_data) @@ -931,6 +950,7 @@ def compile( request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_torchrun_to_request(extended_request, job_settings) return extended_request @@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, - nproc_per_node: int = 1, + nproc_per_node: Optional[int] = None, ): """Copy runtime scripts to a folder and upload to S3. @@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts( use_torchrun (bool): Whether to use torchrun or not. - nproc_per_node (int): Number of processes per node. + nproc_per_node (Optional[int]): Number of processes per node """ from sagemaker.workflow.utilities import load_step_compilation_context @@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts( if use_torchrun: entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT - entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node)) + + if nproc_per_node is not None and nproc_per_node > 0: + entry_point_script = entry_point_script.replace( + "$SM_NPROC_PER_NODE", str(nproc_per_node) + ) with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration( return config_file_s3_uri +def _extend_torchrun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with torchrun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_torchrun = job_settings.use_torchrun + instance_count = job_settings.instance_count + + if not use_torchrun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + def _extend_spark_config_to_request( request_dict: Dict, job_settings: _JobSettings, diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 8fd83bfcfe..0b0823da77 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -15,10 +15,14 @@ import argparse import getpass -import sys +import json +import multiprocessing import os -import shutil import pathlib +import shutil +import subprocess +import sys +from typing import Dict, Any if __package__ is None or __package__ == "": from runtime_environment_manager import ( @@ -39,64 +43,48 @@ REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" BASE_CHANNEL_PATH = "/opt/ml/input/data" FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/output", "/opt/ml/model", "/tmp"] +JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" +SM_MODEL_DIR = "/opt/ml/model" -logger = get_logger() +SM_INPUT_DIR = "/opt/ml/input" +SM_INPUT_DATA_DIR = "/opt/ml/input/data" +SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" +SM_OUTPUT_DIR = "/opt/ml/output" +SM_OUTPUT_FAILURE = "/opt/ml/output/failure" +SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE +SM_MASTER_ADDR = "algo-1" +SM_MASTER_PORT = 7777 - try: - args = _parse_args(sys_args) - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir +RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" +ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") +SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] +HIDDEN_VALUE = "******" - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) +SM_EFA_NCCL_INSTANCES = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.trn1.32xlarge", +] - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) +SM_EFA_RDMA_INSTANCES = [ + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.trn1.32xlarge", +] - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) +logger = get_logger() def _bootstrap_runtime_env_for_remote_function( @@ -287,5 +275,283 @@ def _parse_args(sys_args): return args +def log_key_value(key: str, value: str): + """Log a key-value pair, masking sensitive values if necessary.""" + if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): + logger.info("%s=%s", key, HIDDEN_VALUE) + elif isinstance(value, dict): + masked_value = mask_sensitive_info(value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + try: + decoded_value = json.loads(value) + if isinstance(decoded_value, dict): + masked_value = mask_sensitive_info(decoded_value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + logger.info("%s=%s", key, decoded_value) + except (json.JSONDecodeError, TypeError): + logger.info("%s=%s", key, value) + + +def log_env_variables(env_vars_dict: Dict[str, Any]): + """Log Environment Variables from the environment and an env_vars_dict.""" + for key, value in os.environ.items(): + log_key_value(key, value) + + for key, value in env_vars_dict.items(): + log_key_value(key, value) + + +def mask_sensitive_info(data): + """Recursively mask sensitive information in a dictionary.""" + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict): + data[k] = mask_sensitive_info(v) + elif isinstance(v, str) and any( + keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS + ): + data[k] = HIDDEN_VALUE + return data + + +def num_cpus() -> int: + """Return the number of CPUs available in the current container. + + Returns: + int: Number of CPUs available in the current container. + """ + return multiprocessing.cpu_count() + + +def num_gpus() -> int: + """Return the number of GPUs available in the current container. + + Returns: + int: Number of GPUs available in the current container. + """ + try: + cmd = ["nvidia-smi", "--list-gpus"] + output = subprocess.check_output(cmd).decode("utf-8") + return sum(1 for line in output.splitlines() if line.startswith("GPU ")) + except (OSError, subprocess.CalledProcessError): + logger.info("No GPUs detected (normal if no gpus installed)") + return 0 + + +def num_neurons() -> int: + """Return the number of neuron cores available in the current container. + + Returns: + int: Number of Neuron Cores available in the current container. + """ + try: + cmd = ["neuron-ls", "-j"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + j = json.loads(output) + neuron_cores = 0 + for item in j: + neuron_cores += item.get("nc_count", 0) + logger.info("Found %s neurons on this instance", neuron_cores) + return neuron_cores + except OSError: + logger.info("No Neurons detected (normal if no neurons installed)") + return 0 + except subprocess.CalledProcessError as e: + if e.output is not None: + try: + msg = e.output.decode("utf-8").partition("error=")[2] + logger.info( + "No Neurons detected (normal if no neurons installed). \ + If neuron installed then %s", + msg, + ) + except AttributeError: + logger.info("No Neurons detected (normal if no neurons installed)") + else: + logger.info("No Neurons detected (normal if no neurons installed)") + + return 0 + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def set_env( + resource_config: Dict[str, Any], + output_file: str = ENV_OUTPUT_FILE, +): + """Set environment variables for the training job container. + + Args: + resource_config (Dict[str, Any]): Resource configuration for the training job. + output_file (str): Output file to write the environment variables. + """ + # Constants + env_vars = { + "SM_MODEL_DIR": SM_MODEL_DIR, + "SM_INPUT_DIR": SM_INPUT_DIR, + "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, + "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, + "SM_OUTPUT_DIR": SM_OUTPUT_DIR, + "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, + "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, + "SM_MASTER_ADDR": SM_MASTER_ADDR, + "SM_MASTER_PORT": SM_MASTER_PORT, + } + + # Host Variables + current_host = resource_config["current_host"] + current_instance_type = resource_config["current_instance_type"] + hosts = resource_config["hosts"] + sorted_hosts = sorted(hosts) + + env_vars["SM_CURRENT_HOST"] = current_host + env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type + env_vars["SM_HOSTS"] = sorted_hosts + env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] + env_vars["SM_HOST_COUNT"] = len(sorted_hosts) + env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) + + env_vars["SM_NUM_CPUS"] = num_cpus() + env_vars["SM_NUM_GPUS"] = num_gpus() + env_vars["SM_NUM_NEURONS"] = num_neurons() + + # Misc. + env_vars["SM_RESOURCE_CONFIG"] = resource_config + + if int(env_vars["SM_NUM_GPUS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) + elif int(env_vars["SM_NUM_NEURONS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + else: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) + + # All Training Environment Variables + env_vars["SM_TRAINING_ENV"] = { + "current_host": env_vars["SM_CURRENT_HOST"], + "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], + "hosts": env_vars["SM_HOSTS"], + "host_count": env_vars["SM_HOST_COUNT"], + "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], + "master_addr": env_vars["SM_MASTER_ADDR"], + "master_port": env_vars["SM_MASTER_PORT"], + "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], + "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], + "input_dir": env_vars["SM_INPUT_DIR"], + "job_name": os.environ["TRAINING_JOB_NAME"], + "model_dir": env_vars["SM_MODEL_DIR"], + "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], + "num_cpus": env_vars["SM_NUM_CPUS"], + "num_gpus": env_vars["SM_NUM_GPUS"], + "num_neurons": env_vars["SM_NUM_NEURONS"], + "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], + "resource_config": env_vars["SM_RESOURCE_CONFIG"], + } + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + env_vars["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" + env_vars["RDMAV_FORK_SAFE"] = "1" + env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + env_vars["NCCL_PROTO"] = "simple" + + with open(output_file, "w") as f: + for key, value in env_vars.items(): + f.write(f"export {key}='{safe_serialize(value)}'\n") + + logger.info("Environment Variables:") + log_env_variables(env_vars_dict=env_vars) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + + exit_code = DEFAULT_FAILURE_CODE + + try: + args = _parse_args(sys_args) + client_python_version = args.client_python_version + client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version + job_conda_env = args.job_conda_env + pipeline_execution_id = args.pipeline_execution_id + dependency_settings = _DependencySettings.from_string(args.dependency_settings) + func_step_workspace = args.func_step_s3_dir + + conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") + + RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + + user = getpass.getuser() + if user != "root": + log_message = ( + "The job is running on non-root user: %s. Adding write permissions to the " + "following job output directories: %s." + ) + logger.info(log_message, user, JOB_OUTPUT_DIRS) + RuntimeEnvironmentManager().change_dir_permission( + dirs=JOB_OUTPUT_DIRS, new_permission="777" + ) + + if pipeline_execution_id: + _bootstrap_runtime_env_for_pipeline_step( + client_python_version, func_step_workspace, conda_env, dependency_settings + ) + else: + _bootstrap_runtime_env_for_remote_function( + client_python_version, conda_env, dependency_settings + ) + + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + + if os.path.exists(RESOURCE_CONFIG): + try: + logger.info("Found %s", RESOURCE_CONFIG) + with open(RESOURCE_CONFIG, "r") as f: + resource_config = json.load(f) + set_env(resource_config=resource_config) + except (json.JSONDecodeError, FileNotFoundError) as e: + # Optionally, you might want to log this error + logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) + + exit_code = SUCCESS_EXIT_CODE + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + finally: + sys.exit(exit_code) + + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 2717bb9afe..680bfc01df 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -825,7 +825,6 @@ def test_decorator_torchrun( dummy_container_without_error, gpu_instance_type, use_torchrun=False, - nproc_per_node=1, ): @remote( role=ROLE, @@ -834,7 +833,6 @@ def test_decorator_torchrun( sagemaker_session=sagemaker_session, keep_alive_period_in_seconds=60, use_torchrun=use_torchrun, - nproc_per_node=nproc_per_node, ) def divide(x, y): return x / y diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 888c634bfe..c7d35b6481 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -49,6 +49,11 @@ _prepare_dependencies_and_pre_execution_scripts, ) +from sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment import ( + set_env, + safe_serialize, +) + REGION = "us-west-2" TRAINING_JOB_ARN = "training-job-arn" @@ -68,6 +73,87 @@ EXPECTED_OUTPUT_URI = S3_URI + "/output" EXPECTED_DEPENDENCIES_URI = S3_URI + "/additional_dependencies/requirements.txt" +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_CPU = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.t3.xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='4' +export SM_NUM_GPUS='0' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + DESCRIBE_TRAINING_JOB_RESPONSE = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": "{}", @@ -79,6 +165,8 @@ "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, } +OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") + TEST_JOB_NAME = "my-job-name" TEST_PIPELINE_NAME = "my-pipeline" TEST_EXP_NAME = "my-exp-name" @@ -376,8 +464,6 @@ def test_start( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) @@ -392,7 +478,7 @@ def test_start( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + nproc_per_node=None, ) mock_dependency_upload.assert_called_once_with( @@ -510,8 +596,6 @@ def test_start_with_checkpoint_location( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with( @@ -665,8 +749,6 @@ def test_start_with_complete_job_settings( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY_ARN, - use_torchrun=False, - nproc_per_node=1, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -679,7 +761,7 @@ def test_start_with_complete_job_settings( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + nproc_per_node=None, ) mock_user_workspace_upload.assert_called_once_with( @@ -838,8 +920,6 @@ def test_get_train_args_under_pipeline_context( step_name=MOCKED_PIPELINE_CONFIG.step_name, func_step_s3_dir=MOCKED_PIPELINE_CONFIG.pipeline_build_time, ), - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function.save_pipeline_step_function.assert_called_once_with(mocked_serialized_data) @@ -853,7 +933,7 @@ def test_get_train_args_under_pipeline_context( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + nproc_per_node=None, ) mock_user_workspace_upload.assert_called_once_with( @@ -1029,7 +1109,7 @@ def test_start_with_spark( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + nproc_per_node=None, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1184,8 +1264,6 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): s3_base_uri=S3_URI, s3_kms_key=KMS_KEY_ARN, sagemaker_session=session(), - use_torchrun=False, - nproc_per_node=1, ) assert s3_path == mock_s3_upload.return_value @@ -1619,3 +1697,421 @@ def test_extend_spark_config_to_request( } ], ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + nproc_per_node=None, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + nproc_per_node=None, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_multi_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_count=2, + instance_type="ml.g5.2xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + nproc_per_node=None, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + nproc_per_node=None, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=2, + InstanceType="ml.g5.2xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_cpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.t3.xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.t3.xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_CPU) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +def _remove_extra_lines(string): + """Removes extra blank lines from a string.""" + return "\n".join([line for line in string.splitlines() if line.strip()])