File size: 37,191 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import platform
import re
import warnings
from typing import Optional
import huggingface_hub
import torch
from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.errors import EntryNotFoundError, LocalEntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from transformers.utils import http_user_agent
from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING
from .constants import INCLUDE_LINEAR_LAYERS_SHORTHAND
from .other import (
EMBEDDING_LAYER_NAMES,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
AuxiliaryTrainingWrapper,
check_file_exists_on_hf_hub,
infer_device,
match_target_against_key,
)
from .peft_types import PeftType
def has_valid_embedding_base_layer(layer):
"""Check if the layer has an embedding base layer"""
return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding))
def get_embedding_layer_name(model, layer, is_embedding_in_target_modules):
"""Get the name of the embedding module for a given layer."""
for name, module in model.named_modules():
if (not is_embedding_in_target_modules and module == layer) or module == getattr(layer, "base_layer", None):
return name
return None
def get_peft_model_state_dict(
model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto"
):
"""
Get the state dict of the given adapter of the PEFT model.
This only includes the PEFT parameters, not the parameters of the base model. Thus the returned `state_dict` is
generally small compared to the full model size. To retrieve the full `state_dict`, just call `model.state_dict()`.
Note that the adapter name is removed from the `state_dict`, as this is just an arbitrary name that can be changed
when loading the adapter. So e.g. if the adapter name is `'default'` and the original key is
`'model.q_proj.lora_A.default.weight'`, the returned key will be `'model.q_proj.lora_A.weight'`. Use this function
in conjunction with [`set_peft_model_state_dict`] to take care of the adapter name when loading weights.
Args:
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the passed model will be used.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter whose state dict should be returned.
unwrap_compiled (`bool`, *optional*, defaults to `False`):
Whether to unwrap the model if torch.compile was used.
save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`):
If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding
layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it
sets the boolean flag. This only works for 🤗 transformers models.
"""
if unwrap_compiled:
model = getattr(model, "_orig_mod", model)
config = model.peft_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()
# TUNER SPECIFIC CODE
if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()}
config.rank_pattern = rank_pattern
to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name)
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
elif config.peft_type == PeftType.BOFT:
bias = config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "boft_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "boft_" in k or "bias" in k}
elif bias == "boft_only":
to_return = {}
for k in state_dict:
if "boft_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("boft_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
elif config.peft_type == PeftType.ADAPTION_PROMPT:
to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")}
elif config.is_prompt_learning:
to_return = {}
if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
to_return["prefix_task_cols"] = model.prompt_encoder[adapter_name].prefix_task_cols
to_return["prefix_task_rows"] = model.prompt_encoder[adapter_name].prefix_task_rows
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
else:
if config.inference_mode:
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
to_return["prompt_embeddings"] = prompt_embeddings
elif config.peft_type == PeftType.SHIRA:
shira_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
to_return = {k: state_dict[k] for k in state_dict if shira_prefix in k}
if platform.system() == "Windows":
warnings.warn(
"Windows has issues saving integers into safetensors. Hence, we convert shira_indices to float32 "
"before saving on Windows OS. The shira_indices will always be converted to integers when loading."
)
for name, module in model.named_modules():
if hasattr(module, "shira_indices"):
for k, v in module.shira_indices.items():
# Windows has some issues with saving integers into safetensors. Tests fail with some kind of
# PermissionError. This results in failed tests, so we are converting indices to float32 before
# saving and then converting them back to int when loading. This is happening only for Windows,
# not for Linux and Mac-OS.
to_return[f"{name}.shira_indices.{k}"] = (
v.to(torch.float32) if platform.system() == "Windows" else v
)
elif config.peft_type == PeftType.VERA:
vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k}
if config.save_projection:
# TODO: adding vera_A and vera_B to `self.get_base_layer` would
# make name to match here difficult to predict.
if f"base_model.vera_A.{adapter_name}" not in state_dict:
raise ValueError(
"Model was initialised to not save vera_A and vera_B but config now specifies to save projection!"
" Set `config.save_projection` to `False`."
)
to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name]
to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name]
elif config.peft_type == PeftType.XLORA:
to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k}
elif config.peft_type == PeftType.VBLORA:
to_return = {}
# choose the most efficient dtype for indices
if config.num_vectors < 2**8:
indices_dtype = torch.uint8
elif config.num_vectors < 2**15:
indices_dtype = torch.int16
elif config.num_vectors < 2**31:
indices_dtype = torch.int32
else:
indices_dtype = torch.int64
if config.save_only_topk_weights:
# in save_only_topk_weights mode, we save topk_indices and topk_weights for parameter efficiency
for k in state_dict:
if "vblora_logits" in k:
logits, indices = state_dict[k].topk(config.topk)
to_return.update({k + "_topk_indices": indices.to(dtype=indices_dtype)})
to_return.update({k + "_topk_weights": torch.softmax(logits, dim=-1)[:, :, :-1].contiguous()})
else:
to_return = {k: state_dict[k] for k in state_dict if "vblora_logits" in k}
to_return["base_model.vblora_vector_bank." + adapter_name] = state_dict[
"base_model.vblora_vector_bank." + adapter_name
]
elif config.peft_type in list(PeftType):
prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
to_return = {k: state_dict[k] for k in state_dict if prefix in k}
else:
raise ValueError(f"Unknown PEFT type passed: {config.peft_type}")
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
for name, module in model.named_modules():
if isinstance(module, AuxiliaryTrainingWrapper):
if name.startswith("_fsdp_wrapped_module."):
# If FSDP is used, the state_dict is from the unwrapped model, which will result in a key mismatch if we
# don't remove the FSDP-specific prefix
name = name.removeprefix("_fsdp_wrapped_module.")
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
# module since accelerators like DeepSpeed require special handling which is done for the model
# state dict from above but most likely not in the module itself. See #2450.
module_state_dict = {
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
}
to_return.update(
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
)
# DEAL WITH EMBEDDINGS
#
# save_embedding_layer="auto" needs to check the following logic:
#
# - when vocab size was NOT changed, embeddings should be saved only when targeted
# but not when
# - using PeftType.TRAINABLE_TOKENS
# - LoRA using trainable_token_indices (since their goal is to space-efficient)
# but
# - when vocab size was changed, embeddings should be saved automatically regardless to cover this
# scenario: 1) fine-tune embedding, 2) resize embedding, 3) train with trainable tokens
#
embedding_is_targeted = False
if hasattr(config, "target_modules"):
if isinstance(config.target_modules, str) and (config.target_modules != INCLUDE_LINEAR_LAYERS_SHORTHAND):
# `model` could be a PeftModel or something else like transformers/diffusers/..., in which case unwrapping is
# not needed.
_model = model.get_base_model() if hasattr(model, "get_base_model") else model
embedding_is_targeted = any(
match_target_against_key(config.target_modules, k)
for k, _ in _model.named_modules()
if any(re.match(rf"(.*\.)?{e}$", k) for e in EMBEDDING_LAYER_NAMES)
)
elif config.target_modules:
embedding_is_targeted = any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
using_trainable_tokens = (
config.peft_type == PeftType.TRAINABLE_TOKENS or getattr(config, "trainable_token_indices", None) is not None
)
if save_embedding_layers == "auto" and embedding_is_targeted and not using_trainable_tokens:
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
save_embedding_layers = True
elif save_embedding_layers == "auto":
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
model_id = getattr(config, "base_model_name_or_path", None)
# For some models e.g. diffusers the text config file is stored in a subfolder
# we need to make sure we can download that config.
has_base_config = False
# ensure that this check is not performed in HF offline mode, see #1452
if model_id is not None:
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
if exists is None:
# check failed, could not determine if it exists or not
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
has_base_config = False
else:
has_base_config = exists
# check if the vocab size of the base model is different from the vocab size of the finetuned model
if (
vocab_size
and model_id
and has_base_config
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
):
warnings.warn(
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
)
save_embedding_layers = True
else:
save_embedding_layers = False
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
# Either the layer is not targeted, then it must have been resized and needs saving. Or it is targeted and
# therefore has a valid base layer, then we'll save it as well.
if not embedding_is_targeted or has_valid_embedding_base_layer(layer):
embedding_module_name = get_embedding_layer_name(model, layer, embedding_is_targeted)
if embedding_module_name:
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
# REMOVE ADAPTER NAME
# Ensure not to replace in the middle of the key because a module happens to have the same name as the adapter.
pattern = re.compile(re.escape(f".{adapter_name}") + r"$")
def remove_adapter_name(key):
if "." not in key:
# nothing to do
return key
if key.endswith(f".{adapter_name}"):
# comes from an nn.Parameter, so no .weight suffix, the adapter name is directly at the end
return key.removesuffix(f".{adapter_name}")
# comes from an nn.Module, i.e. the adapter name is the 2nd to last element, e.g. v_proj.lora_A.default.weight
key, _, suffix = key.rpartition(".") # split, e.g. v_proj.lora_A.default + weight
if (config.peft_type == PeftType.VBLORA) and suffix.startswith(f"{adapter_name}_"):
# special case: VBLoRA creates keys that require this replacement:
# base_model.model.lin0.vblora_logits_A.default_topk_indices =>
# base_model.model.lin0.vblora_logits_A_topk_indices
return key + "_" + suffix.removeprefix(f"{adapter_name}_")
key = pattern.sub("", key) # remove adapter name, e.g. v_proj.lora_A
return f"{key}.{suffix}" # stitch the suffix back, e.g, v_proj.lora_A.weight
to_return = {remove_adapter_name(k): v for k, v in to_return.items()}
return to_return
def _find_mismatched_keys(
model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False
) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
if not ignore_mismatched_sizes:
return peft_model_state_dict, []
mismatched = []
state_dict = model.state_dict()
for key, tensor in peft_model_state_dict.items():
if key not in state_dict:
continue
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864
if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size
# differences. Without matching with module type or parameter type it seems like a practical way to detect
# valid 4bit weights.
continue
if state_dict[key].shape != tensor.shape:
mismatched.append((key, tensor.shape, state_dict[key].shape))
for key, _, _ in mismatched:
del peft_model_state_dict[key]
return peft_model_state_dict, mismatched
def _insert_adapter_name_into_state_dict(
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
) -> dict[str, torch.Tensor]:
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
_, _, suffix = key.rpartition(parameter_prefix)
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
# only replace the substring if the key ends on the substring to avoid accidental replacement inside of
# the key if a module happens to have a name that contains the substring
key = re.sub(re.escape(suffix_to_replace) + r"$", f"{adapter_name}.{suffix_to_replace}", key)
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val
else:
peft_model_state_dict[key] = val
return peft_model_state_dict
def set_peft_model_state_dict(
model,
peft_model_state_dict,
adapter_name="default",
ignore_mismatched_sizes: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
"""
Set the state dict of the PEFT model.
Given a PEFT `state_dict` (as returned by [`get_peft_model_state_dict`]), insert the weights into the model. The
model needs to have the PEFT adapters already in place (e.g. via [`inject_adapter_in_model`]).
Setting the adapter weights also takes care of re-inserting the adapter name. This name may be a different name
than the one originally used to train the adapter.
Args:
model ([`PeftModel`]):
The Peft model.
peft_model_state_dict (`dict`):
The state dict of the Peft model.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter whose state dict should be set.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether to ignore mismatched in the state dict.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
This argument must be `True` if the `model` was loaded with adapter weights on the meta device, e.g. after
calling `inject_adapter_in_model` with `low_cpu_mem_usage=True`. Otherwise, leave it as `False`.
"""
config = model.peft_config[adapter_name]
state_dict = peft_model_state_dict
# handle auxiliary training wrappers such as ModulesToSaveWrapper and TrainableTokensWrapper by getting each of
# them and translating saved state dict key (which does not include the adapter name) to loaded state dict key
# (which includes the adapter name).
for name, module in model.named_modules():
if isinstance(module, AuxiliaryTrainingWrapper):
# Not every module has a 1:1 mapping. ModulesToSaveWrapper, for example, removes the
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
# saved state dict which is why we fetch a load key map from the wrapper.
key_map = module.adapter_state_dict_load_map(adapter_name)
if name.startswith("_fsdp_wrapped_module."):
# If FSDP is used, the state_dict is from the unwrapped model, which will result in a key mismatch if we
# don't remove the FSDP-specific prefix
name = name.removeprefix("_fsdp_wrapped_module.")
for k in key_map:
lookup_key = f"{name}.{k}"
store_key = f"{name}.{key_map[k]}"
state_dict[store_key] = peft_model_state_dict[lookup_key]
# delete the old key from the previous `state_dict = peft_model_state_dict` statement.
del state_dict[lookup_key]
if config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT:
peft_model_state_dict = state_dict
elif config.peft_type == PeftType.XLORA:
peft_model_state_dict = state_dict
elif config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING:
peft_model_state_dict = {}
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights:
num_vectors, _ = model.vblora_vector_bank[adapter_name].shape
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
# in save_only_topk_weights mode, only topk_indices and topk_weights are saved
# note that topk_indices and topk_weights serve as an efficient representation of the logits
# so we need to recover the logits from the topk_indices and topk_weights
if "_topk_indices" in k:
v = state_dict[k].to(torch.long)
original_key = k.replace("_topk_indices", "")
# find the corresponding topk_weights from the state_dict
topk_weights = state_dict[k.replace("_topk_indices", "_topk_weights")]
# as we only save the first k-1 topk_weights, here we recover the last one
topk_weights = torch.cat([topk_weights, 1 - topk_weights.sum(-1, keepdim=True)], dim=-1)
# convert the weights to logits
topk_logits = torch.log(topk_weights)
matrix = (
torch.zeros([*(topk_logits.shape[:-1]), num_vectors])
.fill_(float("-inf"))
.to(topk_logits.device)
.scatter(-1, v, topk_logits)
)
# add logits to the state_dict
state_dict[original_key] = matrix
# delete the topk_indices and topk_weights from the state_dict
del state_dict[k]
del state_dict[k.replace("_topk_indices", "_topk_weights")]
peft_model_state_dict = _insert_adapter_name_into_state_dict(
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix
)
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
elif config.peft_type == PeftType.SHIRA:
if platform.system() == "Windows":
warnings.warn(
"Windows has issues saving integers into safetensors. Hence, we had converted shira_indices "
"to float32 before saving on Windows OS. The shira_indices will always be converted to integers "
"when loading."
)
for name, module in model.named_modules():
if hasattr(module, "shira_indices"):
# for k, v in module.shira_indices.items():
if f"{name}.shira_indices.{adapter_name}" in peft_model_state_dict:
shira_indices_values = peft_model_state_dict.pop(f"{name}.shira_indices.{adapter_name}")
# Convert shira_indices to int in case they were saved on a Windows OS and are being loaded
# on a Linux or a Mac-OS system. If they were saved in Linux or Mac-OS, they are already
# integers and the following will not affect anything.
module.shira_indices[adapter_name] = shira_indices_values.to(torch.int)
elif config.peft_type == PeftType.VERA:
if config.save_projection and "base_model.vera_A" not in peft_model_state_dict:
raise ValueError(
"Specified to load vera_A and vera_B from state dictionary however they were not present!"
)
elif not config.save_projection and "base_model.vera_A" in peft_model_state_dict:
warnings.warn(
"Specified to not load vera_A and vera_B from state dictionary however they are present in state"
" dictionary! Consider using them to ensure checkpoint loading is correct on all platforms using"
" `peft_config.save_projection = True`"
)
elif not config.save_projection: # and no vera_A in state dictionary
warnings.warn(
"Specified to not load vera_A and vera_B from state dictionary. This means we will be relying on"
" PRNG initialisation to restore these projections using `config.projection_prng_key`, which may"
" not be accurate on all system configurations."
)
elif config.peft_type == PeftType.LORA:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer.
old_dora_suffix = f"lora_magnitude_vector.{adapter_name}"
def renamed_dora_weights(k):
if k.endswith(old_dora_suffix):
k = k + ".weight"
return k
peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()}
elif config.peft_type == PeftType.OFT:
if any(".oft_r." in key for key in peft_model_state_dict):
raise ValueError(
"Trying to load old OFT checkpoint, which is no longer supported. Please install PEFT <= v0.15.2 to load it or train a new OFT adapter."
)
else:
raise NotImplementedError
peft_model_state_dict, mismatched_keys = _find_mismatched_keys(
model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes
)
if low_cpu_mem_usage:
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
# ensure that the correct device is set
for module in model.modules():
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
module._move_adapter_to_device_of_base_layer(adapter_name)
else:
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
if config.is_prompt_learning:
model.prompt_encoder[adapter_name].embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)
if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
model.prompt_encoder[adapter_name].load_state_dict(peft_model_state_dict, strict=False)
if mismatched_keys:
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
msg = (
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint "
f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}."
)
warnings.warn(msg)
return load_result
# TODO: remove this function, use vanilla torch.load as soon as torch < 2.6.0 is no longer supported
def torch_load(*args, weights_only=True, **kwargs):
"""Call torch.load and handle weights_only.
Defaults to weights_only=True to anticipate upcoming switch on the PyTorch side.
"""
return torch.load(*args, weights_only=weights_only, **kwargs)
def load_peft_weights(
model_id: str, device: Optional[str] = None, key_mapping: Optional[dict[str, str]] = None, **hf_hub_download_kwargs
) -> dict:
r"""
A helper method to load the PEFT weights from the HuggingFace Hub or locally
Args:
model_id (`str`):
The local path to the adapter weights or the name of the adapter to load from the HuggingFace Hub.
device (`str`):
The device to load the weights onto.
key_mapping (dict, *optional*, defaults to None)
Extra mapping of PEFT `state_dict` keys applied before loading the `state_dict`. When this mapping is
applied, the PEFT-specific `"base_model.model"` prefix is removed beforehand and the adapter name (e.g.
`"default"`) is not inserted yet. Only pass this argument if you know what you're doing.
hf_hub_download_kwargs (`dict`):
Additional arguments to pass to the `hf_hub_download` method when loading from the HuggingFace Hub.
"""
path = (
os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
if hf_hub_download_kwargs.get("subfolder", None) is not None
else model_id
)
if device is None:
device = infer_device()
def get_hub_filename(use_safetensors=True):
weights_name = SAFETENSORS_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
return (
os.path.join(hf_hub_download_kwargs["subfolder"], weights_name)
if hf_hub_download_kwargs.get("subfolder", None) is not None
else weights_name
)
if "user_agent" not in hf_hub_download_kwargs:
hf_hub_download_kwargs["user_agent"] = http_user_agent()
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
use_safetensors = True
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False
elif huggingface_hub.constants.HF_HUB_OFFLINE:
# if in offline mode, check if we can find the adapter file locally
hub_filename = get_hub_filename(use_safetensors=True)
hf_hub_download_kwargs.pop("local_files_only", None)
try:
filename = hf_hub_download(model_id, hub_filename, local_files_only=True, **hf_hub_download_kwargs)
use_safetensors = True
except LocalEntryNotFoundError:
# Could not find safetensors, try pickle. If this also fails, it's fine to let the error be raised here, as
# it means that the user tried to load a non-cached model in offline mode.
hub_filename = get_hub_filename(use_safetensors=False)
filename = hf_hub_download(model_id, hub_filename, local_files_only=True, **hf_hub_download_kwargs)
use_safetensors = False
else:
token = hf_hub_download_kwargs.get("token", None)
if token is None:
token = hf_hub_download_kwargs.get("use_auth_token", None)
hub_filename = get_hub_filename(use_safetensors=True)
has_remote_safetensors_file = file_exists(
repo_id=model_id,
filename=hub_filename,
revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None),
token=token,
)
use_safetensors = has_remote_safetensors_file
if has_remote_safetensors_file:
# Priority 1: load safetensors weights
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}."
)
if use_safetensors:
if hasattr(torch.backends, "mps") and (device == torch.device("mps")):
adapters_weights = safe_load_file(filename, device="cpu")
else:
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch_load(filename, map_location=torch.device(device))
if not key_mapping:
remapped_adapters_weights = adapters_weights
else:
# See discussion in https://github.com/huggingface/transformers/pull/38627
# Remap adapter weight names according to the provided key_mapping.
remapped_adapters_weights = {}
for key, val in adapters_weights.items():
if key.startswith("base_model.model."):
prefix = "base_model.model."
elif key.startswith("base_model."):
prefix = "base_model."
else:
raise ValueError(
"An error occurred while trying to load a PEFT state_dict with key_mapping. This should not "
"happen. Please open an issue on https://github.com/huggingface/peft/issues and report the error."
)
key = key.removeprefix(prefix) # the key map assumes that there is no prefix
for pattern, replacement in key_mapping.items():
key_new, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
key = key_new
break
key_with_prefix = f"{prefix}{key}"
remapped_adapters_weights[key_with_prefix] = val
return remapped_adapters_weights
|