Spaces:
Runtime error
Runtime error
| import yaml | |
| import json | |
| import argparse | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def load_config_dict_to_opt(opt, config_dict): | |
| """ | |
| Load the key, value pairs from config_dict to opt, overriding existing values in opt | |
| if there is any. | |
| """ | |
| if not isinstance(config_dict, dict): | |
| raise TypeError("Config must be a Python dictionary") | |
| for k, v in config_dict.items(): | |
| k_parts = k.split('.') | |
| pointer = opt | |
| for k_part in k_parts[:-1]: | |
| if k_part not in pointer: | |
| pointer[k_part] = {} | |
| pointer = pointer[k_part] | |
| assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." | |
| ori_value = pointer.get(k_parts[-1]) | |
| pointer[k_parts[-1]] = v | |
| if ori_value: | |
| logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") | |
| def load_opt_from_config_files(conf_file): | |
| """ | |
| Load opt from the config files, settings in later files can override those in previous files. | |
| Args: | |
| conf_files: config file path | |
| Returns: | |
| dict: a dictionary of opt settings | |
| """ | |
| opt = {} | |
| with open(conf_file, encoding='utf-8') as f: | |
| config_dict = yaml.safe_load(f) | |
| load_config_dict_to_opt(opt, config_dict) | |
| return opt | |
| def load_opt_command(args): | |
| parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.') | |
| parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') | |
| parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).') | |
| parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') | |
| parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER) | |
| cmdline_args = parser.parse_args() if not args else parser.parse_args(args) | |
| opt = load_opt_from_config_files(cmdline_args.conf_files) | |
| if cmdline_args.config_overrides: | |
| config_overrides_string = ' '.join(cmdline_args.config_overrides) | |
| logger.warning(f"Command line config overrides: {config_overrides_string}") | |
| config_dict = json.loads(config_overrides_string) | |
| load_config_dict_to_opt(opt, config_dict) | |
| if cmdline_args.overrides: | |
| assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value" | |
| keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] | |
| vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] | |
| vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] | |
| types = [] | |
| for key in keys: | |
| key = key.split('.') | |
| ele = opt.copy() | |
| while len(key) > 0: | |
| ele = ele[key.pop(0)] | |
| types.append(type(ele)) | |
| config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} | |
| load_config_dict_to_opt(opt, config_dict) | |
| # combine cmdline_args into opt dictionary | |
| for key, val in cmdline_args.__dict__.items(): | |
| if val is not None: | |
| opt[key] = val | |
| return opt, cmdline_args | |
| def save_opt_to_json(opt, conf_file): | |
| with open(conf_file, 'w', encoding='utf-8') as f: | |
| json.dump(opt, f, indent=4) | |
| def save_opt_to_yaml(opt, conf_file): | |
| with open(conf_file, 'w', encoding='utf-8') as f: | |
| yaml.dump(opt, f) | |