Spaces:
Runtime error
Runtime error
| # imports | |
| import logging | |
| import time | |
| import torch | |
| from transformers import GenerationConfig, pipeline | |
| # Setting up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| class BatchAggregator: | |
| def __init__( | |
| self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs | |
| ): | |
| self.logger = logging.getLogger(__name__) | |
| self.model_name = model_name | |
| self.logger.info(f"Initializing aggregator with model {model_name}") | |
| self.aggregator = pipeline( | |
| "text2text-generation", | |
| model_name, | |
| device=0 if torch.cuda.is_available() else -1, | |
| torch_dtype=torch.float32, | |
| ) | |
| try: | |
| self.aggregator.model = torch.compile(self.aggregator.model) | |
| except Exception as e: | |
| self.logger.warning(f"Could not compile model with Torch 2.0: {e}") | |
| try: | |
| self.aggregator.model.generation_config = GenerationConfig.from_pretrained( | |
| self.model_name | |
| ) | |
| except Exception as e: | |
| self.logger.warning( | |
| f"Could not load generation config, using defaults: {e}" | |
| ) | |
| self.aggregator.model.generation_config = GenerationConfig( | |
| num_beams=4, | |
| early_stopping=True, | |
| do_sample=False, | |
| min_new_tokens=32, | |
| max_new_tokens=192, | |
| repetition_penalty=1.1, | |
| length_penalty=1.5, | |
| no_repeat_ngram_size=4, | |
| encoder_no_repeat_ngram_size=5, | |
| decoder_start_token_id=0, | |
| eos_token_id=1, | |
| pad_token_id=0, | |
| ) | |
| if "bart" in model_name.lower(): | |
| self.logger.info("Using BART model, updating generation config") | |
| upd = { | |
| "num_beams": 8, | |
| "repetition_penalty": 1.3, | |
| "length_penalty": 1.0, | |
| "_from_model_config": False, | |
| "max_new_tokens": 256, | |
| "min_new_tokens": 32, | |
| "no_repeat_ngram_size": 3, | |
| "encoder_no_repeat_ngram_size": 6, | |
| } | |
| self.aggregator.model.generation_config.update(**upd) | |
| if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1": | |
| self.logger.info("Updating generation config with defaults") | |
| self.update_generation_config() | |
| self.logger.info(self.aggregator.model.generation_config.to_json_string()) | |
| def update_generation_config(self, **kwargs): | |
| self.logger.info(f"Updating generation config with {kwargs}") | |
| default = GenerationConfig( | |
| num_beams=4, | |
| early_stopping=True, | |
| do_sample=False, | |
| min_new_tokens=32, | |
| max_new_tokens=192, | |
| repetition_penalty=1.1, | |
| length_penalty=1.5, | |
| no_repeat_ngram_size=4, | |
| encoder_no_repeat_ngram_size=5, | |
| decoder_start_token_id=0, | |
| eos_token_id=1, | |
| pad_token_id=0, | |
| ).to_dict() | |
| self.aggregator.model.generation_config.update(**default) | |
| def _replace_pipeline(model_name) | |
| def infer_aggregate( | |
| self, | |
| text_list: list, | |
| instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:", | |
| **kwargs, | |
| ): | |
| joined_text = "\n".join(text_list) | |
| prompt = f"{instruction}\n\n{joined_text}\n" | |
| if kwargs: | |
| self.update_generation_config(**kwargs) | |
| st = time.perf_counter() | |
| self.logger.info(f"Running inference on {len(text_list)} texts") | |
| result = self.aggregator( | |
| prompt, | |
| generation_config=self.aggregator.model.generation_config, | |
| )[0]["generated_text"] | |
| self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") | |
| self.logger.info( | |
| f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" | |
| ) | |
| return result | |
| def count_tokens(self, text: str): | |
| return ( | |
| len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) | |
| if text | |
| else 0 | |
| ) | |