Aatricks's picture
Upload folder using huggingface_hub
a1693c9
import copy
import logging
import gguf
import torch
from modules.Device import Device
from modules.Model import ModelPatcher
from modules.Utilities import util
from modules.clip import Clip
from modules.cond import cast
# Constants for torch-compatible quantization types
TORCH_COMPATIBLE_QTYPES = {
None,
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
}
def is_torch_compatible(tensor: torch.Tensor) -> bool:
"""#### Check if a tensor is compatible with PyTorch operations.
#### Args:
- `tensor` (torch.Tensor): The tensor to check.
#### Returns:
- `bool`: Whether the tensor is torch-compatible.
"""
return (
tensor is None
or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
)
def is_quantized(tensor: torch.Tensor) -> bool:
"""#### Check if a tensor is quantized.
#### Args:
- `tensor` (torch.Tensor): The tensor to check.
#### Returns:
- `bool`: Whether the tensor is quantized.
"""
return not is_torch_compatible(tensor)
def dequantize(
data: torch.Tensor,
qtype: gguf.GGMLQuantizationType,
oshape: tuple,
dtype: torch.dtype = None,
) -> torch.Tensor:
"""#### Dequantize tensor back to usable shape/dtype.
#### Args:
- `data` (torch.Tensor): The quantized data.
- `qtype` (gguf.GGMLQuantizationType): The quantization type.
- `oshape` (tuple): The output shape.
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None.
#### Returns:
- `torch.Tensor`: The dequantized tensor.
"""
# Get block size and type size for quantization format
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
# Reshape data into blocks
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
# Dequantize blocks and reshape to target shape
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
return blocks.reshape(oshape)
def split_block_dims(blocks: torch.Tensor, *args) -> list:
"""#### Split blocks into dimensions.
#### Args:
- `blocks` (torch.Tensor): The blocks to split.
- `*args`: The dimensions to split into.
#### Returns:
- `list`: The split blocks.
"""
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
# Legacy Quantization Functions
def dequantize_blocks_Q8_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype = None
) -> torch.Tensor:
"""#### Dequantize Q8_0 quantized blocks.
#### Args:
- `blocks` (torch.Tensor): The quantized blocks.
- `block_size` (int): The block size.
- `type_size` (int): The type size.
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None.
#### Returns:
- `torch.Tensor`: The dequantized blocks.
"""
# Split blocks into scale and quantized values
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
# K Quants #
QK_K = 256
K_SCALE_SIZE = 12
# Mapping of quantization types to dequantization functions
dequantize_functions = {
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
}
def dequantize_tensor(
tensor: torch.Tensor, dtype: torch.dtype = None, dequant_dtype: torch.dtype = None
) -> torch.Tensor:
"""#### Dequantize a potentially quantized tensor.
#### Args:
- `tensor` (torch.Tensor): The tensor to dequantize.
- `dtype` (torch.dtype, optional): Target dtype. Defaults to None.
- `dequant_dtype` (torch.dtype, optional): Intermediate dequantization dtype. Defaults to None.
#### Returns:
- `torch.Tensor`: The dequantized tensor.
"""
qtype = getattr(tensor, "tensor_type", None)
oshape = getattr(tensor, "tensor_shape", tensor.shape)
if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in dequantize_functions:
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
class GGMLLayer(torch.nn.Module):
"""#### Base class for GGML quantized layers.
Handles dynamic dequantization of weights during forward pass.
"""
comfy_cast_weights: bool = True
dequant_dtype: torch.dtype = None
patch_dtype: torch.dtype = None
torch_compatible_tensor_types: set = {
None,
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
}
def is_ggml_quantized(
self, *, weight: torch.Tensor = None, bias: torch.Tensor = None
) -> bool:
"""#### Check if layer weights are GGML quantized.
#### Args:
- `weight` (torch.Tensor, optional): Weight tensor to check. Defaults to self.weight.
- `bias` (torch.Tensor, optional): Bias tensor to check. Defaults to self.bias.
#### Returns:
- `bool`: Whether weights are quantized.
"""
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
return is_quantized(weight) or is_quantized(bias)
def _load_from_state_dict(
self, state_dict: dict, prefix: str, *args, **kwargs
) -> None:
"""#### Load quantized weights from state dict.
#### Args:
- `state_dict` (dict): State dictionary.
- `prefix` (str): Key prefix.
- `*args`: Additional arguments.
- `**kwargs`: Additional keyword arguments.
"""
weight = state_dict.get(f"{prefix}weight")
bias = state_dict.get(f"{prefix}bias")
# Use modified loader for quantized or linear layers
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(
self, torch.nn.Linear
):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def ggml_load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list,
unexpected_keys: list,
error_msgs: list,
) -> None:
"""#### Load GGML quantized weights from state dict.
#### Args:
- `state_dict` (dict): State dictionary.
- `prefix` (str): Key prefix.
- `local_metadata` (dict): Local metadata.
- `strict` (bool): Strict loading mode.
- `missing_keys` (list): Keys missing from state dict.
- `unexpected_keys` (list): Unexpected keys found.
- `error_msgs` (list): Error messages.
"""
prefix_len = len(prefix)
for k, v in state_dict.items():
if k[prefix_len:] == "weight":
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
missing_keys.append(k)
def _save_to_state_dict(self, *args, **kwargs) -> None:
"""#### Save layer state to state dict.
#### Args:
- `*args`: Additional arguments.
- `**kwargs`: Additional keyword arguments.
"""
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
return super()._save_to_state_dict(*args, **kwargs)
def ggml_save_to_state_dict(
self, destination: dict, prefix: str, keep_vars: bool
) -> None:
"""#### Save GGML layer state to state dict.
#### Args:
- `destination` (dict): Destination dictionary.
- `prefix` (str): Key prefix.
- `keep_vars` (bool): Whether to keep variables.
"""
# Create fake tensors for VRAM estimation
weight = torch.zeros_like(self.weight, device=torch.device("meta"))
destination[prefix + "weight"] = weight
if self.bias is not None:
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
destination[prefix + "bias"] = bias
return
def get_weight(self, tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""#### Get dequantized weight tensor.
#### Args:
- `tensor` (torch.Tensor): Input tensor.
- `dtype` (torch.dtype): Target dtype.
#### Returns:
- `torch.Tensor`: Dequantized tensor.
"""
if tensor is None:
return
# Consolidate and load patches to GPU asynchronously
patch_list = []
device = tensor.device
for function, patches, key in getattr(tensor, "patches", []):
patch_list += move_patch_to_device(patches, device)
# Dequantize tensor while patches load
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
# Apply patches
if patch_list:
if self.patch_dtype is None:
weight = function(patch_list, weight, key)
else:
# For testing, may degrade image quality
patch_dtype = (
dtype if self.patch_dtype == "target" else self.patch_dtype
)
weight = function(patch_list, weight, key, patch_dtype)
return weight
def cast_bias_weight(
self,
input: torch.Tensor = None,
dtype: torch.dtype = None,
device: torch.device = None,
bias_dtype: torch.dtype = None,
) -> tuple:
"""#### Cast layer weights and bias to target dtype/device.
#### Args:
- `input` (torch.Tensor, optional): Input tensor for type/device inference.
- `dtype` (torch.dtype, optional): Target dtype.
- `device` (torch.device, optional): Target device.
- `bias_dtype` (torch.dtype, optional): Target bias dtype.
#### Returns:
- `tuple`: (cast_weight, cast_bias)
"""
if input is not None:
if dtype is None:
dtype = getattr(input, "dtype", torch.float32)
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = Device.device_supports_non_blocking(device)
if self.bias is not None:
bias = self.get_weight(self.bias.to(device), dtype)
bias = cast.cast_to(
bias, bias_dtype, device, non_blocking=non_blocking, copy=False
)
weight = self.get_weight(self.weight.to(device), dtype)
weight = cast.cast_to(
weight, dtype, device, non_blocking=non_blocking, copy=False
)
return weight, bias
def forward_comfy_cast_weights(
self, input: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""#### Forward pass with weight casting.
#### Args:
- `input` (torch.Tensor): Input tensor.
- `*args`: Additional arguments.
- `**kwargs`: Additional keyword arguments.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
if self.is_ggml_quantized():
return self.forward_ggml_cast_weights(input, *args, **kwargs)
return super().forward_comfy_cast_weights(input, *args, **kwargs)
class GGMLOps(cast.manual_cast):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGMLLayer, cast.manual_cast.Linear):
def __init__(
self, in_features, out_features, bias=True, device=None, dtype=None
):
"""
Initialize the Linear layer.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool, optional): If set to False, the layer will not learn an additive bias. Defaults to True.
device (torch.device, optional): The device to store the layer's parameters. Defaults to None.
dtype (torch.dtype, optional): The data type of the layer's parameters. Defaults to None.
"""
torch.nn.Module.__init__(self)
# TODO: better workaround for reserved memory spike on windows
# Issue is with `torch.empty` still reserving the full memory for the layer
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.bias = None
def forward_ggml_cast_weights(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass with GGML cast weights.
Args:
input (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.linear(input, weight, bias)
class Embedding(GGMLLayer, cast.manual_cast.Embedding):
def forward_ggml_cast_weights(
self, input: torch.Tensor, out_dtype: torch.dtype = None
) -> torch.Tensor:
"""
Forward pass with GGML cast weights for embedding.
Args:
input (torch.Tensor): The input tensor.
out_dtype (torch.dtype, optional): The output data type. Defaults to None.
Returns:
torch.Tensor: The output tensor.
"""
output_dtype = out_dtype
if (
self.weight.dtype == torch.float16
or self.weight.dtype == torch.bfloat16
):
out_dtype = None
weight, _bias = self.cast_bias_weight(
self, device=input.device, dtype=out_dtype
)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
).to(dtype=output_dtype)
def gguf_sd_loader_get_orig_shape(
reader: gguf.GGUFReader, tensor_name: str
) -> torch.Size:
"""#### Get the original shape of a tensor from a GGUF reader.
#### Args:
- `reader` (gguf.GGUFReader): The GGUF reader.
- `tensor_name` (str): The name of the tensor.
#### Returns:
- `torch.Size`: The original shape of the tensor.
"""
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
field = reader.get_field(field_key)
if field is None:
return None
# Has original shape metadata, so we try to decode it.
if (
len(field.types) != 2
or field.types[0] != gguf.GGUFValueType.ARRAY
or field.types[1] != gguf.GGUFValueType.INT32
):
raise TypeError(
f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}"
)
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
class GGMLTensor(torch.Tensor):
"""
Main tensor-like class for storing quantized weights
"""
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
"""
Initialize the GGMLTensor.
Args:
*args: Variable length argument list.
tensor_type: The type of the tensor.
tensor_shape: The shape of the tensor.
patches (list, optional): List of patches. Defaults to [].
**kwargs: Arbitrary keyword arguments.
"""
super().__init__()
self.tensor_type = tensor_type
self.tensor_shape = tensor_shape
self.patches = patches
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
"""
Create a new instance of GGMLTensor.
Args:
*args: Variable length argument list.
tensor_type: The type of the tensor.
tensor_shape: The shape of the tensor.
patches (list, optional): List of patches. Defaults to [].
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: A new instance of GGMLTensor.
"""
return super().__new__(cls, *args, **kwargs)
def to(self, *args, **kwargs):
"""
Convert the tensor to a specified device and/or dtype.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: The converted tensor.
"""
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
def clone(self, *args, **kwargs):
"""
Clone the tensor.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: The cloned tensor.
"""
return self
def detach(self, *args, **kwargs):
"""
Detach the tensor from the computation graph.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: The detached tensor.
"""
return self
def copy_(self, *args, **kwargs):
"""
Copy the values from another tensor into this tensor.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: The tensor with copied values.
"""
try:
return super().copy_(*args, **kwargs)
except Exception as e:
print(f"ignoring 'copy_' on tensor: {e}")
def __deepcopy__(self, *args, **kwargs):
"""
Create a deep copy of the tensor.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGMLTensor: The deep copied tensor.
"""
new = super().__deepcopy__(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
@property
def shape(self):
"""
Get the shape of the tensor.
Returns:
torch.Size: The shape of the tensor.
"""
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
def gguf_sd_loader(path: str, handle_prefix: str = "model.diffusion_model."):
"""#### Load a GGUF file into a state dict.
#### Args:
- `path` (str): The path to the GGUF file.
- `handle_prefix` (str, optional): The prefix to handle. Defaults to "model.diffusion_model.".
#### Returns:
- `dict`: The loaded state dict.
"""
reader = gguf.GGUFReader(path)
# filter and strip prefix
has_prefix = False
if handle_prefix is not None:
prefix_len = len(handle_prefix)
tensor_names = set(tensor.name for tensor in reader.tensors)
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
tensors = []
for tensor in reader.tensors:
sd_key = tensor_name = tensor.name
if has_prefix:
if not tensor_name.startswith(handle_prefix):
continue
sd_key = tensor_name[prefix_len:]
tensors.append((sd_key, tensor))
# detect and verify architecture
compat = None
arch_str = None
arch_field = reader.get_field("general.architecture")
if arch_field is not None:
if (
len(arch_field.types) != 1
or arch_field.types[0] != gguf.GGUFValueType.STRING
):
raise TypeError(
f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}"
)
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
if arch_str not in {"flux", "sd1", "sdxl", "t5", "t5encoder"}:
raise ValueError(
f"Unexpected architecture type in GGUF file, expected one of flux, sd1, sdxl, t5encoder but got {arch_str!r}"
)
# main loading loop
state_dict = {}
qtype_dict = {}
for sd_key, tensor in tensors:
tensor_name = tensor.name
tensor_type_str = str(tensor.tensor_type)
torch_tensor = torch.from_numpy(tensor.data) # mmap
shape = gguf_sd_loader_get_orig_shape(reader, tensor_name)
if shape is None:
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# Workaround for stable-diffusion.cpp SDXL detection.
if compat == "sd.cpp" and arch_str == "sdxl":
if any(
[
tensor_name.endswith(x)
for x in (".proj_in.weight", ".proj_out.weight")
]
):
while len(shape) > 2 and shape[-1] == 1:
shape = shape[:-1]
# add to state dict
if tensor.tensor_type in {
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
}:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGMLTensor(
torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape
)
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
# sanity check debug print
print("\nggml_sd_loader:")
for k, v in qtype_dict.items():
print(f" {k:30}{v:3}")
return state_dict
class GGUFModelPatcher(ModelPatcher.ModelPatcher):
patch_on_device = False
def unpatch_model(self, device_to=None, unpatch_weights=True):
"""
Unpatch the model.
Args:
device_to (torch.device, optional): The device to move the model to. Defaults to None.
unpatch_weights (bool, optional): Whether to unpatch the weights. Defaults to True.
Returns:
GGUFModelPatcher: The unpatched model.
"""
if unpatch_weights:
for p in self.model.parameters():
if is_torch_compatible(p):
continue
patches = getattr(p, "patches", [])
if len(patches) > 0:
p.patches = []
self.object_patches = {}
# TODO: Find another way to not unload after patches
return super().unpatch_model(
device_to=device_to, unpatch_weights=unpatch_weights
)
mmap_released = False
def load(self, *args, force_patch_weights=False, **kwargs):
"""
Load the model.
Args:
*args: Variable length argument list.
force_patch_weights (bool, optional): Whether to force patch weights. Defaults to False.
**kwargs: Arbitrary keyword arguments.
"""
super().load(*args, force_patch_weights=True, **kwargs)
# make sure nothing stays linked to mmap after first load
if not self.mmap_released:
linked = []
if kwargs.get("lowvram_model_memory", 0) > 0:
for n, m in self.model.named_modules():
if hasattr(m, "weight"):
device = getattr(m.weight, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if hasattr(m, "bias"):
device = getattr(m.bias, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if linked:
print(f"Attempting to release mmap ({len(linked)})")
for n, m in linked:
# TODO: possible to OOM, find better way to detach
m.to(self.load_device).to(self.offload_device)
self.mmap_released = True
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def clone(self, *args, **kwargs):
"""
Clone the model patcher.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
GGUFModelPatcher: The cloned model patcher.
"""
n = GGUFModelPatcher(
self.model,
self.load_device,
self.offload_device,
self.size,
weight_inplace_update=self.weight_inplace_update,
)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.patch_on_device = getattr(self, "patch_on_device", False)
return n
class UnetLoaderGGUF:
def load_unet(
self,
unet_name: str,
dequant_dtype: str = None,
patch_dtype: str = None,
patch_on_device: bool = None,
) -> tuple:
"""
Load the UNet model.
Args:
unet_name (str): The name of the UNet model.
dequant_dtype (str, optional): The dequantization data type. Defaults to None.
patch_dtype (str, optional): The patch data type. Defaults to None.
patch_on_device (bool, optional): Whether to patch on device. Defaults to None.
Returns:
tuple: The loaded model.
"""
ops = GGMLOps()
if dequant_dtype in ("default", None):
ops.Linear.dequant_dtype = None
elif dequant_dtype in ["target"]:
ops.Linear.dequant_dtype = dequant_dtype
else:
ops.Linear.dequant_dtype = getattr(torch, dequant_dtype)
if patch_dtype in ("default", None):
ops.Linear.patch_dtype = None
elif patch_dtype in ["target"]:
ops.Linear.patch_dtype = patch_dtype
else:
ops.Linear.patch_dtype = getattr(torch, patch_dtype)
unet_path = "./_internal/unet/" + unet_name
sd = gguf_sd_loader(unet_path)
model = ModelPatcher.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError(
"ERROR: Could not detect model type of: {}".format(unet_path)
)
model = GGUFModelPatcher.clone(model)
model.patch_on_device = patch_on_device
return (model,)
clip_sd_map = {
"enc.": "encoder.",
".blk.": ".block.",
"token_embd": "shared",
"output_norm": "final_layer_norm",
"attn_q": "layer.0.SelfAttention.q",
"attn_k": "layer.0.SelfAttention.k",
"attn_v": "layer.0.SelfAttention.v",
"attn_o": "layer.0.SelfAttention.o",
"attn_norm": "layer.0.layer_norm",
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
"ffn_up": "layer.1.DenseReluDense.wi_1",
"ffn_down": "layer.1.DenseReluDense.wo",
"ffn_gate": "layer.1.DenseReluDense.wi_0",
"ffn_norm": "layer.1.layer_norm",
}
clip_name_dict = {
"stable_diffusion": Clip.CLIPType.STABLE_DIFFUSION,
"sdxl": Clip.CLIPType.STABLE_DIFFUSION,
"sd3": Clip.CLIPType.SD3,
"flux": Clip.CLIPType.FLUX,
}
def gguf_clip_loader(path: str) -> dict:
"""#### Load a CLIP model from a GGUF file.
#### Args:
- `path` (str): The path to the GGUF file.
#### Returns:
- `dict`: The loaded CLIP model.
"""
raw_sd = gguf_sd_loader(path)
assert "enc.blk.23.ffn_up.weight" in raw_sd, "Invalid Text Encoder!"
sd = {}
for k, v in raw_sd.items():
for s, d in clip_sd_map.items():
k = k.replace(s, d)
sd[k] = v
return sd
class CLIPLoaderGGUF:
def load_data(self, ckpt_paths: list) -> list:
"""
Load data from checkpoint paths.
Args:
ckpt_paths (list): List of checkpoint paths.
Returns:
list: List of loaded data.
"""
clip_data = []
for p in ckpt_paths:
if p.endswith(".gguf"):
clip_data.append(gguf_clip_loader(p))
else:
sd = util.load_torch_file(p, safe_load=True)
clip_data.append(
{
k: GGMLTensor(
v,
tensor_type=gguf.GGMLQuantizationType.F16,
tensor_shape=v.shape,
)
for k, v in sd.items()
}
)
return clip_data
def load_patcher(self, clip_paths: list, clip_type: str, clip_data: list) -> Clip:
"""
Load the model patcher.
Args:
clip_paths (list): List of clip paths.
clip_type (str): The type of the clip.
clip_data (list): List of clip data.
Returns:
Clip: The loaded clip.
"""
clip = Clip.load_text_encoder_state_dicts(
clip_type=clip_type,
state_dicts=clip_data,
model_options={
"custom_operations": GGMLOps,
"initial_device": Device.text_encoder_offload_device(),
},
embedding_directory="models/embeddings",
)
clip.patcher = GGUFModelPatcher.clone(clip.patcher)
# for some reason this is just missing in some SAI checkpoints
if getattr(clip.cond_stage_model, "clip_l", None) is not None:
if (
getattr(
clip.cond_stage_model.clip_l.transformer.text_projection.weight,
"tensor_shape",
None,
)
is None
):
clip.cond_stage_model.clip_l.transformer.text_projection = (
cast.manual_cast.Linear(768, 768)
)
if getattr(clip.cond_stage_model, "clip_g", None) is not None:
if (
getattr(
clip.cond_stage_model.clip_g.transformer.text_projection.weight,
"tensor_shape",
None,
)
is None
):
clip.cond_stage_model.clip_g.transformer.text_projection = (
cast.manual_cast.Linear(1280, 1280)
)
return clip
class DualCLIPLoaderGGUF(CLIPLoaderGGUF):
def load_clip(self, clip_name1: str, clip_name2: str, type: str) -> tuple:
"""
Load dual clips.
Args:
clip_name1 (str): The name of the first clip.
clip_name2 (str): The name of the second clip.
type (str): The type of the clip.
Returns:
tuple: The loaded clips.
"""
clip_path1 = "./_internal/clip/" + clip_name1
clip_path2 = "./_internal/clip/" + clip_name2
clip_paths = (clip_path1, clip_path2)
clip_type = clip_name_dict.get(type, Clip.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
class CLIPTextEncodeFlux:
def encode(
self,
clip: Clip,
clip_l: str,
t5xxl: str,
guidance: str,
flux_enabled: bool = False,
) -> tuple:
"""
Encode text using CLIP and T5XXL.
Args:
clip (Clip): The clip object.
clip_l (str): The clip text.
t5xxl (str): The T5XXL text.
guidance (str): The guidance text.
flux_enabled (bool, optional): Whether flux is enabled. Defaults to False.
Returns:
tuple: The encoded text.
"""
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
output = clip.encode_from_tokens(
tokens, return_pooled=True, return_dict=True, flux_enabled=flux_enabled
)
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]],)
class ConditioningZeroOut:
def zero_out(self, conditioning: list) -> list:
"""
Zero out the conditioning.
Args:
conditioning (list): The conditioning list.
Returns:
list: The zeroed out conditioning.
"""
c = []
for t in conditioning:
d = t[1].copy()
pooled_output = d.get("pooled_output", None)
if pooled_output is not None:
d["pooled_output"] = torch.zeros_like(pooled_output)
n = [torch.zeros_like(t[0]), d]
c.append(n)
return (c,)