Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import List, Optional, Union | |
| import yaml | |
| from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType | |
| from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION | |
| hf_cache_home = os.path.expanduser( | |
| os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface")) | |
| ) | |
| cache_dir = os.path.join(hf_cache_home, "accelerate") | |
| default_json_config_file = os.path.join(cache_dir, "default_config.yaml") | |
| default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml") | |
| # For backward compatibility: the default config is the json one if it's the only existing file. | |
| if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file): | |
| default_config_file = default_yaml_config_file | |
| else: | |
| default_config_file = default_json_config_file | |
| def load_config_from_file(config_file): | |
| if config_file is not None: | |
| if not os.path.isfile(config_file): | |
| raise FileNotFoundError( | |
| f"The passed configuration file `{config_file}` does not exist. " | |
| "Please pass an existing file to `accelerate launch`, or use the default one " | |
| "created through `accelerate config` and run `accelerate launch` " | |
| "without the `--config_file` argument." | |
| ) | |
| else: | |
| config_file = default_config_file | |
| with open(config_file, encoding="utf-8") as f: | |
| if config_file.endswith(".json"): | |
| if ( | |
| json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) | |
| == ComputeEnvironment.LOCAL_MACHINE | |
| ): | |
| config_class = ClusterConfig | |
| else: | |
| config_class = SageMakerConfig | |
| return config_class.from_json_file(json_file=config_file) | |
| else: | |
| if ( | |
| yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) | |
| == ComputeEnvironment.LOCAL_MACHINE | |
| ): | |
| config_class = ClusterConfig | |
| else: | |
| config_class = SageMakerConfig | |
| return config_class.from_yaml_file(yaml_file=config_file) | |
| class BaseConfig: | |
| compute_environment: ComputeEnvironment | |
| distributed_type: Union[DistributedType, SageMakerDistributedType] | |
| mixed_precision: str | |
| use_cpu: bool | |
| debug: bool | |
| def to_dict(self): | |
| result = self.__dict__ | |
| # For serialization, it's best to convert Enums to strings (or their underlying value type). | |
| for key, value in result.items(): | |
| if isinstance(value, Enum): | |
| result[key] = value.value | |
| if isinstance(value, dict) and not bool(value): | |
| result[key] = None | |
| result = {k: v for k, v in result.items() if v is not None} | |
| return result | |
| def from_json_file(cls, json_file=None): | |
| json_file = default_json_config_file if json_file is None else json_file | |
| with open(json_file, encoding="utf-8") as f: | |
| config_dict = json.load(f) | |
| if "compute_environment" not in config_dict: | |
| config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE | |
| if "mixed_precision" not in config_dict: | |
| config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None | |
| if "fp16" in config_dict: # Convert the config to the new format. | |
| del config_dict["fp16"] | |
| if "dynamo_backend" in config_dict: # Convert the config to the new format. | |
| dynamo_backend = config_dict.pop("dynamo_backend") | |
| config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} | |
| if "use_cpu" not in config_dict: | |
| config_dict["use_cpu"] = False | |
| if "debug" not in config_dict: | |
| config_dict["debug"] = False | |
| if "enable_cpu_affinity" not in config_dict: | |
| config_dict["enable_cpu_affinity"] = False | |
| extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) | |
| if len(extra_keys) > 0: | |
| raise ValueError( | |
| f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" | |
| " version or fix (and potentially remove) these keys from your config file." | |
| ) | |
| return cls(**config_dict) | |
| def to_json_file(self, json_file): | |
| with open(json_file, "w", encoding="utf-8") as f: | |
| content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |
| f.write(content) | |
| def from_yaml_file(cls, yaml_file=None): | |
| yaml_file = default_yaml_config_file if yaml_file is None else yaml_file | |
| with open(yaml_file, encoding="utf-8") as f: | |
| config_dict = yaml.safe_load(f) | |
| if "compute_environment" not in config_dict: | |
| config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE | |
| if "mixed_precision" not in config_dict: | |
| config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None | |
| if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]: | |
| config_dict["mixed_precision"] = "no" | |
| if "fp16" in config_dict: # Convert the config to the new format. | |
| del config_dict["fp16"] | |
| if "dynamo_backend" in config_dict: # Convert the config to the new format. | |
| dynamo_backend = config_dict.pop("dynamo_backend") | |
| config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} | |
| if "use_cpu" not in config_dict: | |
| config_dict["use_cpu"] = False | |
| if "debug" not in config_dict: | |
| config_dict["debug"] = False | |
| if "enable_cpu_affinity" not in config_dict: | |
| config_dict["enable_cpu_affinity"] = False | |
| extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) | |
| if len(extra_keys) > 0: | |
| raise ValueError( | |
| f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" | |
| " version or fix (and potentially remove) these keys from your config file." | |
| ) | |
| return cls(**config_dict) | |
| def to_yaml_file(self, yaml_file): | |
| with open(yaml_file, "w", encoding="utf-8") as f: | |
| yaml.safe_dump(self.to_dict(), f) | |
| def __post_init__(self): | |
| if isinstance(self.compute_environment, str): | |
| self.compute_environment = ComputeEnvironment(self.compute_environment) | |
| if isinstance(self.distributed_type, str): | |
| if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: | |
| self.distributed_type = SageMakerDistributedType(self.distributed_type) | |
| else: | |
| self.distributed_type = DistributedType(self.distributed_type) | |
| if getattr(self, "dynamo_config", None) is None: | |
| self.dynamo_config = {} | |
| class ClusterConfig(BaseConfig): | |
| num_processes: int | |
| machine_rank: int = 0 | |
| num_machines: int = 1 | |
| gpu_ids: Optional[str] = None | |
| main_process_ip: Optional[str] = None | |
| main_process_port: Optional[int] = None | |
| rdzv_backend: Optional[str] = "static" | |
| same_network: Optional[bool] = False | |
| main_training_function: str = "main" | |
| enable_cpu_affinity: bool = False | |
| # args for deepspeed_plugin | |
| deepspeed_config: dict = None | |
| # args for fsdp | |
| fsdp_config: dict = None | |
| # args for megatron_lm | |
| megatron_lm_config: dict = None | |
| # args for ipex | |
| ipex_config: dict = None | |
| # args for mpirun | |
| mpirun_config: dict = None | |
| # args for TPU | |
| downcast_bf16: bool = False | |
| # args for TPU pods | |
| tpu_name: str = None | |
| tpu_zone: str = None | |
| tpu_use_cluster: bool = False | |
| tpu_use_sudo: bool = False | |
| command_file: str = None | |
| commands: List[str] = None | |
| tpu_vm: List[str] = None | |
| tpu_env: List[str] = None | |
| # args for dynamo | |
| dynamo_config: dict = None | |
| def __post_init__(self): | |
| if self.deepspeed_config is None: | |
| self.deepspeed_config = {} | |
| if self.fsdp_config is None: | |
| self.fsdp_config = {} | |
| if self.megatron_lm_config is None: | |
| self.megatron_lm_config = {} | |
| if self.ipex_config is None: | |
| self.ipex_config = {} | |
| if self.mpirun_config is None: | |
| self.mpirun_config = {} | |
| return super().__post_init__() | |
| class SageMakerConfig(BaseConfig): | |
| ec2_instance_type: str | |
| iam_role_name: str | |
| image_uri: Optional[str] = None | |
| profile: Optional[str] = None | |
| region: str = "us-east-1" | |
| num_machines: int = 1 | |
| gpu_ids: str = "all" | |
| base_job_name: str = f"accelerate-sagemaker-{num_machines}" | |
| pytorch_version: str = SAGEMAKER_PYTORCH_VERSION | |
| transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION | |
| py_version: str = SAGEMAKER_PYTHON_VERSION | |
| sagemaker_inputs_file: str = None | |
| sagemaker_metrics_file: str = None | |
| additional_args: dict = None | |
| dynamo_config: dict = None | |
| enable_cpu_affinity: bool = False | |