Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| import gradio as gr | |
| from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp | |
| from mammal.model import Mammal | |
| class MammalObjectBroker: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| name: str | None = None, | |
| task_list: list[str] | None = None, | |
| *, | |
| force_preload=False, | |
| ) -> None: | |
| self.model_path = model_path | |
| if name is None: | |
| name = model_path | |
| self.name = name | |
| self.tasks: list[str] = [] | |
| if task_list is not None: | |
| self.tasks = task_list | |
| self._model: Mammal | None = None | |
| self._tokenizer_op = None | |
| if force_preload: | |
| self.force_preload() | |
| def model(self) -> Mammal: | |
| if self._model is None: | |
| self._model = Mammal.from_pretrained(self.model_path) | |
| self._model.eval() | |
| return self._model | |
| def tokenizer_op(self): | |
| if self._tokenizer_op is None: | |
| self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) | |
| return self._tokenizer_op | |
| def force_preload(self): | |
| """pre-load the model and tokenizer (in this order)""" | |
| _ = self.model | |
| _ = self.tokenizer_op | |
| class MammalTask(ABC): | |
| def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None: | |
| self.name = name | |
| self.description = None | |
| self._demo = None | |
| self.model_dict = model_dict | |
| def crate_sample_dict( | |
| self, sample_inputs: dict, model_holder: MammalObjectBroker | |
| ) -> dict: | |
| """Formatting prompt to match pre-training syntax | |
| Args: | |
| prompt (str): _description_ | |
| Returns: | |
| dict: sample_dict for feeding into model | |
| """ | |
| raise NotImplementedError() | |
| # @abstractmethod | |
| def run_model(self, sample_dict, model: Mammal): | |
| raise NotImplementedError() | |
| def create_demo(self, model_name_widget: gr.component) -> gr.Group: | |
| """create an gradio demo group | |
| Args: | |
| model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create | |
| gradio actions with the current model name as an input | |
| Raises: | |
| NotImplementedError: _description_ | |
| """ | |
| raise NotImplementedError() | |
| def demo(self, model_name_widgit: gr.component = None): | |
| if self._demo is None: | |
| self._demo = self.create_demo(model_name_widget=model_name_widgit) | |
| return self._demo | |
| def decode_output(self, batch_dict, model: Mammal) -> list: | |
| raise NotImplementedError() | |
| # classification helpers | |
| def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int: | |
| """token for positive binding | |
| Args: | |
| model (MammalTrainedModel): model holding tokenizer | |
| Returns: | |
| int: id of positive binding token | |
| """ | |
| return tokenizer_op.get_token_id("<1>") | |
| def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int: | |
| """token for negative binding | |
| Args: | |
| model (MammalTrainedModel): model holding tokenizer | |
| Returns: | |
| int: id of negative binding token | |
| """ | |
| return tokenizer_op.get_token_id("<0>") | |
| def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id): | |
| label_mapping = { | |
| MammalTask.negative_token_id(tokenizer_op): "negative", | |
| MammalTask.positive_token_id(tokenizer_op): "positive", | |
| } | |
| return label_mapping.get(token_id, token_id) | |
| class TaskRegistry(dict[str, MammalTask]): | |
| """just a dictionary with a register method""" | |
| def register_task(self, task: MammalTask): | |
| self[task.name] = task | |
| return task.name | |
| class ModelRegistry(dict[str, MammalObjectBroker]): | |
| """just a dictionary with a register models""" | |
| def register_model( | |
| self, model_path, task_list=None, name=None, *, force_preload=False | |
| ): | |
| """register a model and return the name of the model | |
| Args: | |
| model_path (_type_): _description_ | |
| name (optional str): explicit name for the model | |
| Returns: | |
| str: model name | |
| """ | |
| model_holder = MammalObjectBroker( | |
| model_path=model_path, | |
| task_list=task_list, | |
| name=name, | |
| force_preload=force_preload, | |
| ) | |
| self[model_holder.name] = model_holder | |
| return model_holder.name | |