Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict | |
| import torch | |
| class AttnProcsLayers(torch.nn.Module): | |
| def __init__(self, state_dict: Dict[str, torch.Tensor]): | |
| super().__init__() | |
| self.layers = torch.nn.ModuleList(state_dict.values()) | |
| self.mapping = dict(enumerate(state_dict.keys())) | |
| self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} | |
| # .processor for unet, .self_attn for text encoder | |
| self.split_keys = [".processor", ".self_attn"] | |
| # we add a hook to state_dict() and load_state_dict() so that the | |
| # naming fits with `unet.attn_processors` | |
| def map_to(module, state_dict, *args, **kwargs): | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| num = int(key.split(".")[1]) # 0 is always "layers" | |
| new_key = key.replace(f"layers.{num}", module.mapping[num]) | |
| new_state_dict[new_key] = value | |
| return new_state_dict | |
| def remap_key(key, state_dict): | |
| for k in self.split_keys: | |
| if k in key: | |
| return key.split(k)[0] + k | |
| raise ValueError( | |
| f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." | |
| ) | |
| def map_from(module, state_dict, *args, **kwargs): | |
| all_keys = list(state_dict.keys()) | |
| for key in all_keys: | |
| replace_key = remap_key(key, state_dict) | |
| new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") | |
| state_dict[new_key] = state_dict[key] | |
| del state_dict[key] | |
| self._register_state_dict_hook(map_to) | |
| self._register_load_state_dict_pre_hook(map_from, with_module=True) | |