Spaces:
Paused
Paused
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| A collection of utilities for ensuring that training can always occur. Heavily influenced by the | |
| [toma](https://github.com/BlackHC/toma) library. | |
| """ | |
| import functools | |
| import gc | |
| import inspect | |
| import torch | |
| from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available | |
| def clear_device_cache(garbage_collection=False): | |
| """ | |
| Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that | |
| this is a *considerable* slowdown and should be used sparingly. | |
| """ | |
| if garbage_collection: | |
| gc.collect() | |
| if is_xpu_available(): | |
| torch.xpu.empty_cache() | |
| elif is_mlu_available(): | |
| torch.mlu.empty_cache() | |
| elif is_npu_available(): | |
| torch.npu.empty_cache() | |
| elif is_mps_available(min_version="2.0"): | |
| torch.mps.empty_cache() | |
| else: | |
| torch.cuda.empty_cache() | |
| def release_memory(*objects): | |
| """ | |
| Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. | |
| Returned objects should be reassigned to the same variables. | |
| Args: | |
| objects (`Iterable`): | |
| An iterable of objects | |
| Returns: | |
| A list of `None` objects to replace `objects` | |
| Example: | |
| ```python | |
| >>> import torch | |
| >>> from accelerate.utils import release_memory | |
| >>> a = torch.ones(1000, 1000).cuda() | |
| >>> b = torch.ones(1000, 1000).cuda() | |
| >>> a, b = release_memory(a, b) | |
| ``` | |
| """ | |
| if not isinstance(objects, list): | |
| objects = list(objects) | |
| for i in range(len(objects)): | |
| objects[i] = None | |
| clear_device_cache(garbage_collection=True) | |
| return objects | |
| def should_reduce_batch_size(exception: Exception) -> bool: | |
| """ | |
| Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory | |
| Args: | |
| exception (`Exception`): | |
| An exception | |
| """ | |
| _statements = [ | |
| "CUDA out of memory.", # CUDA OOM | |
| "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU | |
| "DefaultCPUAllocator: can't allocate memory", # CPU OOM | |
| ] | |
| if isinstance(exception, RuntimeError) and len(exception.args) == 1: | |
| return any(err in exception.args[0] for err in _statements) | |
| return False | |
| def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): | |
| """ | |
| A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or | |
| CUDNN, the batch size is cut in half and passed to `function` | |
| `function` must take in a `batch_size` parameter as its first argument. | |
| Args: | |
| function (`callable`, *optional*): | |
| A function to wrap | |
| starting_batch_size (`int`, *optional*): | |
| The batch size to try and fit into memory | |
| Example: | |
| ```python | |
| >>> from accelerate.utils import find_executable_batch_size | |
| >>> @find_executable_batch_size(starting_batch_size=128) | |
| ... def train(batch_size, model, optimizer): | |
| ... ... | |
| >>> train(model, optimizer) | |
| ``` | |
| """ | |
| if function is None: | |
| return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) | |
| batch_size = starting_batch_size | |
| def decorator(*args, **kwargs): | |
| nonlocal batch_size | |
| clear_device_cache(garbage_collection=True) | |
| params = list(inspect.signature(function).parameters.keys()) | |
| # Guard against user error | |
| if len(params) < (len(args) + 1): | |
| arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) | |
| raise TypeError( | |
| f"Batch size was passed into `{function.__name__}` as the first argument when called." | |
| f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" | |
| ) | |
| while True: | |
| if batch_size == 0: | |
| raise RuntimeError("No executable batch size found, reached zero.") | |
| try: | |
| return function(batch_size, *args, **kwargs) | |
| except Exception as e: | |
| if should_reduce_batch_size(e): | |
| clear_device_cache(garbage_collection=True) | |
| batch_size //= 2 | |
| else: | |
| raise | |
| return decorator | |