multimodalart HF Staff commited on
Commit
04ca81f
·
verified ·
1 Parent(s): d78ad44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -33,7 +33,7 @@ try:
33
  with open(attention_file_path, "r") as f:
34
  content = f.read()
35
 
36
- # Define the original problematic code block
37
  original_code = """ x, *_ = flash_attn_func(
38
  q,
39
  k,
@@ -42,17 +42,25 @@ try:
42
  )
43
  x = rearrange(x, "B S H D -> B H S D")"""
44
 
45
- # Define the corrected code block that handles the 3D output of FA3
46
- corrected_code = """ # The output of flash_attn_func is 3D (total_tokens, H, D), but the code expects 4D.
47
- # We get B and S from the rearranged q's shape and reshape the output tensor x.
48
- B, S, H, D = q.shape
49
- x, *_ = flash_attn_func(
50
  q,
51
  k,
52
  v,
53
  softmax_scale=self.scale,
54
  )
55
- x = x.view(B, S, H, D) # Reshape from 3D to 4D
 
 
 
 
 
 
 
 
 
 
 
56
  x = rearrange(x, "B S H D -> B H S D")"""
57
 
58
  if original_code in content:
 
33
  with open(attention_file_path, "r") as f:
34
  content = f.read()
35
 
36
+ # Original code block that we need to replace
37
  original_code = """ x, *_ = flash_attn_func(
38
  q,
39
  k,
 
42
  )
43
  x = rearrange(x, "B S H D -> B H S D")"""
44
 
45
+ # Corrected code block to handle FA3's 3D output shape
46
+ corrected_code = """ x, *_ = flash_attn_func(
 
 
 
47
  q,
48
  k,
49
  v,
50
  softmax_scale=self.scale,
51
  )
52
+ # The output of FA3's flash_attn_func can be 3D (total_tokens, H, D).
53
+ # We need to robustly reshape it back to the 4D format (B, S, H, D) that the
54
+ # subsequent rearrange operation expects.
55
+ if x.ndim == 3:
56
+ # B is the original batch size from the input q tensor
57
+ B = q.shape[0]
58
+ # S_total is the flattened batch and sequence length
59
+ S_total, H, D = x.shape
60
+ # Calculate the sequence length per batch item
61
+ S = S_total // B
62
+ x = x.view(B, S, H, D)
63
+
64
  x = rearrange(x, "B S H D -> B H S D")"""
65
 
66
  if original_code in content: