Spaces:
Runtime error
Runtime error
| # Copyright 2025 ByteDance and/or its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import yaml | |
| global_print_hparams = True | |
| hparams = {} | |
| class Args: | |
| def __init__(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| self.__setattr__(k, v) | |
| def override_config(old_config: dict, new_config: dict): | |
| if new_config.get('__replace', False): | |
| old_config.clear() | |
| for k, v in new_config.items(): | |
| if isinstance(v, dict) and k in old_config: | |
| override_config(old_config[k], new_config[k]) | |
| else: | |
| old_config[k] = v | |
| def traverse_dict(d, func, ctx): | |
| for k in list(d.keys()): | |
| v = d[k] | |
| if isinstance(v, dict): | |
| traverse_dict(v, func, ctx) | |
| else: | |
| d[k] = func(v, ctx) | |
| def parse_config(v, context=None): | |
| if context is None: | |
| context = {} | |
| if isinstance(v, str): | |
| if v.startswith('^'): | |
| return load_config(v[1:], [], set()) | |
| match = re.match(r"\${(.*)}", v) | |
| if match: | |
| expression = match.group(1) | |
| return eval(expression, {}, context) | |
| return v | |
| def remove_meta_key(d): | |
| for k in list(d.keys()): | |
| v = d[k] | |
| if isinstance(v, dict): | |
| remove_meta_key(v) | |
| else: | |
| if k[:2] == '__': | |
| del d[k] | |
| def load_config(config_fn, config_chains, loaded_configs): | |
| # deep first inheritance and avoid the second visit of one node | |
| if not os.path.exists(config_fn): | |
| print(f"| WARN: {config_fn} not exist.", ) | |
| return {} | |
| with open(config_fn) as f: | |
| hparams_ = yaml.safe_load(f) | |
| loaded_configs.add(config_fn) | |
| if 'base_config' in hparams_: | |
| ret_hparams = {} | |
| if not isinstance(hparams_['base_config'], list): | |
| hparams_['base_config'] = [hparams_['base_config']] | |
| for c in hparams_['base_config']: | |
| if c.startswith('.'): | |
| c = f'{os.path.dirname(config_fn)}/{c}' | |
| c = os.path.normpath(c) | |
| if c not in loaded_configs: | |
| override_config(ret_hparams, load_config(c, config_chains, loaded_configs)) | |
| override_config(ret_hparams, hparams_) | |
| else: | |
| ret_hparams = hparams_ | |
| config_chains.append(config_fn) | |
| return ret_hparams | |
| def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): | |
| if config == '' and exp_name == '': | |
| parser = argparse.ArgumentParser(description='') | |
| parser.add_argument('--config', type=str, default='', | |
| help='location of the data corpus') | |
| parser.add_argument('--exp_name', type=str, default='', help='exp_name') | |
| parser.add_argument('-hp', '--hparams', type=str, default='', | |
| help='location of the data corpus') | |
| parser.add_argument('--infer', action='store_true', help='infer') | |
| parser.add_argument('--validate', action='store_true', help='validate') | |
| parser.add_argument('--reset', action='store_true', help='reset hparams') | |
| parser.add_argument('--remove', action='store_true', help='remove old ckpt') | |
| parser.add_argument('--debug', action='store_true', help='debug') | |
| parser.add_argument('--start_rank', type=int, default=-1, | |
| help='the start rank id for DDP, keep 0 when single-machine multi-GPU') | |
| parser.add_argument('--world_size', type=int, default=-1, | |
| help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU') | |
| parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file') | |
| parser.add_argument('--master_addr', type=str, default='', help='') | |
| parser.add_argument('--ddp_dir', type=str, default='', help='') | |
| args, unknown = parser.parse_known_args() | |
| if print_hparams: | |
| print("| set_hparams Unknow hparams: ", unknown) | |
| else: | |
| args = Args(config=config, exp_name=exp_name, hparams=hparams_str, | |
| infer=False, validate=False, reset=False, debug=False, remove=False, | |
| start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='') | |
| global hparams | |
| assert args.config != '' or args.exp_name != '' | |
| if args.config != '': | |
| assert os.path.exists(args.config), f"{args.config} not exists" | |
| saved_hparams = {} | |
| args_work_dir = '' | |
| if args.exp_name != '': | |
| args_work_dir = f'{args.exp_name}' | |
| ckpt_config_path = f'{args_work_dir}/config.yaml' | |
| if os.path.exists(ckpt_config_path): | |
| with open(ckpt_config_path) as f: | |
| saved_hparams_ = yaml.safe_load(f) | |
| if saved_hparams_ is not None: | |
| saved_hparams.update(saved_hparams_) | |
| hparams_ = {} | |
| config_chains = [] | |
| if args.config != '': | |
| hparams_.update(load_config(args.config, config_chains, set())) | |
| if len(config_chains) > 1 and print_hparams: | |
| print('| Hparams chains: ', config_chains) | |
| if not args.reset: | |
| hparams_.update(saved_hparams) | |
| traverse_dict(hparams_, parse_config, hparams_) | |
| hparams_['work_dir'] = args_work_dir | |
| # Support config overriding in command line. Support list type config overriding. | |
| # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" | |
| if args.hparams != "": | |
| for new_hparam in args.hparams.split(","): | |
| k, v = new_hparam.split("=") | |
| v = v.strip("\'\" ") | |
| config_node = hparams_ | |
| for k_ in k.split(".")[:-1]: | |
| config_node = config_node[k_] | |
| k = k.split(".")[-1] | |
| if k in config_node: | |
| if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: | |
| if type(config_node[k]) == list: | |
| v = v.replace(" ", ",").replace('^', "\"") | |
| if '|' in v: | |
| tp = type(config_node[k][0]) if len(config_node[k]) else str | |
| config_node[k] = [tp(x) for x in v.split("|") if x != ''] | |
| continue | |
| config_node[k] = eval(v) | |
| else: | |
| config_node[k] = type(config_node[k])(v) | |
| else: | |
| config_node[k] = v | |
| try: | |
| config_node[k] = float(v) | |
| except: | |
| pass | |
| try: | |
| config_node[k] = int(v) | |
| except: | |
| pass | |
| if v.lower() in ['false', 'true']: | |
| config_node[k] = v.lower() == 'true' | |
| if args_work_dir != '' and not args.infer: | |
| os.makedirs(hparams_['work_dir'], exist_ok=True) | |
| hparams_['infer'] = args.infer | |
| hparams_['debug'] = args.debug | |
| hparams_['validate'] = args.validate | |
| hparams_['exp_name'] = args.exp_name | |
| hparams_['start_rank'] = args.start_rank # useful for multi-machine training | |
| hparams_['world_size'] = args.world_size | |
| hparams_['init_method'] = args.init_method | |
| hparams_['ddp_dir'] = args.ddp_dir | |
| hparams_['master_addr'] = args.master_addr | |
| remove_meta_key(hparams_) | |
| global global_print_hparams | |
| if global_hparams: | |
| hparams.clear() | |
| hparams.update(hparams_) | |
| if print_hparams and global_print_hparams and global_hparams: | |
| print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True)) | |
| # for i, (k, v) in enumerate(sorted(hparams_.items())): | |
| # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") | |
| global_print_hparams = False | |
| return hparams_ |