Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import torch | |
| from PIL.Image import Image | |
| from collections import OrderedDict | |
| from scepter.modules.utils.distribute import we | |
| from scepter.modules.utils.config import Config | |
| from scepter.modules.utils.logger import get_logger | |
| from scepter.studio.utils.env import get_available_memory | |
| from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS | |
| from scepter.modules.utils.registry import Registry, build_from_config | |
| def get_model(model_tuple): | |
| assert 'model' in model_tuple | |
| return model_tuple['model'] | |
| class BaseInference(): | |
| ''' | |
| support to load the components dynamicly. | |
| create and load model when run this model at the first time. | |
| ''' | |
| def __init__(self, cfg, logger=None): | |
| if logger is None: | |
| logger = get_logger(name='scepter') | |
| self.logger = logger | |
| self.name = cfg.NAME | |
| def init_from_modules(self, modules): | |
| for k, v in modules.items(): | |
| self.__setattr__(k, v) | |
| def infer_model(self, cfg, module_paras=None): | |
| module = { | |
| 'model': None, | |
| 'cfg': cfg, | |
| 'device': 'offline', | |
| 'name': cfg.NAME, | |
| 'function_info': {}, | |
| 'paras': {} | |
| } | |
| if module_paras is None: | |
| return module | |
| function_info = {} | |
| paras = { | |
| k.lower(): v | |
| for k, v in module_paras.get('PARAS', {}).items() | |
| } | |
| for function in module_paras.get('FUNCTION', []): | |
| input_dict = {} | |
| for inp in function.get('INPUT', []): | |
| if inp.lower() in self.input: | |
| input_dict[inp.lower()] = self.input[inp.lower()] | |
| function_info[function.NAME] = { | |
| 'dtype': function.get('DTYPE', 'float32'), | |
| 'input': input_dict | |
| } | |
| module['paras'] = paras | |
| module['function_info'] = function_info | |
| return module | |
| def init_from_ckpt(self, path, model, ignore_keys=list()): | |
| if path.endswith('safetensors'): | |
| from safetensors.torch import load_file as load_safetensors | |
| sd = load_safetensors(path) | |
| else: | |
| sd = torch.load(path, map_location='cpu', weights_only=True) | |
| new_sd = OrderedDict() | |
| for k, v in sd.items(): | |
| ignored = False | |
| for ik in ignore_keys: | |
| if ik in k: | |
| if we.rank == 0: | |
| self.logger.info( | |
| 'Ignore key {} from state_dict.'.format(k)) | |
| ignored = True | |
| break | |
| if not ignored: | |
| new_sd[k] = v | |
| missing, unexpected = model.load_state_dict(new_sd, strict=False) | |
| if we.rank == 0: | |
| self.logger.info( | |
| f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' | |
| ) | |
| if len(missing) > 0: | |
| self.logger.info(f'Missing Keys:\n {missing}') | |
| if len(unexpected) > 0: | |
| self.logger.info(f'\nUnexpected Keys:\n {unexpected}') | |
| def load(self, module): | |
| if module['device'] == 'offline': | |
| from scepter.modules.utils.import_utils import LazyImportModule | |
| if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or | |
| module['cfg'].NAME in MODELS.class_map): | |
| model = MODELS.build(module['cfg'], logger=self.logger).eval() | |
| elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or | |
| module['cfg'].NAME in BACKBONES.class_map): | |
| model = BACKBONES.build(module['cfg'], | |
| logger=self.logger).eval() | |
| elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or | |
| module['cfg'].NAME in EMBEDDERS.class_map): | |
| model = EMBEDDERS.build(module['cfg'], | |
| logger=self.logger).eval() | |
| else: | |
| raise NotImplementedError | |
| if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None: | |
| model = model.to(getattr(torch, module['cfg'].DTYPE)) | |
| if module['cfg'].get('RELOAD_MODEL', None): | |
| self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model) | |
| module['model'] = model | |
| module['device'] = 'cpu' | |
| if module['device'] == 'cpu': | |
| module['device'] = we.device_id | |
| module['model'] = module['model'].to(we.device_id) | |
| return module | |
| def unload(self, module): | |
| if module is None: | |
| return module | |
| mem = get_available_memory() | |
| free_mem = int(mem['available'] / (1024**2)) | |
| total_mem = int(mem['total'] / (1024**2)) | |
| if free_mem < 0.5 * total_mem: | |
| if module['model'] is not None: | |
| module['model'] = module['model'].to('cpu') | |
| del module['model'] | |
| module['model'] = None | |
| module['device'] = 'offline' | |
| print('delete module') | |
| else: | |
| if module['model'] is not None: | |
| module['model'] = module['model'].to('cpu') | |
| module['device'] = 'cpu' | |
| else: | |
| module['device'] = 'offline' | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| return module | |
| def dynamic_load(self, module=None, name=''): | |
| self.logger.info('Loading {} model'.format(name)) | |
| if name == 'all': | |
| for subname in self.loaded_model_name: | |
| self.loaded_model[subname] = self.dynamic_load( | |
| getattr(self, subname), subname) | |
| elif name in self.loaded_model_name: | |
| if name in self.loaded_model: | |
| if module['cfg'] != self.loaded_model[name]['cfg']: | |
| self.unload(self.loaded_model[name]) | |
| module = self.load(module) | |
| self.loaded_model[name] = module | |
| return module | |
| elif module['device'] == 'cpu' or module['device'] == 'offline': | |
| module = self.load(module) | |
| return module | |
| else: | |
| return module | |
| else: | |
| module = self.load(module) | |
| self.loaded_model[name] = module | |
| return module | |
| else: | |
| return self.load(module) | |
| def dynamic_unload(self, module=None, name='', skip_loaded=False): | |
| self.logger.info('Unloading {} model'.format(name)) | |
| if name == 'all': | |
| for name, module in self.loaded_model.items(): | |
| module = self.unload(self.loaded_model[name]) | |
| self.loaded_model[name] = module | |
| elif name in self.loaded_model_name: | |
| if name in self.loaded_model: | |
| if not skip_loaded: | |
| module = self.unload(self.loaded_model[name]) | |
| self.loaded_model[name] = module | |
| else: | |
| self.unload(module) | |
| else: | |
| self.unload(module) | |
| def load_default(self, cfg): | |
| module_paras = {} | |
| if cfg is not None: | |
| self.paras = cfg.PARAS | |
| self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()} | |
| self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()} | |
| self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()} | |
| module_paras = cfg.MODULES_PARAS | |
| return module_paras | |
| def load_image(self, image, num_samples=1): | |
| if isinstance(image, torch.Tensor): | |
| pass | |
| elif isinstance(image, Image): | |
| pass | |
| elif isinstance(image, Image): | |
| pass | |
| def get_function_info(self, module, function_name=None): | |
| all_function = module['function_info'] | |
| if function_name in all_function: | |
| return function_name, all_function[function_name]['dtype'] | |
| if function_name is None and len(all_function) == 1: | |
| for k, v in all_function.items(): | |
| return k, v['dtype'] | |
| def __call__(self, | |
| input, | |
| **kwargs): | |
| return | |
| def build_inference(cfg, registry, logger=None, *args, **kwargs): | |
| """ After build model, load pretrained model if exists key `pretrain`. | |
| pretrain (str, dict): Describes how to load pretrained model. | |
| str, treat pretrain as model path; | |
| dict: should contains key `path`, and other parameters token by function load_pretrained(); | |
| """ | |
| if not isinstance(cfg, Config): | |
| raise TypeError(f'Config must be type dict, got {type(cfg)}') | |
| model = build_from_config(cfg, registry, logger=logger, *args, **kwargs) | |
| return model | |
| # reigister cls for diffusion. | |
| INFERENCES = Registry('INFERENCE', build_func=build_inference) | |