Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,7 +33,7 @@ try:
|
|
| 33 |
with open(attention_file_path, "r") as f:
|
| 34 |
content = f.read()
|
| 35 |
|
| 36 |
-
#
|
| 37 |
original_code = """ x, *_ = flash_attn_func(
|
| 38 |
q,
|
| 39 |
k,
|
|
@@ -42,17 +42,25 @@ try:
|
|
| 42 |
)
|
| 43 |
x = rearrange(x, "B S H D -> B H S D")"""
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
corrected_code = """
|
| 47 |
-
# We get B and S from the rearranged q's shape and reshape the output tensor x.
|
| 48 |
-
B, S, H, D = q.shape
|
| 49 |
-
x, *_ = flash_attn_func(
|
| 50 |
q,
|
| 51 |
k,
|
| 52 |
v,
|
| 53 |
softmax_scale=self.scale,
|
| 54 |
)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
x = rearrange(x, "B S H D -> B H S D")"""
|
| 57 |
|
| 58 |
if original_code in content:
|
|
|
|
| 33 |
with open(attention_file_path, "r") as f:
|
| 34 |
content = f.read()
|
| 35 |
|
| 36 |
+
# Original code block that we need to replace
|
| 37 |
original_code = """ x, *_ = flash_attn_func(
|
| 38 |
q,
|
| 39 |
k,
|
|
|
|
| 42 |
)
|
| 43 |
x = rearrange(x, "B S H D -> B H S D")"""
|
| 44 |
|
| 45 |
+
# Corrected code block to handle FA3's 3D output shape
|
| 46 |
+
corrected_code = """ x, *_ = flash_attn_func(
|
|
|
|
|
|
|
|
|
|
| 47 |
q,
|
| 48 |
k,
|
| 49 |
v,
|
| 50 |
softmax_scale=self.scale,
|
| 51 |
)
|
| 52 |
+
# The output of FA3's flash_attn_func can be 3D (total_tokens, H, D).
|
| 53 |
+
# We need to robustly reshape it back to the 4D format (B, S, H, D) that the
|
| 54 |
+
# subsequent rearrange operation expects.
|
| 55 |
+
if x.ndim == 3:
|
| 56 |
+
# B is the original batch size from the input q tensor
|
| 57 |
+
B = q.shape[0]
|
| 58 |
+
# S_total is the flattened batch and sequence length
|
| 59 |
+
S_total, H, D = x.shape
|
| 60 |
+
# Calculate the sequence length per batch item
|
| 61 |
+
S = S_total // B
|
| 62 |
+
x = x.view(B, S, H, D)
|
| 63 |
+
|
| 64 |
x = rearrange(x, "B S H D -> B H S D")"""
|
| 65 |
|
| 66 |
if original_code in content:
|