Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Useful class for Experiment tracking, and ensuring code is | |
| saved alongside files. | |
| """ # fmt: skip | |
| import datetime | |
| import os | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import typing | |
| from pathlib import Path | |
| import randomname | |
| class Experiment: | |
| """This class contains utilities for managing experiments. | |
| It is a context manager, that when you enter it, changes | |
| your directory to a specified experiment folder (which | |
| optionally can have an automatically generated experiment | |
| name, or a specified one), and changes the CUDA device used | |
| to the specified device (or devices). | |
| Parameters | |
| ---------- | |
| exp_directory : str | |
| Folder where all experiments are saved, by default "runs/". | |
| exp_name : str, optional | |
| Name of the experiment, by default uses the current time, date, and | |
| hostname to save. | |
| """ | |
| def __init__( | |
| self, | |
| exp_directory: str = "runs/", | |
| exp_name: str = None, | |
| ): | |
| if exp_name is None: | |
| exp_name = self.generate_exp_name() | |
| exp_dir = Path(exp_directory) / exp_name | |
| exp_dir.mkdir(parents=True, exist_ok=True) | |
| self.exp_dir = exp_dir | |
| self.exp_name = exp_name | |
| self.git_tracked_files = ( | |
| subprocess.check_output( | |
| shlex.split("git ls-tree --full-tree --name-only -r HEAD") | |
| ) | |
| .decode("utf-8") | |
| .splitlines() | |
| ) | |
| self.parent_directory = Path(".").absolute() | |
| def __enter__(self): | |
| self.prev_dir = os.getcwd() | |
| os.chdir(self.exp_dir) | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| os.chdir(self.prev_dir) | |
| def generate_exp_name(): | |
| """Generates a random experiment name based on the date | |
| and a randomly generated adjective-noun tuple. | |
| Returns | |
| ------- | |
| str | |
| Randomly generated experiment name. | |
| """ | |
| date = datetime.datetime.now().strftime("%y%m%d") | |
| name = f"{date}-{randomname.get_name()}" | |
| return name | |
| def snapshot(self, filter_fn: typing.Callable = lambda f: True): | |
| """Captures a full snapshot of all the files tracked by git at the time | |
| the experiment is run. It also captures the diff against the committed | |
| code as a separate file. | |
| Parameters | |
| ---------- | |
| filter_fn : typing.Callable, optional | |
| Function that can be used to exclude some files | |
| from the snapshot, by default accepts all files | |
| """ | |
| for f in self.git_tracked_files: | |
| if filter_fn(f): | |
| Path(f).parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copyfile(self.parent_directory / f, f) | |