| """qwen.py. | |
| File for providing the Qwen model implementation. | |
| """ | |
| from transformers import Qwen2VLForConditionalGeneration | |
| from src.models.base import ModelBase | |
| from src.models.config import Config | |
| class QwenModel(ModelBase): | |
| """Qwen model implementation.""" | |
| def __init__(self, config: Config) -> None: | |
| """Initialization of the qwen model. | |
| Args: | |
| config (Config): Parsed config | |
| """ | |
| # initialize the parent class | |
| super().__init__(config) | |
| def _load_specific_model(self) -> None: | |
| """Overridden function to populate self.model.""" | |
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| self.model_path, **self.config.model | |
| ) if hasattr(self.config, 'model') else ( | |
| Qwen2VLForConditionalGeneration.from_pretrained( | |
| self.model_path | |
| ) | |
| ) | |