Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os | |
| from argparse import ArgumentParser | |
| def ssim_fid100_f1(metrics, fid_scale=100): | |
| ssim = metrics.loc['total', 'ssim']['mean'] | |
| fid = metrics.loc['total', 'fid']['mean'] | |
| fid_rel = max(0, fid_scale - fid) / fid_scale | |
| f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) | |
| return f1 | |
| def find_best_checkpoint(model_list, models_dir): | |
| with open(model_list) as f: | |
| models = [m.strip() for m in f.readlines()] | |
| with open(f'{model_list}_best', 'w') as f: | |
| for model in models: | |
| print(model) | |
| best_f1 = 0 | |
| best_epoch = 0 | |
| best_step = 0 | |
| with open(os.path.join(models_dir, model, 'train.log')) as fm: | |
| lines = fm.readlines() | |
| for line_index in range(len(lines)): | |
| line = lines[line_index] | |
| if 'Validation metrics after epoch' in line: | |
| sharp_index = line.index('#') | |
| cur_ep = line[sharp_index + 1:] | |
| comma_index = cur_ep.index(',') | |
| cur_ep = int(cur_ep[:comma_index]) | |
| total_index = line.index('total ') | |
| step = int(line[total_index:].split()[1].strip()) | |
| total_line = lines[line_index + 5] | |
| if not total_line.startswith('total'): | |
| continue | |
| words = total_line.strip().split() | |
| f1 = float(words[-1]) | |
| print(f'\tEpoch: {cur_ep}, f1={f1}') | |
| if f1 > best_f1: | |
| best_f1 = f1 | |
| best_epoch = cur_ep | |
| best_step = step | |
| f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n') | |
| if __name__ == '__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument('model_list') | |
| parser.add_argument('models_dir') | |
| args = parser.parse_args() | |
| find_best_checkpoint(args.model_list, args.models_dir) | |