Spaces:
Runtime error
Runtime error
| import os | |
| from .configs.model2path import MODEL2PATH | |
| class MInferenceConfig: | |
| ATTENTION_TYPES = [ | |
| "minference", | |
| "minference_with_dense", | |
| "static", | |
| "dilated1", | |
| "dilated2", | |
| "streaming", | |
| "inf_llm", | |
| "vllm", | |
| ] | |
| def __init__( | |
| self, | |
| attn_type: str = "minference", | |
| model_name: str = None, | |
| config_path: str = None, | |
| starting_layer: int = -1, | |
| kv_cache_cpu: bool = False, | |
| use_snapkv: bool = False, | |
| is_search: bool = False, | |
| attn_kwargs: dict = {}, | |
| **kwargs, | |
| ): | |
| super(MInferenceConfig, self).__init__() | |
| assert ( | |
| attn_type in self.ATTENTION_TYPES | |
| ), f"The attention_type {attn_type} you specified is not supported." | |
| self.attn_type = attn_type | |
| self.config_path = self.update_config_path(config_path, model_name) | |
| self.model_name = model_name | |
| self.is_search = is_search | |
| self.starting_layer = starting_layer | |
| self.kv_cache_cpu = kv_cache_cpu | |
| self.use_snapkv = use_snapkv | |
| self.attn_kwargs = attn_kwargs | |
| def update_config_path(self, config_path: str, model_name: str): | |
| if config_path is not None: | |
| return config_path | |
| assert ( | |
| model_name in MODEL2PATH | |
| ), f"The model {model_name} you specified is not supported. You are welcome to add it and open a PR :)" | |
| return MODEL2PATH[model_name] | |