Update to add SDPA support
Browse files- modular_isaac.py +131 -58
    	
        modular_isaac.py
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
             
            from collections import defaultdict
         | 
| 4 | 
            -
            from typing import Any,  | 
| 5 |  | 
| 6 | 
             
            import math
         | 
| 7 | 
             
            import numpy as np
         | 
| @@ -81,6 +81,91 @@ def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) | |
| 81 | 
             
                return cu_seqlens, max_seqlen
         | 
| 82 |  | 
| 83 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
            class Siglip2VariableSequenceEmbeddings(nn.Module):
         | 
| 85 | 
             
                def __init__(self, config: PixelShuffleSiglip2VisionConfig):
         | 
| 86 | 
             
                    super().__init__()
         | 
| @@ -172,58 +257,42 @@ class Siglip2VariableLengthAttention(nn.Module): | |
| 172 | 
             
                    self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
         | 
| 173 |  | 
| 174 | 
             
                def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
         | 
| 175 | 
            -
                     | 
| 176 | 
            -
             | 
| 177 | 
            -
                    # For variable-length attention, we need to reshape to (total_tokens, embed_dim)
         | 
| 178 | 
             
                    if batch_size != 1:
         | 
| 179 | 
            -
                        raise ValueError(" | 
| 180 | 
            -
                     | 
| 181 | 
            -
             | 
| 182 | 
            -
                     | 
| 183 | 
            -
                     | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
                     | 
| 187 | 
            -
                     | 
| 188 | 
            -
                     | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
                     | 
| 192 | 
            -
             | 
| 193 | 
            -
                     | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
                         | 
| 206 | 
            -
             | 
| 207 | 
            -
                         | 
| 208 | 
            -
                        window_size_left=-1,
         | 
| 209 | 
            -
                        window_size_right=-1,
         | 
| 210 | 
            -
                        alibi_slopes=None,
         | 
| 211 | 
            -
                    )
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                    # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim)
         | 
| 214 | 
            -
                    attn_output = attn_output.reshape(seq_len, self.embed_dim)
         | 
| 215 | 
            -
             | 
| 216 | 
            -
                    # 5. Convert back to original dtype if needed
         | 
| 217 | 
            -
                    if attn_output.dtype != orig_dtype:
         | 
| 218 | 
            -
                        attn_output = attn_output.to(orig_dtype)
         | 
| 219 | 
            -
             | 
| 220 | 
            -
                    # 6. Project output
         | 
| 221 | 
            -
                    attn_output = self.out_proj(attn_output)  # (seq_len, embed_dim)
         | 
| 222 | 
            -
             | 
| 223 | 
            -
                    # 7. Add back batch dimension for compatibility
         | 
| 224 | 
            -
                    attn_output = attn_output.unsqueeze(0)  # (1, seq_len, embed_dim)
         | 
| 225 |  | 
| 226 | 
            -
                     | 
|  | |
|  | |
| 227 |  | 
| 228 |  | 
| 229 | 
             
            class IsaacSiglip2EncoderLayer(nn.Module):
         | 
| @@ -805,6 +874,7 @@ class IsaacConfig(Qwen3Config): | |
| 805 | 
             
                    pixel_shuffle_scale: int = 1,
         | 
| 806 | 
             
                    max_sequence_length: int = 16384,
         | 
| 807 | 
             
                    vision_token: str = "<image>",
         | 
|  | |
| 808 | 
             
                    **kwargs,
         | 
| 809 | 
             
                ):
         | 
| 810 | 
             
                    super().__init__(**kwargs)
         | 
| @@ -826,6 +896,7 @@ class IsaacConfig(Qwen3Config): | |
| 826 | 
             
                    # Processing parameters
         | 
| 827 | 
             
                    self.max_sequence_length = max_sequence_length
         | 
| 828 | 
             
                    self.vision_token = vision_token
         | 
|  | |
| 829 |  | 
| 830 |  | 
| 831 | 
             
            # ============================================================================
         | 
| @@ -880,7 +951,6 @@ class IsaacProcessor(ProcessorMixin): | |
| 880 | 
             
                attributes = ["tokenizer"]
         | 
| 881 | 
             
                tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
         | 
| 882 |  | 
| 883 | 
            -
             | 
| 884 | 
             
                def __init__(
         | 
| 885 | 
             
                    self,
         | 
| 886 | 
             
                    tokenizer: Qwen2Tokenizer,
         | 
| @@ -992,8 +1062,8 @@ class IsaacProcessor(ProcessorMixin): | |
| 992 |  | 
| 993 | 
             
                def __call__(
         | 
| 994 | 
             
                    self,
         | 
| 995 | 
            -
                    text:  | 
| 996 | 
            -
                    images:  | 
| 997 | 
             
                    return_tensors: str | TensorType | None = TensorType.PYTORCH,
         | 
| 998 | 
             
                    **kwargs,
         | 
| 999 | 
             
                ) -> BatchFeature:
         | 
| @@ -1135,6 +1205,12 @@ class IsaacModel(Qwen3Model): | |
| 1135 | 
             
                    self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
         | 
| 1136 |  | 
| 1137 | 
             
                    vision_cfg = config.vision_config
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1138 | 
             
                    if vision_cfg is None:
         | 
| 1139 | 
             
                        raise ValueError("IsaacConfig should always have vision_config")
         | 
| 1140 |  | 
| @@ -1418,9 +1494,7 @@ class IsaacModel(Qwen3Model): | |
| 1418 | 
             
                        causal_mask = attention_mask
         | 
| 1419 | 
             
                    else:
         | 
| 1420 | 
             
                        min_dtype = torch.finfo(dtype).min
         | 
| 1421 | 
            -
                        causal_mask = torch.full(
         | 
| 1422 | 
            -
                            (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
         | 
| 1423 | 
            -
                        )
         | 
| 1424 | 
             
                        diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 1425 | 
             
                        if config.sliding_window is not None:
         | 
| 1426 | 
             
                            # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
         | 
| @@ -1447,7 +1521,6 @@ class IsaacModel(Qwen3Model): | |
| 1447 | 
             
                    return causal_mask
         | 
| 1448 |  | 
| 1449 |  | 
| 1450 | 
            -
             | 
| 1451 | 
             
            class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
         | 
| 1452 | 
             
                """Isaac multimodal model for conditional generation."""
         | 
| 1453 |  | 
|  | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
             
            from collections import defaultdict
         | 
| 4 | 
            +
            from typing import Any, TypedDict
         | 
| 5 |  | 
| 6 | 
             
            import math
         | 
| 7 | 
             
            import numpy as np
         | 
|  | |
| 81 | 
             
                return cu_seqlens, max_seqlen
         | 
| 82 |  | 
| 83 |  | 
| 84 | 
            +
            def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int:
         | 
| 85 | 
            +
                """Helper to compute max sequence length from cumulative sequence lengths."""
         | 
| 86 | 
            +
                if cu is None or len(cu) < 2:
         | 
| 87 | 
            +
                    return fallback
         | 
| 88 | 
            +
                return int((cu[1:] - cu[:-1]).max().item())
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def flash_attention_document_mask_forward(
         | 
| 92 | 
            +
                q_lhd: torch.Tensor,  # (L, H, D)
         | 
| 93 | 
            +
                k_lhd: torch.Tensor,  # (L, H, D)
         | 
| 94 | 
            +
                v_lhd: torch.Tensor,  # (L, H, D)
         | 
| 95 | 
            +
                attention_mask: torch.Tensor | None = None,  # unused for FA path
         | 
| 96 | 
            +
                dropout: float = 0.0,
         | 
| 97 | 
            +
                scaling: float | None = None,
         | 
| 98 | 
            +
                cum_seq_q: torch.Tensor | None = None,
         | 
| 99 | 
            +
                cum_seq_k: torch.Tensor | None = None,
         | 
| 100 | 
            +
                max_seqlen: int | None = None,
         | 
| 101 | 
            +
                is_causal: bool = False,
         | 
| 102 | 
            +
                **kwargs,
         | 
| 103 | 
            +
            ) -> tuple[torch.Tensor, None]:
         | 
| 104 | 
            +
                """FlashAttention that consumes (L, H, D) directly to avoid layout churn."""
         | 
| 105 | 
            +
                L, H, D = q_lhd.shape
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                # Compute max block length once (honor caller when provided)
         | 
| 108 | 
            +
                if max_seqlen is not None:
         | 
| 109 | 
            +
                    max_q = max_k = int(max_seqlen)
         | 
| 110 | 
            +
                else:
         | 
| 111 | 
            +
                    max_q = _max_from_cu(cum_seq_q, L)
         | 
| 112 | 
            +
                    max_k = _max_from_cu(cum_seq_k, L)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # Ensure contiguity only if needed
         | 
| 115 | 
            +
                if not q_lhd.is_contiguous():
         | 
| 116 | 
            +
                    q_lhd = q_lhd.contiguous()
         | 
| 117 | 
            +
                if not k_lhd.is_contiguous():
         | 
| 118 | 
            +
                    k_lhd = k_lhd.contiguous()
         | 
| 119 | 
            +
                if not v_lhd.is_contiguous():
         | 
| 120 | 
            +
                    v_lhd = v_lhd.contiguous()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                out_lhd, *_ = torch.ops.aten._flash_attention_forward(
         | 
| 123 | 
            +
                    query=q_lhd,  # (L, H, D)
         | 
| 124 | 
            +
                    key=k_lhd,  # (L, H, D)
         | 
| 125 | 
            +
                    value=v_lhd,  # (L, H, D)
         | 
| 126 | 
            +
                    cum_seq_q=cum_seq_q,
         | 
| 127 | 
            +
                    cum_seq_k=cum_seq_k,
         | 
| 128 | 
            +
                    max_q=max_q,
         | 
| 129 | 
            +
                    max_k=max_k,
         | 
| 130 | 
            +
                    dropout_p=dropout,
         | 
| 131 | 
            +
                    is_causal=is_causal,
         | 
| 132 | 
            +
                    return_debug_mask=False,
         | 
| 133 | 
            +
                    scale=scaling,
         | 
| 134 | 
            +
                    window_size_left=-1,
         | 
| 135 | 
            +
                    window_size_right=-1,
         | 
| 136 | 
            +
                    alibi_slopes=None,
         | 
| 137 | 
            +
                )
         | 
| 138 | 
            +
                return out_lhd, None  # (L, H, D)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def sdpa_document_mask_forward(
         | 
| 142 | 
            +
                q_lhd: torch.Tensor,  # (L, H, D)
         | 
| 143 | 
            +
                k_lhd: torch.Tensor,  # (L, H, D)
         | 
| 144 | 
            +
                v_lhd: torch.Tensor,  # (L, H, D)
         | 
| 145 | 
            +
                dropout: float,
         | 
| 146 | 
            +
                scaling: float | None,
         | 
| 147 | 
            +
                cu_seqlens: torch.Tensor | None,
         | 
| 148 | 
            +
            ) -> torch.Tensor:
         | 
| 149 | 
            +
                """SDPA with block-diagonal masking for variable-length sequences."""
         | 
| 150 | 
            +
                L, H, D = q_lhd.shape
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                # Transpose to (1, H, L, D) format for SDPA
         | 
| 153 | 
            +
                Q = q_lhd.permute(1, 0, 2).unsqueeze(0)
         | 
| 154 | 
            +
                K = k_lhd.permute(1, 0, 2).unsqueeze(0)
         | 
| 155 | 
            +
                V = v_lhd.permute(1, 0, 2).unsqueeze(0)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # Build block-diagonal mask for variable-length sequences
         | 
| 158 | 
            +
                attn_mask = None
         | 
| 159 | 
            +
                if cu_seqlens is not None:
         | 
| 160 | 
            +
                    seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long()
         | 
| 161 | 
            +
                    seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes)
         | 
| 162 | 
            +
                    block_mask = seg_ids[:, None] != seg_ids[None, :]  # Cross-document attention blocked
         | 
| 163 | 
            +
                    attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling)
         | 
| 166 | 
            +
                return Y.squeeze(0).permute(1, 0, 2)  # Back to (L, H, D)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
             
            class Siglip2VariableSequenceEmbeddings(nn.Module):
         | 
| 170 | 
             
                def __init__(self, config: PixelShuffleSiglip2VisionConfig):
         | 
| 171 | 
             
                    super().__init__()
         | 
|  | |
| 257 | 
             
                    self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
         | 
| 258 |  | 
| 259 | 
             
                def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
         | 
| 260 | 
            +
                    # Expect packed sequences with batch_size == 1
         | 
| 261 | 
            +
                    batch_size, L, _ = hidden_states.shape
         | 
|  | |
| 262 | 
             
                    if batch_size != 1:
         | 
| 263 | 
            +
                        raise ValueError("packed variable-length attention expects batch_size=1")
         | 
| 264 | 
            +
                    x = hidden_states[0]  # (L, E)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    H = self.num_heads
         | 
| 267 | 
            +
                    D = self.head_dim
         | 
| 268 | 
            +
                    p_drop = self.dropout if self.training else 0.0
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # Project and reshape to (L, H, D)
         | 
| 271 | 
            +
                    q = self.q_proj(x).view(L, H, D)
         | 
| 272 | 
            +
                    k = self.k_proj(x).view(L, H, D)
         | 
| 273 | 
            +
                    v = self.v_proj(x).view(L, H, D)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3")
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    if attn_impl in ("flash_attention_2", "flash_attention_3"):
         | 
| 278 | 
            +
                        y_lhd, _ = flash_attention_document_mask_forward(
         | 
| 279 | 
            +
                            q,
         | 
| 280 | 
            +
                            k,
         | 
| 281 | 
            +
                            v,
         | 
| 282 | 
            +
                            attention_mask=None,
         | 
| 283 | 
            +
                            dropout=p_drop,
         | 
| 284 | 
            +
                            scaling=self.scale,
         | 
| 285 | 
            +
                            cum_seq_q=cu_seqlens,
         | 
| 286 | 
            +
                            cum_seq_k=cu_seqlens,
         | 
| 287 | 
            +
                            max_seqlen=max_seqlen,
         | 
| 288 | 
            +
                            is_causal=False,
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 292 |  | 
| 293 | 
            +
                    # Merge heads and project
         | 
| 294 | 
            +
                    y = self.out_proj(y_lhd.reshape(L, self.embed_dim))
         | 
| 295 | 
            +
                    return y.unsqueeze(0), None  # (1, L, E)
         | 
| 296 |  | 
| 297 |  | 
| 298 | 
             
            class IsaacSiglip2EncoderLayer(nn.Module):
         | 
|  | |
| 874 | 
             
                    pixel_shuffle_scale: int = 1,
         | 
| 875 | 
             
                    max_sequence_length: int = 16384,
         | 
| 876 | 
             
                    vision_token: str = "<image>",
         | 
| 877 | 
            +
                    vision_attn_implementation: str | None = None,
         | 
| 878 | 
             
                    **kwargs,
         | 
| 879 | 
             
                ):
         | 
| 880 | 
             
                    super().__init__(**kwargs)
         | 
|  | |
| 896 | 
             
                    # Processing parameters
         | 
| 897 | 
             
                    self.max_sequence_length = max_sequence_length
         | 
| 898 | 
             
                    self.vision_token = vision_token
         | 
| 899 | 
            +
                    self.vision_attn_implementation = vision_attn_implementation
         | 
| 900 |  | 
| 901 |  | 
| 902 | 
             
            # ============================================================================
         | 
|  | |
| 951 | 
             
                attributes = ["tokenizer"]
         | 
| 952 | 
             
                tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
         | 
| 953 |  | 
|  | |
| 954 | 
             
                def __init__(
         | 
| 955 | 
             
                    self,
         | 
| 956 | 
             
                    tokenizer: Qwen2Tokenizer,
         | 
|  | |
| 1062 |  | 
| 1063 | 
             
                def __call__(
         | 
| 1064 | 
             
                    self,
         | 
| 1065 | 
            +
                    text: str | list[str],
         | 
| 1066 | 
            +
                    images: PIL.Image.Image | list[PIL.Image.Image] | None = None,
         | 
| 1067 | 
             
                    return_tensors: str | TensorType | None = TensorType.PYTORCH,
         | 
| 1068 | 
             
                    **kwargs,
         | 
| 1069 | 
             
                ) -> BatchFeature:
         | 
|  | |
| 1205 | 
             
                    self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
         | 
| 1206 |  | 
| 1207 | 
             
                    vision_cfg = config.vision_config
         | 
| 1208 | 
            +
                    # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation
         | 
| 1209 | 
            +
                    vision_cfg._attn_implementation = (
         | 
| 1210 | 
            +
                        config.vision_attn_implementation
         | 
| 1211 | 
            +
                        if config.vision_attn_implementation is not None
         | 
| 1212 | 
            +
                        else config._attn_implementation
         | 
| 1213 | 
            +
                    )
         | 
| 1214 | 
             
                    if vision_cfg is None:
         | 
| 1215 | 
             
                        raise ValueError("IsaacConfig should always have vision_config")
         | 
| 1216 |  | 
|  | |
| 1494 | 
             
                        causal_mask = attention_mask
         | 
| 1495 | 
             
                    else:
         | 
| 1496 | 
             
                        min_dtype = torch.finfo(dtype).min
         | 
| 1497 | 
            +
                        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
         | 
|  | |
|  | |
| 1498 | 
             
                        diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 1499 | 
             
                        if config.sliding_window is not None:
         | 
| 1500 | 
             
                            # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
         | 
|  | |
| 1521 | 
             
                    return causal_mask
         | 
| 1522 |  | 
| 1523 |  | 
|  | |
| 1524 | 
             
            class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
         | 
| 1525 | 
             
                """Isaac multimodal model for conditional generation."""
         | 
| 1526 |  | 
