Skip to content

Commit

Permalink
SageMaker @Remote function: Added multi-node functionality (#4984)
Browse files Browse the repository at this point in the history
* implemented multi-node distribution with @Remote function

* completed unit tests

* added distributed training with CPU and torchrun

* backwards compatibility nproc_per_node

* fixing code: permissions for non-root users, integration tests

* fixed docstyle

* refactor nproc_per_node for backwards compatibility

* refactor nproc_per_node for backwards compatibility

* pylint fix, newlines

* added unit tests for bootstrap_environment remote
  • Loading branch information
brunopistone authored Jan 16, 2025
1 parent a58654e commit ae3cc1c
Show file tree
Hide file tree
Showing 6 changed files with 908 additions and 91 deletions.
30 changes: 20 additions & 10 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def remote(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
):
"""Decorator for running the annotated function as a SageMaker training job.
Expand Down Expand Up @@ -284,8 +284,9 @@ def remote(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (int): Specifies the number of processes per node for distributed training.
Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""

def _remote(func):
Expand Down Expand Up @@ -325,9 +326,13 @@ def _remote(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):

if instance_count > 1 and not spark_config:
if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
):
raise ValueError(
"Remote function do not support training on multi instances. "
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "Please provide instance_count = 1"
)

Expand Down Expand Up @@ -532,7 +537,7 @@ def __init__(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
):
"""Constructor for RemoteExecutor
Expand Down Expand Up @@ -725,17 +730,22 @@ def __init__(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (int): Specifies the number of processes per node.
Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.max_parallel_jobs = max_parallel_jobs

if self.max_parallel_jobs <= 0:
raise ValueError("max_parallel_jobs must be greater than 0.")

if instance_count > 1 and not spark_config:
if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
):
raise ValueError(
"Remote function do not support training on multi instances. "
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "Please provide instance_count = 1"
)

Expand Down
6 changes: 0 additions & 6 deletions src/sagemaker/remote_function/core/stored_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def __init__(
hmac_key: str,
s3_kms_key: str = None,
context: Context = Context(),
use_torchrun: bool = False,
nproc_per_node: int = 1,
):
"""Construct a StoredFunction object.
Expand All @@ -67,16 +65,12 @@ def __init__(
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
context: Build or run context of a pipeline step.
use_torchrun: Whether to use torchrun for distributed training.
nproc_per_node: Number of processes per node for distributed training.
"""
self.sagemaker_session = sagemaker_session
self.s3_base_uri = s3_base_uri
self.s3_kms_key = s3_kms_key
self.hmac_key = hmac_key
self.context = context
self.use_torchrun = use_torchrun
self.nproc_per_node = nproc_per_node

self.func_upload_path = s3_path_join(
s3_base_uri, context.step_name, context.func_step_s3_dir
Expand Down
75 changes: 64 additions & 11 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
cat /opt/ml/input/config/resourceconfig.json
printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
source /opt/ml/input/sm_training.env
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
Expand All @@ -155,9 +158,11 @@
fi
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function\\n"
printf "INFO: python -m sagemaker.remote_function.invoke_function \\n"
python -m sagemaker.remote_function.invoke_function "$@"
fi
"""
Expand All @@ -175,9 +180,12 @@
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
cat /opt/ml/input/config/resourceconfig.json
printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
source /opt/ml/input/sm_training.env
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
Expand All @@ -200,11 +208,18 @@
fi
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
fi
"""

Expand Down Expand Up @@ -262,8 +277,8 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
use_torchrun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Initialize a _JobSettings instance which configures the remote job.
Expand Down Expand Up @@ -445,6 +460,13 @@ def __init__(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.sagemaker_session = sagemaker_session or Session()
self.environment_variables = resolve_value_from_config(
Expand Down Expand Up @@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
)

logger.info("Creating job: %s", job_name)

job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)

return _Job(
Expand Down Expand Up @@ -776,8 +799,6 @@ def compile(
s3_base_uri=s3_base_uri,
hmac_key=hmac_key,
s3_kms_key=job_settings.s3_kms_key,
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)
stored_function.save(func, *func_args, **func_kwargs)
else:
Expand All @@ -790,8 +811,6 @@ def compile(
step_name=step_compilation_context.step_name,
func_step_s3_dir=step_compilation_context.pipeline_build_time,
),
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)

stored_function.save_pipeline_step_function(serialized_data)
Expand Down Expand Up @@ -931,6 +950,7 @@ def compile(
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})

extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
extended_request = _extend_torchrun_to_request(extended_request, job_settings)

return extended_request

Expand Down Expand Up @@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts(
s3_kms_key: str,
sagemaker_session: Session,
use_torchrun: bool = False,
nproc_per_node: int = 1,
nproc_per_node: Optional[int] = None,
):
"""Copy runtime scripts to a folder and upload to S3.
Expand All @@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts(
use_torchrun (bool): Whether to use torchrun or not.
nproc_per_node (int): Number of processes per node.
nproc_per_node (Optional[int]): Number of processes per node
"""

from sagemaker.workflow.utilities import load_step_compilation_context
Expand All @@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts(

if use_torchrun:
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))

if nproc_per_node is not None and nproc_per_node > 0:
entry_point_script = entry_point_script.replace(
"$SM_NPROC_PER_NODE", str(nproc_per_node)
)

with open(entrypoint_script_path, "w", newline="\n") as file:
file.writelines(entry_point_script)
Expand Down Expand Up @@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration(
return config_file_s3_uri


def _extend_torchrun_to_request(
request_dict: Dict,
job_settings: _JobSettings,
) -> Dict:
"""Extend the create training job request with torchrun configuration.
Args:
request_dict (Dict): create training job request dict.
job_settings (_JobSettings): the job settings.
"""
use_torchrun = job_settings.use_torchrun
instance_count = job_settings.instance_count

if not use_torchrun:
return request_dict

if instance_count == 1:
return request_dict

extended_request = request_dict.copy()

for input_channel in extended_request["InputDataConfig"]:
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
if s3_data_source:
s3_data_source["S3DataDistributionType"] = "FullyReplicated"

return extended_request


def _extend_spark_config_to_request(
request_dict: Dict,
job_settings: _JobSettings,
Expand Down
Loading

0 comments on commit ae3cc1c

Please sign in to comment.