cbensimon HF Staff commited on
Commit
72118d6
·
verified ·
1 Parent(s): d70d883

Cleanup fa3.py

Browse files
Files changed (1) hide show
  1. fa3.py +1 -98
fa3.py CHANGED
@@ -15,101 +15,4 @@ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.
15
 
16
  @flash_attn_func.register_fake
17
  def _(q, k, v, **kwargs):
18
- # two outputs:
19
- # 1. output: (batch, seq_len, num_heads, head_dim)
20
- # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
21
- meta_q = torch.empty_like(q).contiguous()
22
- return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
23
-
24
- # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
25
- class FlashFusedFluxAttnProcessor3_0:
26
- """Attention processor used typically in processing the SD3-like self-attention projections."""
27
-
28
- def __call__(
29
- self,
30
- attn,
31
- hidden_states: torch.FloatTensor,
32
- encoder_hidden_states: torch.FloatTensor | None = None,
33
- attention_mask: torch.FloatTensor | None = None,
34
- image_rotary_emb: torch.Tensor | None = None,
35
- ) -> torch.FloatTensor:
36
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
-
38
- # `sample` projections.
39
- qkv = attn.to_qkv(hidden_states)
40
- split_size = qkv.shape[-1] // 3
41
- query, key, value = torch.split(qkv, split_size, dim=-1)
42
-
43
- inner_dim = key.shape[-1]
44
- head_dim = inner_dim // attn.heads
45
-
46
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
47
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
48
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
49
-
50
- if attn.norm_q is not None:
51
- query = attn.norm_q(query)
52
- if attn.norm_k is not None:
53
- key = attn.norm_k(key)
54
-
55
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
56
- # `context` projections.
57
- if encoder_hidden_states is not None:
58
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
59
- split_size = encoder_qkv.shape[-1] // 3
60
- (
61
- encoder_hidden_states_query_proj,
62
- encoder_hidden_states_key_proj,
63
- encoder_hidden_states_value_proj,
64
- ) = torch.split(encoder_qkv, split_size, dim=-1)
65
-
66
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
67
- batch_size, -1, attn.heads, head_dim
68
- ).transpose(1, 2)
69
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
70
- batch_size, -1, attn.heads, head_dim
71
- ).transpose(1, 2)
72
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
73
- batch_size, -1, attn.heads, head_dim
74
- ).transpose(1, 2)
75
-
76
- if attn.norm_added_q is not None:
77
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
78
- if attn.norm_added_k is not None:
79
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
80
-
81
- # attention
82
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
83
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
84
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
85
-
86
- if image_rotary_emb is not None:
87
- from diffusers.models.embeddings import apply_rotary_emb
88
-
89
- query = apply_rotary_emb(query, image_rotary_emb)
90
- key = apply_rotary_emb(key, image_rotary_emb)
91
-
92
- # NB: transposes are necessary to match expected SDPA input shape
93
- hidden_states = flash_attn_func(
94
- query.transpose(1, 2),
95
- key.transpose(1, 2),
96
- value.transpose(1, 2))[0].transpose(1, 2)
97
-
98
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
99
- hidden_states = hidden_states.to(query.dtype)
100
-
101
- if encoder_hidden_states is not None:
102
- encoder_hidden_states, hidden_states = (
103
- hidden_states[:, : encoder_hidden_states.shape[1]],
104
- hidden_states[:, encoder_hidden_states.shape[1] :],
105
- )
106
-
107
- # linear proj
108
- hidden_states = attn.to_out[0](hidden_states)
109
- # dropout
110
- hidden_states = attn.to_out[1](hidden_states)
111
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
112
-
113
- return hidden_states, encoder_hidden_states
114
- else:
115
- return hidden_states
 
15
 
16
  @flash_attn_func.register_fake
17
  def _(q, k, v, **kwargs):
18
+ return torch.empty_like(q).contiguous()