Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import os | |
| from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES | |
| from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType | |
| from ...utils.imports import is_boto3_available | |
| from .config_args import SageMakerConfig | |
| from .config_utils import ( | |
| DYNAMO_BACKENDS, | |
| _ask_field, | |
| _ask_options, | |
| _convert_dynamo_backend, | |
| _convert_mixed_precision, | |
| _convert_sagemaker_distributed_mode, | |
| _convert_yes_no_to_bool, | |
| ) | |
| if is_boto3_available(): | |
| import boto3 # noqa: F401 | |
| def _create_iam_role_for_sagemaker(role_name): | |
| iam_client = boto3.client("iam") | |
| sagemaker_trust_policy = { | |
| "Version": "2012-10-17", | |
| "Statement": [ | |
| {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"} | |
| ], | |
| } | |
| try: | |
| # create the role, associated with the chosen trust policy | |
| iam_client.create_role( | |
| RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2) | |
| ) | |
| policy_document = { | |
| "Version": "2012-10-17", | |
| "Statement": [ | |
| { | |
| "Effect": "Allow", | |
| "Action": [ | |
| "sagemaker:*", | |
| "ecr:GetDownloadUrlForLayer", | |
| "ecr:BatchGetImage", | |
| "ecr:BatchCheckLayerAvailability", | |
| "ecr:GetAuthorizationToken", | |
| "cloudwatch:PutMetricData", | |
| "cloudwatch:GetMetricData", | |
| "cloudwatch:GetMetricStatistics", | |
| "cloudwatch:ListMetrics", | |
| "logs:CreateLogGroup", | |
| "logs:CreateLogStream", | |
| "logs:DescribeLogStreams", | |
| "logs:PutLogEvents", | |
| "logs:GetLogEvents", | |
| "s3:CreateBucket", | |
| "s3:ListBucket", | |
| "s3:GetBucketLocation", | |
| "s3:GetObject", | |
| "s3:PutObject", | |
| ], | |
| "Resource": "*", | |
| } | |
| ], | |
| } | |
| # attach policy to role | |
| iam_client.put_role_policy( | |
| RoleName=role_name, | |
| PolicyName=f"{role_name}_policy_permission", | |
| PolicyDocument=json.dumps(policy_document, indent=2), | |
| ) | |
| except iam_client.exceptions.EntityAlreadyExistsException: | |
| print(f"role {role_name} already exists. Using existing one") | |
| def _get_iam_role_arn(role_name): | |
| iam_client = boto3.client("iam") | |
| return iam_client.get_role(RoleName=role_name)["Role"]["Arn"] | |
| def get_sagemaker_input(): | |
| credentials_configuration = _ask_options( | |
| "How do you want to authorize?", | |
| ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "], | |
| int, | |
| ) | |
| aws_profile = None | |
| if credentials_configuration == 0: | |
| aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default") | |
| os.environ["AWS_PROFILE"] = aws_profile | |
| else: | |
| print( | |
| "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with," | |
| "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`" | |
| ) | |
| aws_access_key_id = _ask_field("AWS Access Key ID: ") | |
| os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id | |
| aws_secret_access_key = _ask_field("AWS Secret Access Key: ") | |
| os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key | |
| aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1") | |
| os.environ["AWS_DEFAULT_REGION"] = aws_region | |
| role_management = _ask_options( | |
| "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?", | |
| ["Provide IAM Role name", "Create new IAM role using credentials"], | |
| int, | |
| ) | |
| if role_management == 0: | |
| iam_role_name = _ask_field("Enter your IAM role name: ") | |
| else: | |
| iam_role_name = "accelerate_sagemaker_execution_role" | |
| print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials') | |
| _create_iam_role_for_sagemaker(iam_role_name) | |
| is_custom_docker_image = _ask_field( | |
| "Do you want to use custom Docker image? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| docker_image = None | |
| if is_custom_docker_image: | |
| docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower()) | |
| is_sagemaker_inputs_enabled = _ask_field( | |
| "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| sagemaker_inputs_file = None | |
| if is_sagemaker_inputs_enabled: | |
| sagemaker_inputs_file = _ask_field( | |
| "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ", | |
| lambda x: str(x).lower(), | |
| ) | |
| is_sagemaker_metrics_enabled = _ask_field( | |
| "Do you want to enable SageMaker metrics? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| sagemaker_metrics_file = None | |
| if is_sagemaker_metrics_enabled: | |
| sagemaker_metrics_file = _ask_field( | |
| "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ", | |
| lambda x: str(x).lower(), | |
| ) | |
| distributed_type = _ask_options( | |
| "What is the distributed mode?", | |
| ["No distributed training", "Data parallelism"], | |
| _convert_sagemaker_distributed_mode, | |
| ) | |
| dynamo_config = {} | |
| use_dynamo = _ask_field( | |
| "Do you wish to optimize your script with torch dynamo?[yes/NO]:", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| if use_dynamo: | |
| prefix = "dynamo_" | |
| dynamo_config[prefix + "backend"] = _ask_options( | |
| "Which dynamo backend would you like to use?", | |
| [x.lower() for x in DYNAMO_BACKENDS], | |
| _convert_dynamo_backend, | |
| default=2, | |
| ) | |
| use_custom_options = _ask_field( | |
| "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| if use_custom_options: | |
| dynamo_config[prefix + "mode"] = _ask_options( | |
| "Which mode do you want to use?", | |
| TORCH_DYNAMO_MODES, | |
| lambda x: TORCH_DYNAMO_MODES[int(x)], | |
| default="default", | |
| ) | |
| dynamo_config[prefix + "use_fullgraph"] = _ask_field( | |
| "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| dynamo_config[prefix + "use_dynamic"] = _ask_field( | |
| "Do you want to enable dynamic shape tracing? [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| ec2_instance_query = "Which EC2 instance type you want to use for your training?" | |
| if distributed_type != SageMakerDistributedType.NO: | |
| ec2_instance_type = _ask_options( | |
| ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)] | |
| ) | |
| else: | |
| ec2_instance_query += "? [ml.p3.2xlarge]:" | |
| ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge") | |
| debug = False | |
| if distributed_type != SageMakerDistributedType.NO: | |
| debug = _ask_field( | |
| "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", | |
| _convert_yes_no_to_bool, | |
| default=False, | |
| error_message="Please enter yes or no.", | |
| ) | |
| num_machines = 1 | |
| if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL): | |
| num_machines = _ask_field( | |
| "How many machines do you want use? [1]: ", | |
| int, | |
| default=1, | |
| ) | |
| mixed_precision = _ask_options( | |
| "Do you wish to use FP16 or BF16 (mixed precision)?", | |
| ["no", "fp16", "bf16", "fp8"], | |
| _convert_mixed_precision, | |
| ) | |
| if use_dynamo and mixed_precision == "no": | |
| print( | |
| "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." | |
| ) | |
| return SageMakerConfig( | |
| image_uri=docker_image, | |
| compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER, | |
| distributed_type=distributed_type, | |
| use_cpu=False, | |
| dynamo_config=dynamo_config, | |
| ec2_instance_type=ec2_instance_type, | |
| profile=aws_profile, | |
| region=aws_region, | |
| iam_role_name=iam_role_name, | |
| mixed_precision=mixed_precision, | |
| num_machines=num_machines, | |
| sagemaker_inputs_file=sagemaker_inputs_file, | |
| sagemaker_metrics_file=sagemaker_metrics_file, | |
| debug=debug, | |
| ) | |