Ozan Oktay
		
	commited on
		
		
					Commit 
							
							·
						
						1cb4998
	
1
								Parent(s):
							
							2194015
								
add model
Browse files- config.json +30 -0
- configuration_cxrbert.py +26 -0
- modeling_cxrbert.py +129 -0
- pytorch_model.bin +3 -0
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_name_or_path": "microsoft/BiomedVLP-BioViL-T",
         | 
| 3 | 
            +
              "architectures": [
         | 
| 4 | 
            +
                "CXRBertModel"
         | 
| 5 | 
            +
              ],
         | 
| 6 | 
            +
              "attention_probs_dropout_prob": 0.25,
         | 
| 7 | 
            +
              "auto_map": {
         | 
| 8 | 
            +
                "AutoModel": "modeling_cxrbert.CXRBertModel"
         | 
| 9 | 
            +
              },
         | 
| 10 | 
            +
              "classifier_dropout": null,
         | 
| 11 | 
            +
              "gradient_checkpointing": false,
         | 
| 12 | 
            +
              "hidden_act": "gelu",
         | 
| 13 | 
            +
              "hidden_dropout_prob": 0.25,
         | 
| 14 | 
            +
              "hidden_size": 768,
         | 
| 15 | 
            +
              "initializer_range": 0.02,
         | 
| 16 | 
            +
              "intermediate_size": 3072,
         | 
| 17 | 
            +
              "layer_norm_eps": 1e-12,
         | 
| 18 | 
            +
              "max_position_embeddings": 512,
         | 
| 19 | 
            +
              "model_type": "bert",
         | 
| 20 | 
            +
              "num_attention_heads": 12,
         | 
| 21 | 
            +
              "num_hidden_layers": 12,
         | 
| 22 | 
            +
              "pad_token_id": 0,
         | 
| 23 | 
            +
              "position_embedding_type": "absolute",
         | 
| 24 | 
            +
              "projection_size": 128,
         | 
| 25 | 
            +
              "torch_dtype": "float32",
         | 
| 26 | 
            +
              "transformers_version": "4.17.0",
         | 
| 27 | 
            +
              "type_vocab_size": 2,
         | 
| 28 | 
            +
              "use_cache": true,
         | 
| 29 | 
            +
              "vocab_size": 30522
         | 
| 30 | 
            +
            }
         | 
    	
        configuration_cxrbert.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #  ------------------------------------------------------------------------------------------
         | 
| 2 | 
            +
            #  Copyright (c) Microsoft Corporation. All rights reserved.
         | 
| 3 | 
            +
            #  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
         | 
| 4 | 
            +
            #  ------------------------------------------------------------------------------------------
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from typing import Any
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from transformers import BertConfig, BertTokenizer
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class CXRBertConfig(BertConfig):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                Config class for CXR-BERT model.
         | 
| 14 | 
            +
                :param projection_size: Dimensionality of the joint latent space.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                model_type = "cxr-bert"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(self, projection_size: int = 128, **kwargs: Any) -> None:
         | 
| 20 | 
            +
                    super().__init__(**kwargs)
         | 
| 21 | 
            +
                    self.projection_size = projection_size
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class CXRBertTokenizer(BertTokenizer):
         | 
| 25 | 
            +
                def __init__(self, **kwargs: Any) -> None:
         | 
| 26 | 
            +
                    super().__init__(**kwargs)
         | 
    	
        modeling_cxrbert.py
    ADDED
    
    | @@ -0,0 +1,129 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #  ------------------------------------------------------------------------------------------
         | 
| 2 | 
            +
            #  Copyright (c) Microsoft Corporation. All rights reserved.
         | 
| 3 | 
            +
            #  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
         | 
| 4 | 
            +
            #  ------------------------------------------------------------------------------------------
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from typing import Any, Optional, Tuple, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
            from torch import Tensor as T
         | 
| 12 | 
            +
            from transformers import BertForMaskedLM
         | 
| 13 | 
            +
            from transformers.modeling_outputs import ModelOutput
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from .configuration_cxrbert import CXRBertConfig
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            BERTTupleOutput = Tuple[T, T, T, T, T]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            class CXRBertOutput(ModelOutput):
         | 
| 20 | 
            +
                last_hidden_state: torch.FloatTensor
         | 
| 21 | 
            +
                logits: torch.FloatTensor
         | 
| 22 | 
            +
                cls_projected_embedding: Optional[torch.FloatTensor] = None
         | 
| 23 | 
            +
                hidden_states: Optional[Tuple[torch.FloatTensor]] = None
         | 
| 24 | 
            +
                attentions: Optional[Tuple[torch.FloatTensor]] = None
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class BertProjectionHead(nn.Module):
         | 
| 28 | 
            +
                '''
         | 
| 29 | 
            +
                Projection head to be used with BERT CLS token, it's similar to `BertPredictionHeadTransform` in HuggingFace library.
         | 
| 30 | 
            +
                :param config: CXRBertConfig
         | 
| 31 | 
            +
                :return: (batch_size, output_size)
         | 
| 32 | 
            +
                '''
         | 
| 33 | 
            +
                def __init__(self, config: CXRBertConfig) -> None:
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
         | 
| 36 | 
            +
                    self.transform_act_fn = nn.functional.gelu
         | 
| 37 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
         | 
| 38 | 
            +
                    self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         | 
| 41 | 
            +
                    hidden_states = self.dense_to_hidden(hidden_states)
         | 
| 42 | 
            +
                    hidden_states = self.transform_act_fn(hidden_states)
         | 
| 43 | 
            +
                    hidden_states = self.LayerNorm(hidden_states)
         | 
| 44 | 
            +
                    hidden_states = self.dense_to_output(hidden_states)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    return hidden_states
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class CXRBertModel(BertForMaskedLM):
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Implements the CXR-BERT model outlined in the manuscript:
         | 
| 52 | 
            +
                Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
         | 
| 53 | 
            +
                https://arxiv.org/abs/2204.09817
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is used to align
         | 
| 56 | 
            +
                the latent vectors of image and text modalities.
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                config_class = CXRBertConfig
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def __init__(self, config: CXRBertConfig):
         | 
| 62 | 
            +
                    super().__init__(config)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.cls_projection_head = BertProjectionHead(config)
         | 
| 65 | 
            +
                    self.init_weights()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def forward(
         | 
| 68 | 
            +
                    self,
         | 
| 69 | 
            +
                    input_ids: torch.Tensor,
         | 
| 70 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 71 | 
            +
                    token_type_ids: Optional[torch.Tensor] = None,
         | 
| 72 | 
            +
                    position_ids: Optional[torch.Tensor] = None,
         | 
| 73 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 74 | 
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         | 
| 75 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 76 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 77 | 
            +
                    output_cls_projected_embedding: Optional[bool] = None,
         | 
| 78 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 79 | 
            +
                    **kwargs: Any
         | 
| 80 | 
            +
                ) -> Union[BERTTupleOutput, CXRBertOutput]:
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    bert_for_masked_lm_output = super().forward(input_ids=input_ids,
         | 
| 85 | 
            +
                                                                attention_mask=attention_mask,
         | 
| 86 | 
            +
                                                                token_type_ids=token_type_ids,
         | 
| 87 | 
            +
                                                                position_ids=position_ids,
         | 
| 88 | 
            +
                                                                head_mask=head_mask,
         | 
| 89 | 
            +
                                                                inputs_embeds=inputs_embeds,
         | 
| 90 | 
            +
                                                                output_attentions=output_attentions,
         | 
| 91 | 
            +
                                                                output_hidden_states=True,
         | 
| 92 | 
            +
                                                                return_dict=True)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
         | 
| 95 | 
            +
                    cls_projected_embedding = self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if return_dict:
         | 
| 98 | 
            +
                        return CXRBertOutput(
         | 
| 99 | 
            +
                            last_hidden_state=last_hidden_state,
         | 
| 100 | 
            +
                            logits=bert_for_masked_lm_output.logits,
         | 
| 101 | 
            +
                            cls_projected_embedding=cls_projected_embedding,
         | 
| 102 | 
            +
                            hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
         | 
| 103 | 
            +
                            attentions=bert_for_masked_lm_output.attentions,
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        return (
         | 
| 107 | 
            +
                            last_hidden_state,
         | 
| 108 | 
            +
                            bert_for_masked_lm_output.logits,
         | 
| 109 | 
            +
                            cls_projected_embedding,
         | 
| 110 | 
            +
                            bert_for_masked_lm_output.hidden_states,
         | 
| 111 | 
            +
                            bert_for_masked_lm_output.attentions,)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
         | 
| 114 | 
            +
                    """
         | 
| 115 | 
            +
                    Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
         | 
| 116 | 
            +
                    The joint latent space is trained using a contrastive objective between image and text data modalities.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    :param input_ids: (batch_size, sequence_length)
         | 
| 119 | 
            +
                    :param attention_mask: (batch_size, sequence_length)
         | 
| 120 | 
            +
                    :return: (batch_size, projection_size)
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
         | 
| 124 | 
            +
                                           output_cls_projected_embedding=True, return_dict=True)
         | 
| 125 | 
            +
                    assert isinstance(outputs, CXRBertOutput)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    assert outputs.cls_projected_embedding is not None
         | 
| 128 | 
            +
                    normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
         | 
| 129 | 
            +
                    return normalized_cls_embedding
         | 
    	
        pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:6d86a8d760eaa09c9a55d57cc6f6bb01b0cbccb8b827fc775a79f37a8fbda76c
         | 
| 3 | 
            +
            size 440966107
         | 
