Spaces:
Runtime error
Runtime error
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| import glob | |
| import hashlib | |
| import importlib | |
| import os | |
| import re | |
| import shutil | |
| import uuid | |
| import torch | |
| import torch.utils.cpp_extension | |
| from torch.utils.file_baton import FileBaton | |
| #---------------------------------------------------------------------------- | |
| # Global options. | |
| verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' | |
| #---------------------------------------------------------------------------- | |
| # Internal helper funcs. | |
| def _find_compiler_bindir(): | |
| patterns = [ | |
| 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', | |
| 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', | |
| 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', | |
| 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', | |
| ] | |
| for pattern in patterns: | |
| matches = sorted(glob.glob(pattern)) | |
| if len(matches): | |
| return matches[-1] | |
| return None | |
| #---------------------------------------------------------------------------- | |
| def _get_mangled_gpu_name(): | |
| name = torch.cuda.get_device_name().lower() | |
| out = [] | |
| for c in name: | |
| if re.match('[a-z0-9_-]+', c): | |
| out.append(c) | |
| else: | |
| out.append('-') | |
| return ''.join(out) | |
| #---------------------------------------------------------------------------- | |
| # Main entry point for compiling and loading C++/CUDA plugins. | |
| _cached_plugins = dict() | |
| def get_plugin(module_name, | |
| sources, | |
| headers=None, | |
| source_dir=None, | |
| **build_kwargs): | |
| assert verbosity in ['none', 'brief', 'full'] | |
| if headers is None: | |
| headers = [] | |
| if source_dir is not None: | |
| sources = [os.path.join(source_dir, fname) for fname in sources] | |
| headers = [os.path.join(source_dir, fname) for fname in headers] | |
| # Already cached? | |
| if module_name in _cached_plugins: | |
| return _cached_plugins[module_name] | |
| # Print status. | |
| if verbosity == 'full': | |
| print(f'Setting up PyTorch plugin "{module_name}"...') | |
| elif verbosity == 'brief': | |
| print(f'Setting up PyTorch plugin "{module_name}"... ', | |
| end='', | |
| flush=True) | |
| verbose_build = (verbosity == 'full') | |
| # Compile and load. | |
| try: # pylint: disable=too-many-nested-blocks | |
| # Make sure we can find the necessary compiler binaries. | |
| if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: | |
| compiler_bindir = _find_compiler_bindir() | |
| if compiler_bindir is None: | |
| raise RuntimeError( | |
| f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".' | |
| ) | |
| os.environ['PATH'] += ';' + compiler_bindir | |
| # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either | |
| # break the build or unnecessarily restrict what's available to nvcc. | |
| # Unset it to let nvcc decide based on what's available on the | |
| # machine. | |
| os.environ['TORCH_CUDA_ARCH_LIST'] = '' | |
| # Incremental build md5sum trickery. Copies all the input source files | |
| # into a cached build directory under a combined md5 digest of the input | |
| # source files. Copying is done only if the combined digest has changed. | |
| # This keeps input file timestamps and filenames the same as in previous | |
| # extension builds, allowing for fast incremental rebuilds. | |
| # | |
| # This optimization is done only in case all the source files reside in | |
| # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR | |
| # environment variable is set (we take this as a signal that the user | |
| # actually cares about this.) | |
| # | |
| # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work | |
| # around the *.cu dependency bug in ninja config. | |
| # | |
| all_source_files = sorted(sources + headers) | |
| all_source_dirs = set( | |
| os.path.dirname(fname) for fname in all_source_files) | |
| if len(all_source_dirs | |
| ) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): | |
| # Compute combined hash digest for all source files. | |
| hash_md5 = hashlib.md5() | |
| for src in all_source_files: | |
| with open(src, 'rb') as f: | |
| hash_md5.update(f.read()) | |
| # Select cached build directory name. | |
| source_digest = hash_md5.hexdigest() | |
| build_top_dir = torch.utils.cpp_extension._get_build_directory( | |
| module_name, verbose=verbose_build) # pylint: disable=protected-access | |
| cached_build_dir = os.path.join( | |
| build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') | |
| if not os.path.isdir(cached_build_dir): | |
| tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' | |
| os.makedirs(tmpdir) | |
| for src in all_source_files: | |
| shutil.copyfile( | |
| src, os.path.join(tmpdir, os.path.basename(src))) | |
| try: | |
| os.replace(tmpdir, cached_build_dir) # atomic | |
| except OSError: | |
| # source directory already exists, delete tmpdir and its contents. | |
| shutil.rmtree(tmpdir) | |
| if not os.path.isdir(cached_build_dir): raise | |
| # Compile. | |
| cached_sources = [ | |
| os.path.join(cached_build_dir, os.path.basename(fname)) | |
| for fname in sources | |
| ] | |
| torch.utils.cpp_extension.load(name=module_name, | |
| build_directory=cached_build_dir, | |
| verbose=verbose_build, | |
| sources=cached_sources, | |
| **build_kwargs) | |
| else: | |
| torch.utils.cpp_extension.load(name=module_name, | |
| verbose=verbose_build, | |
| sources=sources, | |
| **build_kwargs) | |
| # Load. | |
| module = importlib.import_module(module_name) | |
| except: | |
| if verbosity == 'brief': | |
| print('Failed!') | |
| raise | |
| # Print status and add to cache dict. | |
| if verbosity == 'full': | |
| print(f'Done setting up PyTorch plugin "{module_name}".') | |
| elif verbosity == 'brief': | |
| print('Done.') | |
| _cached_plugins[module_name] = module | |
| return module | |
| #---------------------------------------------------------------------------- | |