Spaces:
Running
Running
| import os | |
| import sys | |
| import logging | |
| import torch | |
| MATPLOTLIB_FLAG = False | |
| logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | |
| logger = logging | |
| def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None): | |
| assert os.path.isfile(checkpoint_path) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') | |
| iteration = checkpoint_dict['iteration'] | |
| learning_rate = checkpoint_dict['learning_rate'] | |
| if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: | |
| optimizer.load_state_dict(checkpoint_dict['optimizer']) | |
| elif optimizer is None and not skip_optimizer: | |
| # else: #Disable this line if Infer ,and enable the line upper | |
| new_opt_dict = optimizer.state_dict() | |
| new_opt_dict_params = new_opt_dict['param_groups'][0]['params'] | |
| new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups'] | |
| new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params | |
| optimizer.load_state_dict(new_opt_dict) | |
| saved_state_dict = checkpoint_dict['model'] | |
| if hasattr(model, 'module'): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| try: | |
| # assert "emb_g" not in k | |
| # print("load", k) | |
| new_state_dict[k] = saved_state_dict[k] | |
| assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) | |
| except: | |
| # Handle legacy model versions and provide appropriate warnings | |
| if "ja_bert_proj" in k: | |
| v = torch.zeros_like(v) | |
| if version is None: | |
| logger.error(f"{k} is not in the checkpoint") | |
| logger.warning( | |
| f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"") | |
| elif "flow.flows.0.enc.attn_layers.3" in k: | |
| logger.error(f"{k} is not in the checkpoint") | |
| logger.warning( | |
| f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.") | |
| else: | |
| logger.error(f"{k} is not in the checkpoint") | |
| new_state_dict[k] = v | |
| if hasattr(model, 'module'): | |
| model.module.load_state_dict(new_state_dict, strict=False) | |
| else: | |
| model.load_state_dict(new_state_dict, strict=False) | |
| # print("load ") | |
| logger.info("Loaded checkpoint '{}' (iteration {})".format( | |
| checkpoint_path, iteration)) | |
| return model, optimizer, learning_rate, iteration | |
| def process_legacy_versions(hps): | |
| version = getattr(hps, "version", getattr(hps.data, "version", None)) | |
| if version: | |
| prefix = version[0].lower() | |
| if prefix == "v": | |
| version = version[1:] | |
| return version | |