Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| class TorchAutocast: | |
| """TorchAutocast utility class. | |
| Allows you to enable and disable autocast. This is specially useful | |
| when dealing with different architectures and clusters with different | |
| levels of support. | |
| Args: | |
| enabled (bool): Whether to enable torch.autocast or not. | |
| args: Additional args for torch.autocast. | |
| kwargs: Additional kwargs for torch.autocast | |
| """ | |
| def __init__(self, enabled: bool, *args, **kwargs): | |
| self.autocast = torch.autocast(*args, **kwargs) if enabled else None | |
| def __enter__(self): | |
| if self.autocast is None: | |
| return | |
| try: | |
| self.autocast.__enter__() | |
| except RuntimeError: | |
| device = self.autocast.device | |
| dtype = self.autocast.fast_dtype | |
| raise RuntimeError( | |
| f"There was an error autocasting with dtype={dtype} device={device}\n" | |
| "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" | |
| ) | |
| def __exit__(self, *args, **kwargs): | |
| if self.autocast is None: | |
| return | |
| self.autocast.__exit__(*args, **kwargs) | |