Spaces:
Runtime error
Runtime error
| import os | |
| from .minference_configuration import MInferenceConfig | |
| from .patch import minference_patch, minference_patch_vllm, patch_hf | |
| class MInference: | |
| def __init__( | |
| self, | |
| attn_type: str = "minference", | |
| model_name: str = None, | |
| config_path: str = None, | |
| starting_layer: int = -1, | |
| kv_cache_cpu: bool = False, | |
| use_snapkv: bool = False, | |
| is_search: bool = False, | |
| attn_kwargs: dict = {}, | |
| **kwargs, | |
| ): | |
| super(MInference, self).__init__() | |
| self.config = MInferenceConfig( | |
| attn_type=attn_type, | |
| model_name=model_name, | |
| config_path=config_path, | |
| starting_layer=starting_layer, | |
| kv_cache_cpu=kv_cache_cpu, | |
| use_snapkv=use_snapkv, | |
| is_search=is_search, | |
| attn_kwargs=attn_kwargs, | |
| **kwargs, | |
| ) | |
| def __call__(self, model): | |
| return self.patch_model(model) | |
| def patch_model(self, model): | |
| if self.config.attn_type != "vllm": | |
| model.config.starting_layer = self.config.starting_layer | |
| model.config.config_path = self.config.config_path | |
| if self.config.attn_type == "minference": | |
| model.config.is_search = self.config.is_search | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "minference_with_dense": | |
| model.config.dense = True | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "dilated1": | |
| model.config.dilated1 = True | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "static": | |
| model.config.static_pattern = True | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "dilated2": | |
| model.config.dilated2 = True | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "streaming": | |
| model.config.streaming = True | |
| model.config.streaming_kwargs = { | |
| "n_local": 3968, | |
| "n_init": 128, | |
| **self.config.attn_kwargs, | |
| } | |
| model = minference_patch(model, self.config) | |
| elif self.config.attn_type == "streaming2": | |
| model = patch_hf( | |
| model, | |
| attn_type="streaming", | |
| attn_kwargs={"n_local": 3968, "n_init": 128, **self.config.attn_kwargs}, | |
| ) | |
| elif self.config.attn_type == "inf_llm": | |
| model = patch_hf( | |
| model, | |
| attn_type="inf_llm", | |
| attn_kwargs={ | |
| "block_size": 128, | |
| "n_init": 128, | |
| "n_local": 4096, | |
| "topk": 16, | |
| "repr_topk": 4, | |
| "max_cached_block": 32, | |
| "exc_block_size": 512, | |
| "base": 1000000, | |
| "distance_scale": 1.0, | |
| "dense_decoding": True, | |
| **self.config.attn_kwargs, | |
| }, | |
| ) | |
| elif self.config.attn_type == "vllm": | |
| model = minference_patch_vllm(model, self.config.config_path) | |
| else: | |
| raise ValueError( | |
| f"The attention type {self.config.attn_type} you specified is not supported." | |
| ) | |
| return model | |