Spaces:
Runtime error
Runtime error
| from SCRL_new.scrl.model import load_model | |
| from transformers import AutoTokenizer | |
| import re | |
| from abs_compressor import AbstractCompressor | |
| class SCRLCompressor(AbstractCompressor): | |
| def __init__(self, model_dir: str, device: str = "cpu", tokenizer_dir: str = "sentence-transformers/paraphrase-distilroberta-base-v2"): | |
| self.model_dir = model_dir | |
| self.device = device | |
| self.model = load_model(self.model_dir, self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) | |
| def compress(self, original_prompt: str, ratio: float = 0.5, max_length: int = 256) -> dict: | |
| original_tokens = len(self.gpt_tokenizer.encode(original_prompt)) | |
| # sources = [original_prompt.strip()] | |
| sources = re.findall(r'.{%d}' % max_length, original_prompt.strip()) | |
| # print(sources) | |
| if sources: | |
| summaries = self.model.predict(sources, self.tokenizer, self.device) | |
| # print(sources) | |
| # print(summaries) | |
| compressed_prompt = "" | |
| for s in summaries: | |
| compressed_prompt += s | |
| compressed_tokens = len(self.gpt_tokenizer.encode(compressed_prompt)) | |
| result = { | |
| 'compressed_prompt': compressed_prompt, | |
| 'ratio': compressed_tokens / original_tokens, | |
| 'original_tokens': original_tokens, | |
| 'compressed_tokens': compressed_tokens, | |
| } | |
| return result | |
| else: | |
| result = { | |
| 'compressed_prompt': "", | |
| 'ratio': 0, | |
| 'original_tokens': "", | |
| 'compressed_tokens': "", | |
| } | |
| return result | |