Feature Extraction
Transformers
Safetensors
English
GAR
custom_code
File size: 2,567 Bytes
14e3151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import copy
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, PerceptionLMConfig

logger = logging.get_logger(__name__)


class GARConfig(PretrainedConfig):
    model_type = 'GAR'
    is_composition = True

    def __init__(
        self,
        mllm_config=None,
        prompt_numbers=5,
        crop_tokens_ids=[128004, 128005, 128008, 128010, 128011],
        use_flash_attn=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if mllm_config is None:
            mllm_config = {}
            logger.info('mllm_config is None. Initializing the PerceptionLM with default values.')

        if mllm_config is None:
            self.mllm_config = AutoConfig.from_pretrained("facebook/Perception-LM-8B")
        else:
            self.mllm_config = PerceptionLMConfig(**mllm_config)
        self.prompt_numbers = prompt_numbers

        self.crop_tokens_ids = crop_tokens_ids
        assert len(self.crop_tokens_ids) == self.prompt_numbers, f'{self.crop_tokens_ids} crop_tokens_ids length should be {self.prompt_numbers}'

        try:
            self.patch_size_h = self.mllm_config.vision_config.model_args["img_size"][0] // self.mllm_config.vision_config.model_args["ref_feat_shape"][0]
            self.patch_size_w = self.mllm_config.vision_config.model_args["img_size"][1] // self.mllm_config.vision_config.model_args["ref_feat_shape"][1]
            self.kernel_size = [self.patch_size_h, self.patch_size_w]
        except:
            self.patch_size_h = 16
            self.patch_size_w = 16
            self.kernel_size = [self.patch_size_h, self.patch_size_w]
            
        try:
            self.mask_path_embedding_out_channels = self.mllm_config.vision_config.num_features
        except:
            self.mask_path_embedding_out_channels = 1280

        self.mllm_config.use_flash_attn = True if use_flash_attn else False
        self.mllm_config.text_config.use_flash_attn = True if use_flash_attn else False
        self.mllm_config.vision_config.use_flash_attn = False

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output['mllm_config'] = self.mllm_config.to_dict()
        output['model_type'] = self.__class__.model_type
        return output