Spaces:
Runtime error
Runtime error
| import copy | |
| from transformers import GPT2Config, ViTConfig | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class ViTGPT2Config(PretrainedConfig): | |
| model_type = "vit-gpt2" | |
| is_composition = True | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| if "vit_config" not in kwargs: | |
| raise ValueError("`vit_config` can not be `None`.") | |
| if "gpt2_config" not in kwargs: | |
| raise ValueError("`gpt2_config` can not be `None`.") | |
| vit_config = kwargs.pop("vit_config") | |
| gpt2_config = kwargs.pop("gpt2_config") | |
| self.vit_config = ViTConfig(**vit_config) | |
| self.gpt2_config = GPT2Config(**gpt2_config) | |
| def from_vit_gpt2_configs( | |
| cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs | |
| ): | |
| return cls( | |
| vit_config=vit_config.to_dict(), | |
| gpt2_config=gpt2_config.to_dict(), | |
| **kwargs | |
| ) | |
| def to_dict(self): | |
| output = copy.deepcopy(self.__dict__) | |
| output["vit_config"] = self.vit_config.to_dict() | |
| output["gpt2_config"] = self.gpt2_config.to_dict() | |
| output["model_type"] = self.__class__.model_type | |
| return output |