Spaces:
Runtime error
Runtime error
qkv
Browse files
fa3.py
CHANGED
|
@@ -36,9 +36,9 @@ class FlashFusedFluxAttnProcessor3_0:
|
|
| 36 |
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
|
| 38 |
# `sample` projections.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
inner_dim = key.shape[-1]
|
| 44 |
head_dim = inner_dim // attn.heads
|
|
|
|
| 36 |
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
|
| 38 |
# `sample` projections.
|
| 39 |
+
query = attn.to_q(hidden_states)
|
| 40 |
+
key = attn.to_k(hidden_states)
|
| 41 |
+
value = attn.to_v(hidden_states)
|
| 42 |
|
| 43 |
inner_dim = key.shape[-1]
|
| 44 |
head_dim = inner_dim // attn.heads
|