--- license: cc-by-nc-4.0 --- # SMB-RAD-Encoder-v1 SMB-RAD-Encoder-v1 is a pure-vision backbone designed for medical imaging foundation models, with a strong focus on radiology modalities such as CT, MRI, and X‑ray. It implements efficient 3D patch embedding, rotary position encodings, scalable Transformer blocks, multi‑scale deep feature extraction, and two self‑supervised objectives tailored for medical imagery: masked image modeling (MIM) and joint embedding predictive architecture (JEPA). ## Architecture Overview The implementation lives in `modeling_smb_vision.py` and exposes three main classes: - `SMBVisionEncoder`: Vision encoder with 3D patch embedding and stacked Transformer blocks - `SMBVisionPredictor`: Lightweight Transformer for JEPA next‑embedding prediction - `SMBVisionModel`: Wrapper that combines encoder + predictor and computes MIM and JEPA losses Key components and how they map to the code: - **3D Patch Embedding (`SMBVisionPatchEmbed`)** - A `Conv3d` with kernel=stride=`[temporal_patch_size, patch_size, patch_size]` over per‑patch tensors - Supports `in_channels` = 1 (grayscale), 3 (RGB), or 4; radiology typically uses 1 - Produces per‑patch embeddings of size `hidden_size` - **Learned 2D Positional Embedding + Fast Interpolation** - `pos_embed: nn.Embedding(num_position_embeddings, hidden_size)` with bilinear‑style interpolation (`fast_pos_embed_interpolate`) to target grid sizes (height×width) per frame - **Rotary Position Embedding (RoPE) in Space (and Time)** - `SMBVisionRotaryEmbedding` generates frequencies; applied in attention via `apply_rotary_pos_emb_vision` - Encodes spatial (and slice/temporal) structure for robust geometric reasoning - **Transformer Blocks (`SMBVisionBlock`)** - Pre‑norm residual blocks with `SMBVisionAttention` and `SMBVisionMLP` - Attention backends: eager, SDPA, FlashAttention‑2 (config‑selectable) - **DeepStack Multi‑scale Features** - `deepstack_visual_indexes` selects block indices whose outputs are merged by `SMBVisionPatchMerger` - Produces multi‑level visual descriptors for downstream tasks (e.g., detection, retrieval) - **Masked Image Modeling (MIM)** - Randomly masks a ratio of patch tokens and reconstructs pixels via `to_pixels: Linear(hidden_size -> patch_volume)` - Reconstruction loss: L1 (MAE) on masked patches - Note: For medical grayscale data, set `in_channels=1` so reconstruction target matches output shape - **JEPA Next‑Embedding Prediction** - Context/target partitions at the study level expand to patch tokens internally - `SMBVisionPredictor` predicts target encoder embeddings; loss is MSE on target tokens ## Radiology‑centric Design Notes - **Modalities**: CT/MRI volumes (slice stacks) and X‑ray images are supported via patch tokenization - **Through‑plane handling**: `temporal_patch_size` acts as slice depth for 3D patching over the Z/through‑plane axis - **Grayscale emphasis**: Use `in_channels=1` for CT/MRI/X‑ray to align MIM reconstruction shapes - **Scalability**: Attention backends support SDPA and FlashAttention‑2 for large studies and high‑res inputs - **Multi‑scale features**: `deepstack_visual_indexes` provide hooks for detection/segmentation heads ## Installation ```bash pip install torch torchvision transformers 'monai[all]' pip install git+https://github.com/standardmodelbio/smb-biopan-utils.git ``` ## Quick Start (CT volumes) The encoder expects a list of patch tokens and a per‑sample grid descriptor `grid_thw = [T, H, W]`, where: - `T = num_slices / temporal_patch_size` - `H = image_height / patch_size` - `W = image_width / patch_size` You must first patchify the volume into non‑overlapping 3D patches of shape `[in_channels, temporal_patch_size, patch_size, patch_size]`, flatten each patch to a token, and concatenate all tokens for the batch. Example helper for NIfTI volumes: ```python from smb_biopan_utils import process_mm_info from transformers import AutoModel # Prepare message spec for your volume(s). Each "image" can be a path to NIfTI/DICOM. messages = [ { "content": [ {"type": "image", "image": "dummy.nii.gz"}, # Volume size is [1, 64, 160, 160] {"type": "image", "image": "dummy.nii.gz"}, ] } ] # Convert to patch tokens and grid descriptor expected by SMB‑Vision # Default patch size is 16 for all dimensions images, grid_thw = process_mm_info(messages) # images size is [800(400*2), 4096] # Optional - Dummy images and grid_thw images, grid_thw = torch.randn(800, 4096), torch.tensor([[4, 10, 10], [4, 10, 10]]) # Load backbone from HF Hub (uses this repo's modeling with trust_remote_code) model = AutoModel.from_pretrained( "standardmodelbio/SMB-RAD-Encoder-v1", trust_remote_code=True, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).encoder model.to("cuda") # Encode features encoded_patches, deepstack_features = model(images.to("cuda"), grid_thw=grid_thw.to("cuda")) print(encoded_patches.shape) # (800, 1152) ``` ## API Summary - `SMBVisionEncoder.forward(hidden_states, grid_thw)` → `(encoded_patches, deepstack_features)` - `hidden_states`: Float tensor of shape `(num_patches, in_channels * temporal_patch_size * patch_size^2)` - `grid_thw`: Int tensor of shape `(num_studies, 3)` with `[T, H, W]` per study - `SMBVisionModel.forward(hidden_states, grid_thw, context_mask, target_mask)` → `SMBVisionModelOutput` - Computes MIM (always) and JEPA (if masks provided) - Output contains losses and (optionally) encoder/predicted hidden states - `SMBVisionModel.forward_features(hidden_states, grid_thw)` → `(encoded_patches, deepstack_features)` - Convenience wrapper that calls the encoder directly for feature extraction ## Recommended Radiology Settings - **CT chest/abdomen**: `patch_size=16`, `temporal_patch_size=16`, `in_channels=1` - **MRI brain**: `patch_size=16`, `temporal_patch_size=16` (or per‑sequence 2D with `temporal_patch_size=1`) - **X‑ray**: `patch_size=16`, `temporal_patch_size=1`, `in_channels=1` ## Notes - FlashAttention‑2 can be enabled via the attention implementation setting in the vision config - Ensure volume dimensions are divisible by `patch_size` and `temporal_patch_size` (or center‑crop/pad before patchify) - For multi‑sequence MRI or 4‑channel inputs, set `in_channels=4` and adapt reconstruction paths accordingly ## Citation If you use SMB‑RAD-Encoder-v1 in your research, please cite this repository. ``` @software{standardmodel2025smbrad, author = {Chen, Zekai and Adam, Irsyad and Laprade, David and Brown, Kevin and others}, title = {SMB‑RAD-Encoder-v1}, year = {2025}, publisher = {Standard Model Biomedicine, Inc.}, journal = {Standard Model Blog}, url = {https://huggingface.co/standardmodelbio/SMB-RAD-Encoder-v1/edit/main/README.md} } ```