Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Optional | |
| import torch | |
| import comfy.model_management | |
| from .base import WeightAdapterBase, weight_decompose | |
| class LoHaAdapter(WeightAdapterBase): | |
| name = "loha" | |
| def __init__(self, loaded_keys, weights): | |
| self.loaded_keys = loaded_keys | |
| self.weights = weights | |
| def load( | |
| cls, | |
| x: str, | |
| lora: dict[str, torch.Tensor], | |
| alpha: float, | |
| dora_scale: torch.Tensor, | |
| loaded_keys: set[str] = None, | |
| ) -> Optional["LoHaAdapter"]: | |
| if loaded_keys is None: | |
| loaded_keys = set() | |
| hada_w1_a_name = "{}.hada_w1_a".format(x) | |
| hada_w1_b_name = "{}.hada_w1_b".format(x) | |
| hada_w2_a_name = "{}.hada_w2_a".format(x) | |
| hada_w2_b_name = "{}.hada_w2_b".format(x) | |
| hada_t1_name = "{}.hada_t1".format(x) | |
| hada_t2_name = "{}.hada_t2".format(x) | |
| if hada_w1_a_name in lora.keys(): | |
| hada_t1 = None | |
| hada_t2 = None | |
| if hada_t1_name in lora.keys(): | |
| hada_t1 = lora[hada_t1_name] | |
| hada_t2 = lora[hada_t2_name] | |
| loaded_keys.add(hada_t1_name) | |
| loaded_keys.add(hada_t2_name) | |
| weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) | |
| loaded_keys.add(hada_w1_a_name) | |
| loaded_keys.add(hada_w1_b_name) | |
| loaded_keys.add(hada_w2_a_name) | |
| loaded_keys.add(hada_w2_b_name) | |
| return cls(loaded_keys, weights) | |
| else: | |
| return None | |
| def calculate_weight( | |
| self, | |
| weight, | |
| key, | |
| strength, | |
| strength_model, | |
| offset, | |
| function, | |
| intermediate_dtype=torch.float32, | |
| original_weight=None, | |
| ): | |
| v = self.weights | |
| w1a = v[0] | |
| w1b = v[1] | |
| if v[2] is not None: | |
| alpha = v[2] / w1b.shape[0] | |
| else: | |
| alpha = 1.0 | |
| w2a = v[3] | |
| w2b = v[4] | |
| dora_scale = v[7] | |
| if v[5] is not None: #cp decomposition | |
| t1 = v[5] | |
| t2 = v[6] | |
| m1 = torch.einsum('i j k l, j r, i p -> p r k l', | |
| comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) | |
| m2 = torch.einsum('i j k l, j r, i p -> p r k l', | |
| comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) | |
| else: | |
| m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) | |
| m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), | |
| comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) | |
| try: | |
| lora_diff = (m1 * m2).reshape(weight.shape) | |
| if dora_scale is not None: | |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) | |
| else: | |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) | |
| except Exception as e: | |
| logging.error("ERROR {} {} {}".format(self.name, key, e)) | |
| return weight | |