Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Union | |
| from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available | |
| class PeftAdapterMixin: | |
| """ | |
| A class containing all functions for loading and using adapters weights that are supported in PEFT library. For | |
| more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index). | |
| Install the latest version of PEFT, and use this mixin to: | |
| - Attach new adapters in the model. | |
| - Attach multiple adapters and iteratively activate/deactivate them. | |
| - Activate/deactivate all adapters from the model. | |
| - Get a list of the active adapters. | |
| """ | |
| _hf_peft_config_loaded = False | |
| def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: | |
| r""" | |
| Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned | |
| to the adapter to follow the convention of the PEFT library. | |
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT | |
| [documentation](https://huggingface.co/docs/peft). | |
| Args: | |
| adapter_config (`[~peft.PeftConfig]`): | |
| The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt | |
| methods. | |
| adapter_name (`str`, *optional*, defaults to `"default"`): | |
| The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. | |
| """ | |
| check_peft_version(min_version=MIN_PEFT_VERSION) | |
| if not is_peft_available(): | |
| raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") | |
| from peft import PeftConfig, inject_adapter_in_model | |
| if not self._hf_peft_config_loaded: | |
| self._hf_peft_config_loaded = True | |
| elif adapter_name in self.peft_config: | |
| raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") | |
| if not isinstance(adapter_config, PeftConfig): | |
| raise ValueError( | |
| f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." | |
| ) | |
| # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is | |
| # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. | |
| adapter_config.base_model_name_or_path = None | |
| inject_adapter_in_model(adapter_config, self, adapter_name) | |
| self.set_adapter(adapter_name) | |
| def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: | |
| """ | |
| Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. | |
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
| [documentation](https://huggingface.co/docs/peft). | |
| Args: | |
| adapter_name (Union[str, List[str]])): | |
| The list of adapters to set or the adapter name in the case of a single adapter. | |
| """ | |
| check_peft_version(min_version=MIN_PEFT_VERSION) | |
| if not self._hf_peft_config_loaded: | |
| raise ValueError("No adapter loaded. Please load an adapter first.") | |
| if isinstance(adapter_name, str): | |
| adapter_name = [adapter_name] | |
| missing = set(adapter_name) - set(self.peft_config) | |
| if len(missing) > 0: | |
| raise ValueError( | |
| f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." | |
| f" current loaded adapters are: {list(self.peft_config.keys())}" | |
| ) | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| _adapters_has_been_set = False | |
| for _, module in self.named_modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| if hasattr(module, "set_adapter"): | |
| module.set_adapter(adapter_name) | |
| # Previous versions of PEFT does not support multi-adapter inference | |
| elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: | |
| raise ValueError( | |
| "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." | |
| " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" | |
| ) | |
| else: | |
| module.active_adapter = adapter_name | |
| _adapters_has_been_set = True | |
| if not _adapters_has_been_set: | |
| raise ValueError( | |
| "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." | |
| ) | |
| def disable_adapters(self) -> None: | |
| r""" | |
| Disable all adapters attached to the model and fallback to inference with the base model only. | |
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
| [documentation](https://huggingface.co/docs/peft). | |
| """ | |
| check_peft_version(min_version=MIN_PEFT_VERSION) | |
| if not self._hf_peft_config_loaded: | |
| raise ValueError("No adapter loaded. Please load an adapter first.") | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| for _, module in self.named_modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| if hasattr(module, "enable_adapters"): | |
| module.enable_adapters(enabled=False) | |
| else: | |
| # support for older PEFT versions | |
| module.disable_adapters = True | |
| def enable_adapters(self) -> None: | |
| """ | |
| Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the | |
| list of adapters to enable. | |
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
| [documentation](https://huggingface.co/docs/peft). | |
| """ | |
| check_peft_version(min_version=MIN_PEFT_VERSION) | |
| if not self._hf_peft_config_loaded: | |
| raise ValueError("No adapter loaded. Please load an adapter first.") | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| for _, module in self.named_modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| if hasattr(module, "enable_adapters"): | |
| module.enable_adapters(enabled=True) | |
| else: | |
| # support for older PEFT versions | |
| module.disable_adapters = False | |
| def active_adapters(self) -> List[str]: | |
| """ | |
| Gets the current list of active adapters of the model. | |
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
| [documentation](https://huggingface.co/docs/peft). | |
| """ | |
| check_peft_version(min_version=MIN_PEFT_VERSION) | |
| if not is_peft_available(): | |
| raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") | |
| if not self._hf_peft_config_loaded: | |
| raise ValueError("No adapter loaded. Please load an adapter first.") | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| for _, module in self.named_modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| return module.active_adapter | |