Aatricks's picture
Upload folder using huggingface_hub
a1693c9
import math
import os
import torch
from modules.Attention import Attention
from modules.Device import Device
from modules.SD15 import SDClip, SDToken
from modules.cond import cast
from transformers import T5TokenizerFast
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
"relu": torch.nn.functional.relu,
}
class T5DenseGatedActDense(torch.nn.Module):
"""#### Dense Gated Activation Layer"""
def __init__(self, model_dim: int, ff_dim: int, ff_activation: str, dtype: torch.dtype, device: torch.device, operations):
"""#### Initialize Dense Gated Activation Layer
#### Args:
- `model_dim` (int): Model dimension.
- `ff_dim` (int): Feedforward dimension.
- `ff_activation` (str): Feedforward activation function.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
self.wi_0 = operations.Linear(
model_dim, ff_dim, bias=False, dtype=dtype, device=device
)
self.wi_1 = operations.Linear(
model_dim, ff_dim, bias=False, dtype=dtype, device=device
)
self.wo = operations.Linear(
ff_dim, model_dim, bias=False, dtype=dtype, device=device
)
# self.dropout = nn.Dropout(config.dropout_rate)
self.act = activations[ff_activation]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
hidden_gelu = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
# x = self.dropout(x)
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
"""#### Feedforward Layer"""
def __init__(
self, model_dim: int, ff_dim: int, ff_activation: str, gated_act: bool, dtype: torch.dtype, device: torch.device, operations
):
"""#### Initialize Feedforward Layer
#### Args:
- `model_dim` (int): Model dimension.
- `ff_dim` (int): Feedforward dimension.
- `ff_activation` (str): Feedforward activation function.
- `gated_act` (bool): Whether to use gated activation.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
if gated_act:
self.DenseReluDense = T5DenseGatedActDense(
model_dim, ff_dim, ff_activation, dtype, device, operations
)
self.layer_norm = T5LayerNorm(
model_dim, dtype=dtype, device=device, operations=operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
# x = x + self.dropout(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
"""#### Attention Layer"""
def __init__(
self,
model_dim: int,
inner_dim: int,
num_heads: int,
relative_attention_bias: bool,
dtype: torch.dtype,
device: torch.device,
operations,
):
"""#### Initialize Attention Layer
#### Args:
- `model_dim` (int): Model dimension.
- `inner_dim` (int): Inner dimension.
- `num_heads` (int): Number of attention heads.
- `relative_attention_bias` (bool): Whether to use relative attention bias.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = operations.Linear(
model_dim, inner_dim, bias=False, dtype=dtype, device=device
)
self.k = operations.Linear(
model_dim, inner_dim, bias=False, dtype=dtype, device=device
)
self.v = operations.Linear(
model_dim, inner_dim, bias=False, dtype=dtype, device=device
)
self.o = operations.Linear(
inner_dim, model_dim, bias=False, dtype=dtype, device=device
)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = operations.Embedding(
self.relative_attention_num_buckets,
self.num_heads,
device=device,
dtype=dtype,
)
@staticmethod
def _relative_position_bucket(
relative_position: torch.Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128
) -> torch.Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
#### Args:
- `relative_position` (torch.Tensor): Relative position tensor.
- `bidirectional` (bool): Whether the attention is bidirectional.
- `num_buckets` (int): Number of buckets.
- `max_distance` (int): Maximum distance.
#### Returns:
- `torch.Tensor`: Bucketed relative positions.
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length: int, key_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""#### Compute binned relative position bias
#### Args:
- `query_length` (int): Length of the query.
- `key_length` (int): Length of the key.
- `device` (torch.device): Device.
- `dtype` (torch.dtype): Data type.
#### Returns:
- `torch.Tensor`: Computed bias.
"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket, out_dtype=dtype
) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, past_bias: torch.Tensor = None, optimized_attention = None) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
- `mask` (torch.Tensor, optional): Attention mask. Defaults to None.
- `past_bias` (torch.Tensor, optional): Past bias. Defaults to None.
- `optimized_attention` (callable, optional): Optimized attention function. Defaults to None.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
if past_bias is not None:
if mask is not None:
mask = mask + past_bias
else:
mask = past_bias
out = optimized_attention(
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
"""#### Self-Attention Layer"""
def __init__(
self,
model_dim: int,
inner_dim: int,
ff_dim: int,
num_heads: int,
relative_attention_bias: bool,
dtype: torch.dtype,
device: torch.device,
operations,
):
"""#### Initialize Self-Attention Layer
#### Args:
- `model_dim` (int): Model dimension.
- `inner_dim` (int): Inner dimension.
- `ff_dim` (int): Feedforward dimension.
- `num_heads` (int): Number of attention heads.
- `relative_attention_bias` (bool): Whether to use relative attention bias.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
self.SelfAttention = T5Attention(
model_dim,
inner_dim,
num_heads,
relative_attention_bias,
dtype,
device,
operations,
)
self.layer_norm = T5LayerNorm(
model_dim, dtype=dtype, device=device, operations=operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, past_bias: torch.Tensor = None, optimized_attention = None) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
- `mask` (torch.Tensor, optional): Attention mask. Defaults to None.
- `past_bias` (torch.Tensor, optional): Past bias. Defaults to None.
- `optimized_attention` (callable, optional): Optimized attention function. Defaults to None.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
self.layer_norm(x)
output, past_bias = self.SelfAttention(
self.layer_norm(x),
mask=mask,
past_bias=past_bias,
optimized_attention=optimized_attention,
)
# x = x + self.dropout(attention_output)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
"""#### T5 Block"""
def __init__(
self,
model_dim: int,
inner_dim: int,
ff_dim: int,
ff_activation: str,
gated_act: bool,
num_heads: int,
relative_attention_bias: bool,
dtype: torch.dtype,
device: torch.device,
operations,
):
"""#### Initialize T5 Block
#### Args:
- `model_dim` (int): Model dimension.
- `inner_dim` (int): Inner dimension.
- `ff_dim` (int): Feedforward dimension.
- `ff_activation` (str): Feedforward activation function.
- `gated_act` (bool): Whether to use gated activation.
- `num_heads` (int): Number of attention heads.
- `relative_attention_bias` (bool): Whether to use relative attention bias.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
operations,
)
)
self.layer.append(
T5LayerFF(
model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations
)
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, past_bias: torch.Tensor = None, optimized_attention = None) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
- `mask` (torch.Tensor, optional): Attention mask. Defaults to None.
- `past_bias` (torch.Tensor, optional): Past bias. Defaults to None.
- `optimized_attention` (callable, optional): Optimized attention function. Defaults to None.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
"""#### T5 Stack"""
def __init__(
self,
num_layers: int,
model_dim: int,
inner_dim: int,
ff_dim: int,
ff_activation: str,
gated_act: bool,
num_heads: int,
relative_attention: bool,
dtype: torch.dtype,
device: torch.device,
operations,
):
"""#### Initialize T5 Stack
#### Args:
- `num_layers` (int): Number of layers.
- `model_dim` (int): Model dimension.
- `inner_dim` (int): Inner dimension.
- `ff_dim` (int): Feedforward dimension.
- `ff_activation` (str): Feedforward activation function.
- `gated_act` (bool): Whether to use gated activation.
- `num_heads` (int): Number of attention heads.
- `relative_attention` (bool): Whether to use relative attention.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
self.block = torch.nn.ModuleList(
[
T5Block(
model_dim,
inner_dim,
ff_dim,
ff_activation,
gated_act,
num_heads,
relative_attention_bias=((not relative_attention) or (i == 0)),
dtype=dtype,
device=device,
operations=operations,
)
for i in range(num_layers)
]
)
self.final_layer_norm = T5LayerNorm(
model_dim, dtype=dtype, device=device, operations=operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
intermediate_output: int = None,
final_layer_norm_intermediate: bool = True,
dtype: torch.dtype = None,
) -> torch.Tensor:
"""#### Forward Pass
#### Args:
- `x` (torch.Tensor): Input tensor.
- `attention_mask` (torch.Tensor, optional): Attention mask. Defaults to None.
- `intermediate_output` (int, optional): Intermediate output index. Defaults to None.
- `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer norm to intermediate output. Defaults to True.
- `dtype` (torch.dtype, optional): Data type. Defaults to None.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
).expand(
attention_mask.shape[0],
1,
attention_mask.shape[-1],
attention_mask.shape[-1],
)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
intermediate = None
optimized_attention = Attention.optimized_attention_for_device()
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
"""#### Initialize T5 Model
#### Args:
- `config_dict` (dict): Configuration dictionary.
- `dtype` (torch.dtype): Data type.
- `device` (torch.device): Device.
- `operations` (Operations): Operations.
"""
super().__init__()
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
self.encoder = T5Stack(
self.num_layers,
model_dim,
model_dim,
config_dict["d_ff"],
config_dict["dense_act_fn"],
config_dict["is_gated_act"],
config_dict["num_heads"],
config_dict["model_type"] != "umt5",
dtype,
device,
operations,
)
self.dtype = dtype
self.shared = operations.Embedding(
config_dict["vocab_size"], model_dim, device=device, dtype=dtype
)
def get_input_embeddings(self) -> torch.nn.Embedding:
"""#### Get input embeddings
#### Returns:
- `torch.nn.Embedding`: The input embeddings.
"""
return self.shared
def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None:
"""#### Set input embeddings
#### Args:
- `embeddings` (torch.nn.Embedding): The input embeddings.
"""
self.shared = embeddings
def forward(self, input_ids: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""#### Forward pass
#### Args:
- `input_ids` (torch.Tensor): Input tensor.
- `*args`: Additional arguments.
- `**kwargs`: Additional keyword arguments.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) # Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)
class T5XXLModel(SDClip.SDClipModel):
def __init__(
self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}
):
"""#### Initialize T5XXL Model
#### Args:
- `device` (str, optional): Device. Defaults to "cpu".
- `layer` (str, optional): Layer. Defaults to "last".
- `layer_idx` (int, optional): Layer index. Defaults to None.
- `dtype` (torch.dtype, optional): Data type. Defaults to None.
- `model_options` (dict, optional): Model options. Defaults to {}.
"""
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"./clip/t5_config_xxl.json",
)
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=textmodel_json_config,
dtype=dtype,
special_tokens={"end": 1, "pad": 0},
model_class=T5,
model_options=model_options,
)
class T5XXLTokenizer(SDToken.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
"""#### Initialize T5XXL Tokenizer
#### Args:
- `embedding_directory` (str, optional): Embedding directory. Defaults to None.
- `tokenizer_data` (dict, optional): Tokenizer data. Defaults to {}.
"""
tokenizer_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "./clip/t5_tokenizer"
)
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_size=4096,
embedding_key="t5xxl",
tokenizer_class=T5TokenizerFast,
has_start_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=256,
)
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
"""#### Initialize T5 Layer Normalization
#### Args:
- `hidden_size` (int): Hidden size.
- `eps` (float, optional): Epsilon. Defaults to 1e-6.
- `dtype` (torch.dtype, optional): Data type. Defaults to None.
- `device` (torch.device, optional): Device. Defaults to None.
- `operations` (Operations, optional): Operations. Defaults to None.
"""
super().__init__()
self.weight = torch.nn.Parameter(
torch.empty(hidden_size, dtype=dtype, device=device)
)
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return cast.cast_to_input(self.weight, x) * x
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
"""#### Initialize Flux Tokenizer
#### Args:
- `embedding_directory` (str, optional): Embedding directory. Defaults to None.
- `tokenizer_data` (dict, optional): Tokenizer data. Defaults to {}.
"""
clip_l_tokenizer_class = tokenizer_data.get(
"clip_l_tokenizer_class", SDToken.SDTokenizer
)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False) -> dict:
"""#### Tokenize text with weights
#### Args:
- `text` (str): Text to tokenize.
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
#### Returns:
- `dict`: Tokenized text with weights.
"""
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
return out
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
"""#### Initialize FluxClip Model
#### Args:
- `dtype_t5` (torch.dtype, optional): T5 data type. Defaults to None.
- `device` (str, optional): Device. Defaults to "cpu".
- `dtype` (torch.dtype, optional): Data type. Defaults to None.
- `model_options` (dict, optional): Model options. Defaults to {}.
"""
super().__init__()
dtype_t5 = Device.pick_weight_dtype(dtype_t5, dtype, device)
clip_l_class = model_options.get("clip_l_class", SDClip.SDClipModel)
self.clip_l = clip_l_class(
device=device,
dtype=dtype,
return_projected_pooled=False,
model_options=model_options,
)
self.t5xxl = T5XXLModel(
device=device, dtype=dtype_t5, model_options=model_options
)
self.dtypes = set([dtype, dtype_t5])
def reset_clip_options(self) -> None:
"""#### Reset CLIP options"""
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs: dict) -> tuple:
"""#### Encode token weights
#### Args:
- `token_weight_pairs` (dict): Token weight pairs.
#### Returns:
- `tuple`: Encoded token weights.
"""
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled
def load_sd(self, sd: dict) -> None:
"""#### Load state dictionary
#### Args:
- `sd` (dict): State dictionary.
"""
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
"""#### Create FluxClip Model
#### Args:
- `dtype_t5` (torch.dtype, optional): T5 data type. Defaults to None.
#### Returns:
- `FluxClipModel`: FluxClip Model class.
"""
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
"""#### Initialize FluxClip Model
#### Args:
- `device` (str, optional): Device. Defaults to "cpu".
- `dtype` (torch.dtype, optional): Data type. Defaults to None.
- `model_options` (dict, optional): Model options. Defaults to {}.
"""
super().__init__(
dtype_t5=dtype_t5,
device=device,
dtype=dtype,
model_options=model_options,
)
return FluxClipModel_