Spaces:
Running
on
Zero
Running
on
Zero
Cleanup fa3.py
Browse files
fa3.py
CHANGED
|
@@ -15,101 +15,4 @@ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.
|
|
| 15 |
|
| 16 |
@flash_attn_func.register_fake
|
| 17 |
def _(q, k, v, **kwargs):
|
| 18 |
-
|
| 19 |
-
# 1. output: (batch, seq_len, num_heads, head_dim)
|
| 20 |
-
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
| 21 |
-
meta_q = torch.empty_like(q).contiguous()
|
| 22 |
-
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
|
| 23 |
-
|
| 24 |
-
# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
|
| 25 |
-
class FlashFusedFluxAttnProcessor3_0:
|
| 26 |
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
| 27 |
-
|
| 28 |
-
def __call__(
|
| 29 |
-
self,
|
| 30 |
-
attn,
|
| 31 |
-
hidden_states: torch.FloatTensor,
|
| 32 |
-
encoder_hidden_states: torch.FloatTensor | None = None,
|
| 33 |
-
attention_mask: torch.FloatTensor | None = None,
|
| 34 |
-
image_rotary_emb: torch.Tensor | None = None,
|
| 35 |
-
) -> torch.FloatTensor:
|
| 36 |
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
-
|
| 38 |
-
# `sample` projections.
|
| 39 |
-
qkv = attn.to_qkv(hidden_states)
|
| 40 |
-
split_size = qkv.shape[-1] // 3
|
| 41 |
-
query, key, value = torch.split(qkv, split_size, dim=-1)
|
| 42 |
-
|
| 43 |
-
inner_dim = key.shape[-1]
|
| 44 |
-
head_dim = inner_dim // attn.heads
|
| 45 |
-
|
| 46 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 47 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 48 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 49 |
-
|
| 50 |
-
if attn.norm_q is not None:
|
| 51 |
-
query = attn.norm_q(query)
|
| 52 |
-
if attn.norm_k is not None:
|
| 53 |
-
key = attn.norm_k(key)
|
| 54 |
-
|
| 55 |
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
| 56 |
-
# `context` projections.
|
| 57 |
-
if encoder_hidden_states is not None:
|
| 58 |
-
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
| 59 |
-
split_size = encoder_qkv.shape[-1] // 3
|
| 60 |
-
(
|
| 61 |
-
encoder_hidden_states_query_proj,
|
| 62 |
-
encoder_hidden_states_key_proj,
|
| 63 |
-
encoder_hidden_states_value_proj,
|
| 64 |
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
| 65 |
-
|
| 66 |
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 67 |
-
batch_size, -1, attn.heads, head_dim
|
| 68 |
-
).transpose(1, 2)
|
| 69 |
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 70 |
-
batch_size, -1, attn.heads, head_dim
|
| 71 |
-
).transpose(1, 2)
|
| 72 |
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 73 |
-
batch_size, -1, attn.heads, head_dim
|
| 74 |
-
).transpose(1, 2)
|
| 75 |
-
|
| 76 |
-
if attn.norm_added_q is not None:
|
| 77 |
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 78 |
-
if attn.norm_added_k is not None:
|
| 79 |
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 80 |
-
|
| 81 |
-
# attention
|
| 82 |
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 83 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 84 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 85 |
-
|
| 86 |
-
if image_rotary_emb is not None:
|
| 87 |
-
from diffusers.models.embeddings import apply_rotary_emb
|
| 88 |
-
|
| 89 |
-
query = apply_rotary_emb(query, image_rotary_emb)
|
| 90 |
-
key = apply_rotary_emb(key, image_rotary_emb)
|
| 91 |
-
|
| 92 |
-
# NB: transposes are necessary to match expected SDPA input shape
|
| 93 |
-
hidden_states = flash_attn_func(
|
| 94 |
-
query.transpose(1, 2),
|
| 95 |
-
key.transpose(1, 2),
|
| 96 |
-
value.transpose(1, 2))[0].transpose(1, 2)
|
| 97 |
-
|
| 98 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 99 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 100 |
-
|
| 101 |
-
if encoder_hidden_states is not None:
|
| 102 |
-
encoder_hidden_states, hidden_states = (
|
| 103 |
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 104 |
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# linear proj
|
| 108 |
-
hidden_states = attn.to_out[0](hidden_states)
|
| 109 |
-
# dropout
|
| 110 |
-
hidden_states = attn.to_out[1](hidden_states)
|
| 111 |
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 112 |
-
|
| 113 |
-
return hidden_states, encoder_hidden_states
|
| 114 |
-
else:
|
| 115 |
-
return hidden_states
|
|
|
|
| 15 |
|
| 16 |
@flash_attn_func.register_fake
|
| 17 |
def _(q, k, v, **kwargs):
|
| 18 |
+
return torch.empty_like(q).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|