Spaces:
Runtime error
Runtime error
| import json | |
| from dataclasses import dataclass, make_dataclass, asdict, field | |
| from typing import List | |
| class Config: | |
| device: str = "cpu" | |
| # paths | |
| config: str = "config/default.json" | |
| loader: str = "loaders/google_sc.py" | |
| dataset: str = "" | |
| indices: str = "" | |
| model_dir: str = "default_model_dir" | |
| validation_datasets: List = field(default_factory=lambda: []) | |
| # training settings/hyperparams | |
| batch_size: int = 4 | |
| verbose: bool = True | |
| # pretrained models | |
| encoder_model_id: str = "distilroberta-base" | |
| # reward settings | |
| rewards: tuple = ( | |
| "FluencyReward", | |
| "CrossSimilarityReward", | |
| ) | |
| def load_config(args): | |
| """ | |
| Loads settings into a dataclass object, from the following sources: | |
| - defaults defined above by DefaultConfig | |
| - args.config (path to a JSON config file) | |
| - args (from using argparse in a script) | |
| Overlapping fields are overwritten in that order. | |
| Example usage: | |
| (...) | |
| args = load_config(parser.parse_args()) | |
| args.batch_size | |
| """ | |
| config = asdict(Config()) | |
| if args.config: | |
| with open(args.config) as f: | |
| config.update(json.load(f)) | |
| config.update(args.__dict__) | |
| Config_ = make_dataclass("Config", fields=config.items()) | |
| return Config_(**config) | |