Matt
		
	commited on
		
		
					Commit 
							
							·
						
						b5c7055
	
1
								Parent(s):
							
							f285b2c
								
Tie weights correctly
Browse files- modeling_florence2.py +8 -2
    	
        modeling_florence2.py
    CHANGED
    
    | @@ -2066,6 +2066,12 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel | |
| 2066 | 
             
                    # Initialize weights and apply final processing
         | 
| 2067 | 
             
                    self.post_init()
         | 
| 2068 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2069 | 
             
                def get_encoder(self):
         | 
| 2070 | 
             
                    return self.model.get_encoder()
         | 
| 2071 |  | 
| @@ -2523,6 +2529,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel): | |
| 2523 | 
             
                FLORENCE2_START_DOCSTRING,
         | 
| 2524 | 
             
            )
         | 
| 2525 | 
             
            class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
         | 
|  | |
|  | |
| 2526 | 
             
                def __init__(self, config: Florence2Config):
         | 
| 2527 | 
             
                    super().__init__(config)
         | 
| 2528 | 
             
                    assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
         | 
| @@ -2537,8 +2545,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): | |
| 2537 |  | 
| 2538 | 
             
                    language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
         | 
| 2539 |  | 
| 2540 | 
            -
                    if language_model._tied_weights_keys is not None:
         | 
| 2541 | 
            -
                        self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
         | 
| 2542 | 
             
                    self.language_model = language_model
         | 
| 2543 |  | 
| 2544 | 
             
                    self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
         | 
|  | |
| 2066 | 
             
                    # Initialize weights and apply final processing
         | 
| 2067 | 
             
                    self.post_init()
         | 
| 2068 |  | 
| 2069 | 
            +
                def _tie_weights(self):
         | 
| 2070 | 
            +
                    if self.config.tie_word_embeddings:
         | 
| 2071 | 
            +
                        self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
         | 
| 2072 | 
            +
                        self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
         | 
| 2073 | 
            +
                        self._tie_or_clone_weights(self.lm_head, self.model.shared)
         | 
| 2074 | 
            +
             | 
| 2075 | 
             
                def get_encoder(self):
         | 
| 2076 | 
             
                    return self.model.get_encoder()
         | 
| 2077 |  | 
|  | |
| 2529 | 
             
                FLORENCE2_START_DOCSTRING,
         | 
| 2530 | 
             
            )
         | 
| 2531 | 
             
            class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
         | 
| 2532 | 
            +
                _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
         | 
| 2533 | 
            +
             | 
| 2534 | 
             
                def __init__(self, config: Florence2Config):
         | 
| 2535 | 
             
                    super().__init__(config)
         | 
| 2536 | 
             
                    assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
         | 
|  | |
| 2545 |  | 
| 2546 | 
             
                    language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
         | 
| 2547 |  | 
|  | |
|  | |
| 2548 | 
             
                    self.language_model = language_model
         | 
| 2549 |  | 
| 2550 | 
             
                    self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
         | 
