| import copy | |
| from transformers.utils import logging | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers import AutoConfig, PerceptionLMConfig | |
| logger = logging.get_logger(__name__) | |
| class GARConfig(PretrainedConfig): | |
| model_type = 'GAR' | |
| is_composition = True | |
| def __init__( | |
| self, | |
| mllm_config=None, | |
| prompt_numbers=5, | |
| crop_tokens_ids=[128004, 128005, 128008, 128010, 128011], | |
| use_flash_attn=True, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| if mllm_config is None: | |
| mllm_config = {} | |
| logger.info('mllm_config is None. Initializing the PerceptionLM with default values.') | |
| if mllm_config is None: | |
| self.mllm_config = AutoConfig.from_pretrained("facebook/Perception-LM-8B") | |
| else: | |
| self.mllm_config = PerceptionLMConfig(**mllm_config) | |
| self.prompt_numbers = prompt_numbers | |
| self.crop_tokens_ids = crop_tokens_ids | |
| assert len(self.crop_tokens_ids) == self.prompt_numbers, f'{self.crop_tokens_ids} crop_tokens_ids length should be {self.prompt_numbers}' | |
| try: | |
| self.patch_size_h = self.mllm_config.vision_config.model_args["img_size"][0] // self.mllm_config.vision_config.model_args["ref_feat_shape"][0] | |
| self.patch_size_w = self.mllm_config.vision_config.model_args["img_size"][1] // self.mllm_config.vision_config.model_args["ref_feat_shape"][1] | |
| self.kernel_size = [self.patch_size_h, self.patch_size_w] | |
| except: | |
| self.patch_size_h = 16 | |
| self.patch_size_w = 16 | |
| self.kernel_size = [self.patch_size_h, self.patch_size_w] | |
| try: | |
| self.mask_path_embedding_out_channels = self.mllm_config.vision_config.num_features | |
| except: | |
| self.mask_path_embedding_out_channels = 1280 | |
| self.mllm_config.use_flash_attn = True if use_flash_attn else False | |
| self.mllm_config.text_config.use_flash_attn = True if use_flash_attn else False | |
| self.mllm_config.vision_config.use_flash_attn = False | |
| def to_dict(self): | |
| """ | |
| Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. | |
| Returns: | |
| `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | |
| """ | |
| output = copy.deepcopy(self.__dict__) | |
| output['mllm_config'] = self.mllm_config.to_dict() | |
| output['model_type'] = self.__class__.model_type | |
| return output |