| from transformers import ( | |
| PretrainedConfig, | |
| PreTrainedModel | |
| ) | |
| import torch | |
| class StarVectorConfig(PretrainedConfig): | |
| model_type = "starvector" | |
| def __init__( | |
| self, | |
| starcoder_model_name: str = "bigcode/starcoderbase-1b", | |
| image_encoder_type: str = "clip", | |
| adapter_norm: str = "layer_norm", | |
| image_size: int = 224, | |
| max_length: int = 8192, | |
| max_length_train: int = 8192, | |
| use_flash_attn: bool = True, | |
| use_cache: bool = True, | |
| num_attention_heads: int = 16, | |
| num_hidden_layers: int = 24, | |
| vocab_size: int = 49152, | |
| hidden_size: int = 2048, | |
| num_kv_heads: int = 4, | |
| torch_dtype: str = "bfloat16", | |
| **kwargs, | |
| ): | |
| self.starcoder_model_name = starcoder_model_name | |
| self.image_encoder_type = image_encoder_type | |
| self.adapter_norm = adapter_norm | |
| self.image_size = image_size | |
| self.max_length = max_length | |
| self.max_length_train = max_length_train | |
| self.use_flash_attn = use_flash_attn | |
| self.use_cache = use_cache | |
| self.num_attention_heads = num_attention_heads | |
| self.num_hidden_layers = num_hidden_layers | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_kv_heads = num_kv_heads | |
| self.torch_dtype = torch_dtype | |
| super().__init__(**kwargs) | |
| class StarVectorForCausalLM(PreTrainedModel): | |
| config_class = StarVectorConfig | |
| _no_split_modules = [] | |
| def __init__(self, config: StarVectorConfig, **kwargs): | |
| super().__init__(config) | |
| starcoder_model_name = config.starcoder_model_name | |
| if 'starcoder2' in starcoder_model_name: | |
| from starvector.model.models.starvector_v2 import StarVectorStarCoder2 | |
| self.model = StarVectorStarCoder2(config=config, **kwargs) | |
| else: | |
| from starvector.model.models.starvector_v1 import StarVectorStarCoder | |
| self.model = StarVectorStarCoder(config=config, **kwargs) | |
| def forward(self, batch): | |
| return self.model(batch) | |
| def generate_im2svg(self, batch, **kwargs): | |
| return self.model.generate_im2svg(batch, **kwargs) | |
| def generate_im2text(self, batch, **kwargs): | |
| return self.model.generate_im2text(batch, **kwargs) | |
| def process_images(self, images): | |
| return self.model.image_encoder.process_images(images) | |