Spaces:
Running
Running
| from tclogger import logger | |
| from transformers import AutoTokenizer | |
| from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED | |
| class TokenChecker: | |
| def __init__(self, input_str: str, model: str): | |
| self.input_str = input_str | |
| if model in MODEL_MAP.keys(): | |
| self.model = model | |
| else: | |
| self.model = "mixtral-8x7b" | |
| self.model_fullname = MODEL_MAP[self.model] | |
| # As some models are gated, we need to fetch tokenizers from alternatives | |
| GATED_MODEL_MAP = { | |
| "llama3-70b": "NousResearch/Meta-Llama-3-70B", | |
| "gemma-7b": "unsloth/gemma-7b", | |
| "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2", | |
| "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1", | |
| } | |
| if self.model in GATED_MODEL_MAP.keys(): | |
| self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model]) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname) | |
| def count_tokens(self): | |
| token_count = len(self.tokenizer.encode(self.input_str)) | |
| logger.note(f"Prompt Token Count: {token_count}") | |
| return token_count | |
| def get_token_limit(self): | |
| return TOKEN_LIMIT_MAP[self.model] | |
| def get_token_redundancy(self): | |
| return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens()) | |
| def check_token_limit(self): | |
| if self.get_token_redundancy() <= 0: | |
| raise ValueError( | |
| f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}" | |
| ) | |
| return True | |