Spaces:
Runtime error
Runtime error
Revert "qkv"
Browse filesThis reverts commit cfad87d56b207199f6380882371c5b31639b8c41.
fa3.py
CHANGED
|
@@ -36,9 +36,9 @@ class FlashFusedFluxAttnProcessor3_0:
|
|
| 36 |
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
|
| 38 |
# `sample` projections.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
value =
|
| 42 |
|
| 43 |
inner_dim = key.shape[-1]
|
| 44 |
head_dim = inner_dim // attn.heads
|
|
|
|
| 36 |
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
|
| 38 |
# `sample` projections.
|
| 39 |
+
qkv = attn.to_qkv(hidden_states)
|
| 40 |
+
split_size = qkv.shape[-1] // 3
|
| 41 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
| 42 |
|
| 43 |
inner_dim = key.shape[-1]
|
| 44 |
head_dim = inner_dim // attn.heads
|