|
|
"""internlm.py. |
|
|
|
|
|
File for providing the InternLM-XComposer model implementation. |
|
|
""" |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
|
|
|
from src.models.base import ModelBase |
|
|
from src.models.config import Config |
|
|
|
|
|
|
|
|
class InternLMXComposerModel(ModelBase): |
|
|
"""InternLM model implementation.""" |
|
|
|
|
|
def __init__(self, config: Config) -> None: |
|
|
"""Initialization of the InternLM model. |
|
|
|
|
|
Args: |
|
|
config (Config): Parsed config |
|
|
""" |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
def _load_specific_model(self) -> None: |
|
|
"""Overridden function to populate self.model.""" |
|
|
self.model = AutoModel.from_pretrained( |
|
|
self.model_path, |
|
|
trust_remote_code=True, |
|
|
**self.config.model |
|
|
) if hasattr(self.config, 'model') else ( |
|
|
AutoModel.from_pretrained( |
|
|
self.model_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
) |
|
|
|
|
|
def _init_processor(self) -> None: |
|
|
"""Overridden function to instantiate the model's processor.""" |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
self.model_path, trust_remote_code=True) |
|
|
self.model.tokenizer = self.processor |
|
|
|
|
|
def _generate_prompt(self, prompt: str) -> str: |
|
|
"""Overridden function to generate the prompt for the model. |
|
|
|
|
|
Args: |
|
|
prompt (str): The input prompt to be processed. |
|
|
|
|
|
Returns: |
|
|
str: The formatted prompt ready for model input. |
|
|
""" |
|
|
return prompt |
|
|
|
|
|
def _generate_processor_output(self, prompt: str, img_path: str) -> dict: |
|
|
"""Overridden function to generate the format the prompt for the processor. |
|
|
|
|
|
Args: |
|
|
prompt (str): The input prompt to be processed. |
|
|
img_path (str): The path to the image to be processed. |
|
|
|
|
|
Returns: |
|
|
dict: The formatted inputs for the processor. |
|
|
|
|
|
Raises: |
|
|
ValueError: If no prompt is provided when required. |
|
|
""" |
|
|
logging.debug('Loading data...') |
|
|
|
|
|
|
|
|
inputs = {} |
|
|
|
|
|
|
|
|
if not prompt: |
|
|
raise ValueError( |
|
|
'No input prompt was provided for the InternLM-XC model') |
|
|
|
|
|
|
|
|
if self.config.has_images(): |
|
|
inputs['query'] = f'<ImageHere>; {prompt}' |
|
|
inputs['image'] = [img_path] |
|
|
else: |
|
|
inputs['query'] = prompt |
|
|
|
|
|
return inputs |
|
|
|
|
|
def _forward(self, data: dict) -> None: |
|
|
"""Overridden function to run the model forward pass. |
|
|
|
|
|
Args: |
|
|
data (dict): The input data for the model. |
|
|
""" |
|
|
device_type = str(self.config.device) |
|
|
logging.debug(f'DATA: {data}') |
|
|
with torch.autocast(device_type=device_type): |
|
|
_, _ = self.model.chat( |
|
|
self.processor, **data, **self.config.forward) |
|
|
|