akshat-perceptron commited on
Commit
08cbfa9
·
verified ·
1 Parent(s): c10d94a

Update to add SDPA support

Browse files
Files changed (1) hide show
  1. modular_isaac.py +131 -58
modular_isaac.py CHANGED
@@ -1,7 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from collections import defaultdict
4
- from typing import Any, Union, TypedDict
5
 
6
  import math
7
  import numpy as np
@@ -81,6 +81,91 @@ def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device)
81
  return cu_seqlens, max_seqlen
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  class Siglip2VariableSequenceEmbeddings(nn.Module):
85
  def __init__(self, config: PixelShuffleSiglip2VisionConfig):
86
  super().__init__()
@@ -172,58 +257,42 @@ class Siglip2VariableLengthAttention(nn.Module):
172
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
173
 
174
  def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
175
- batch_size, seq_len, _ = hidden_states.size()
176
-
177
- # For variable-length attention, we need to reshape to (total_tokens, embed_dim)
178
  if batch_size != 1:
179
- raise ValueError("Variable-length attention expects batch_size=1 for packed sequences")
180
- hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim)
181
-
182
- # Store original dtype
183
- orig_dtype = hidden_states.dtype
184
-
185
- # 1. Linear projections
186
- Q = self.q_proj(hidden_states) # (seq_len, embed_dim)
187
- K = self.k_proj(hidden_states) # (seq_len, embed_dim)
188
- V = self.v_proj(hidden_states) # (seq_len, embed_dim)
189
-
190
- # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim)
191
- Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads)
192
- K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads)
193
- V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads)
194
-
195
- # 3. Apply variable-length attention using flash attention
196
- attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward(
197
- query=Q,
198
- key=K,
199
- value=V,
200
- cum_seq_q=cu_seqlens,
201
- cum_seq_k=cu_seqlens,
202
- max_q=max_seqlen,
203
- max_k=max_seqlen,
204
- dropout_p=self.dropout if self.training else 0.0,
205
- is_causal=False,
206
- return_debug_mask=False,
207
- scale=self.scale,
208
- window_size_left=-1,
209
- window_size_right=-1,
210
- alibi_slopes=None,
211
- )
212
-
213
- # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim)
214
- attn_output = attn_output.reshape(seq_len, self.embed_dim)
215
-
216
- # 5. Convert back to original dtype if needed
217
- if attn_output.dtype != orig_dtype:
218
- attn_output = attn_output.to(orig_dtype)
219
-
220
- # 6. Project output
221
- attn_output = self.out_proj(attn_output) # (seq_len, embed_dim)
222
-
223
- # 7. Add back batch dimension for compatibility
224
- attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim)
225
 
226
- return attn_output, None
 
 
227
 
228
 
229
  class IsaacSiglip2EncoderLayer(nn.Module):
@@ -805,6 +874,7 @@ class IsaacConfig(Qwen3Config):
805
  pixel_shuffle_scale: int = 1,
806
  max_sequence_length: int = 16384,
807
  vision_token: str = "<image>",
 
808
  **kwargs,
809
  ):
810
  super().__init__(**kwargs)
@@ -826,6 +896,7 @@ class IsaacConfig(Qwen3Config):
826
  # Processing parameters
827
  self.max_sequence_length = max_sequence_length
828
  self.vision_token = vision_token
 
829
 
830
 
831
  # ============================================================================
@@ -880,7 +951,6 @@ class IsaacProcessor(ProcessorMixin):
880
  attributes = ["tokenizer"]
881
  tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
882
 
883
-
884
  def __init__(
885
  self,
886
  tokenizer: Qwen2Tokenizer,
@@ -992,8 +1062,8 @@ class IsaacProcessor(ProcessorMixin):
992
 
993
  def __call__(
994
  self,
995
- text: Union[str, list[str]],
996
- images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None,
997
  return_tensors: str | TensorType | None = TensorType.PYTORCH,
998
  **kwargs,
999
  ) -> BatchFeature:
@@ -1135,6 +1205,12 @@ class IsaacModel(Qwen3Model):
1135
  self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
1136
 
1137
  vision_cfg = config.vision_config
 
 
 
 
 
 
1138
  if vision_cfg is None:
1139
  raise ValueError("IsaacConfig should always have vision_config")
1140
 
@@ -1418,9 +1494,7 @@ class IsaacModel(Qwen3Model):
1418
  causal_mask = attention_mask
1419
  else:
1420
  min_dtype = torch.finfo(dtype).min
1421
- causal_mask = torch.full(
1422
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1423
- )
1424
  diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1425
  if config.sliding_window is not None:
1426
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
@@ -1447,7 +1521,6 @@ class IsaacModel(Qwen3Model):
1447
  return causal_mask
1448
 
1449
 
1450
-
1451
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
1452
  """Isaac multimodal model for conditional generation."""
1453
 
 
1
  from __future__ import annotations
2
 
3
  from collections import defaultdict
4
+ from typing import Any, TypedDict
5
 
6
  import math
7
  import numpy as np
 
81
  return cu_seqlens, max_seqlen
82
 
83
 
84
+ def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int:
85
+ """Helper to compute max sequence length from cumulative sequence lengths."""
86
+ if cu is None or len(cu) < 2:
87
+ return fallback
88
+ return int((cu[1:] - cu[:-1]).max().item())
89
+
90
+
91
+ def flash_attention_document_mask_forward(
92
+ q_lhd: torch.Tensor, # (L, H, D)
93
+ k_lhd: torch.Tensor, # (L, H, D)
94
+ v_lhd: torch.Tensor, # (L, H, D)
95
+ attention_mask: torch.Tensor | None = None, # unused for FA path
96
+ dropout: float = 0.0,
97
+ scaling: float | None = None,
98
+ cum_seq_q: torch.Tensor | None = None,
99
+ cum_seq_k: torch.Tensor | None = None,
100
+ max_seqlen: int | None = None,
101
+ is_causal: bool = False,
102
+ **kwargs,
103
+ ) -> tuple[torch.Tensor, None]:
104
+ """FlashAttention that consumes (L, H, D) directly to avoid layout churn."""
105
+ L, H, D = q_lhd.shape
106
+
107
+ # Compute max block length once (honor caller when provided)
108
+ if max_seqlen is not None:
109
+ max_q = max_k = int(max_seqlen)
110
+ else:
111
+ max_q = _max_from_cu(cum_seq_q, L)
112
+ max_k = _max_from_cu(cum_seq_k, L)
113
+
114
+ # Ensure contiguity only if needed
115
+ if not q_lhd.is_contiguous():
116
+ q_lhd = q_lhd.contiguous()
117
+ if not k_lhd.is_contiguous():
118
+ k_lhd = k_lhd.contiguous()
119
+ if not v_lhd.is_contiguous():
120
+ v_lhd = v_lhd.contiguous()
121
+
122
+ out_lhd, *_ = torch.ops.aten._flash_attention_forward(
123
+ query=q_lhd, # (L, H, D)
124
+ key=k_lhd, # (L, H, D)
125
+ value=v_lhd, # (L, H, D)
126
+ cum_seq_q=cum_seq_q,
127
+ cum_seq_k=cum_seq_k,
128
+ max_q=max_q,
129
+ max_k=max_k,
130
+ dropout_p=dropout,
131
+ is_causal=is_causal,
132
+ return_debug_mask=False,
133
+ scale=scaling,
134
+ window_size_left=-1,
135
+ window_size_right=-1,
136
+ alibi_slopes=None,
137
+ )
138
+ return out_lhd, None # (L, H, D)
139
+
140
+
141
+ def sdpa_document_mask_forward(
142
+ q_lhd: torch.Tensor, # (L, H, D)
143
+ k_lhd: torch.Tensor, # (L, H, D)
144
+ v_lhd: torch.Tensor, # (L, H, D)
145
+ dropout: float,
146
+ scaling: float | None,
147
+ cu_seqlens: torch.Tensor | None,
148
+ ) -> torch.Tensor:
149
+ """SDPA with block-diagonal masking for variable-length sequences."""
150
+ L, H, D = q_lhd.shape
151
+
152
+ # Transpose to (1, H, L, D) format for SDPA
153
+ Q = q_lhd.permute(1, 0, 2).unsqueeze(0)
154
+ K = k_lhd.permute(1, 0, 2).unsqueeze(0)
155
+ V = v_lhd.permute(1, 0, 2).unsqueeze(0)
156
+
157
+ # Build block-diagonal mask for variable-length sequences
158
+ attn_mask = None
159
+ if cu_seqlens is not None:
160
+ seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long()
161
+ seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes)
162
+ block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked
163
+ attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L)
164
+
165
+ Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling)
166
+ return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D)
167
+
168
+
169
  class Siglip2VariableSequenceEmbeddings(nn.Module):
170
  def __init__(self, config: PixelShuffleSiglip2VisionConfig):
171
  super().__init__()
 
257
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
258
 
259
  def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
260
+ # Expect packed sequences with batch_size == 1
261
+ batch_size, L, _ = hidden_states.shape
 
262
  if batch_size != 1:
263
+ raise ValueError("packed variable-length attention expects batch_size=1")
264
+ x = hidden_states[0] # (L, E)
265
+
266
+ H = self.num_heads
267
+ D = self.head_dim
268
+ p_drop = self.dropout if self.training else 0.0
269
+
270
+ # Project and reshape to (L, H, D)
271
+ q = self.q_proj(x).view(L, H, D)
272
+ k = self.k_proj(x).view(L, H, D)
273
+ v = self.v_proj(x).view(L, H, D)
274
+
275
+ attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3")
276
+
277
+ if attn_impl in ("flash_attention_2", "flash_attention_3"):
278
+ y_lhd, _ = flash_attention_document_mask_forward(
279
+ q,
280
+ k,
281
+ v,
282
+ attention_mask=None,
283
+ dropout=p_drop,
284
+ scaling=self.scale,
285
+ cum_seq_q=cu_seqlens,
286
+ cum_seq_k=cu_seqlens,
287
+ max_seqlen=max_seqlen,
288
+ is_causal=False,
289
+ )
290
+ else:
291
+ y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # Merge heads and project
294
+ y = self.out_proj(y_lhd.reshape(L, self.embed_dim))
295
+ return y.unsqueeze(0), None # (1, L, E)
296
 
297
 
298
  class IsaacSiglip2EncoderLayer(nn.Module):
 
874
  pixel_shuffle_scale: int = 1,
875
  max_sequence_length: int = 16384,
876
  vision_token: str = "<image>",
877
+ vision_attn_implementation: str | None = None,
878
  **kwargs,
879
  ):
880
  super().__init__(**kwargs)
 
896
  # Processing parameters
897
  self.max_sequence_length = max_sequence_length
898
  self.vision_token = vision_token
899
+ self.vision_attn_implementation = vision_attn_implementation
900
 
901
 
902
  # ============================================================================
 
951
  attributes = ["tokenizer"]
952
  tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
953
 
 
954
  def __init__(
955
  self,
956
  tokenizer: Qwen2Tokenizer,
 
1062
 
1063
  def __call__(
1064
  self,
1065
+ text: str | list[str],
1066
+ images: PIL.Image.Image | list[PIL.Image.Image] | None = None,
1067
  return_tensors: str | TensorType | None = TensorType.PYTORCH,
1068
  **kwargs,
1069
  ) -> BatchFeature:
 
1205
  self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
1206
 
1207
  vision_cfg = config.vision_config
1208
+ # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation
1209
+ vision_cfg._attn_implementation = (
1210
+ config.vision_attn_implementation
1211
+ if config.vision_attn_implementation is not None
1212
+ else config._attn_implementation
1213
+ )
1214
  if vision_cfg is None:
1215
  raise ValueError("IsaacConfig should always have vision_config")
1216
 
 
1494
  causal_mask = attention_mask
1495
  else:
1496
  min_dtype = torch.finfo(dtype).min
1497
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
 
 
1498
  diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1499
  if config.sliding_window is not None:
1500
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
 
1521
  return causal_mask
1522
 
1523
 
 
1524
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
1525
  """Isaac multimodal model for conditional generation."""
1526