Spaces:
Runtime error
Runtime error
| print('Loading dependencies...') | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, LlamaForCausalLM, LlamaTokenizer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| import torch | |
| import re | |
| from typing import List, Tuple | |
| import spacy | |
| import numpy as np | |
| import os | |
| from dataclasses import dataclass | |
| from nltk.tokenize import sent_tokenize, word_tokenize | |
| import time | |
| DEVICE = torch.device('cpu') | |
| class LexicalUnits: | |
| unit_type: str | |
| text: List[str] | |
| self_info: List[float] = None | |
| def __add__(self, other): | |
| assert self.unit_type == other.unit_type, 'Cannot add two different unit types' | |
| return LexicalUnits(self.unit_type, self.text + other.text, self.self_info + other.self_info) | |
| def __radd__(self, other): | |
| if other == 0: | |
| return self | |
| return NotImplementedError() | |
| def add_to_head(self, token, self_info): | |
| return LexicalUnits(self.unit_type, [token] + self.text, [self_info] + self.self_info) | |
| def add_to_tail(self, token, self_info): | |
| return LexicalUnits(self.unit_type, self.text + [token], self.self_info + [self_info]) | |
| class SelectiveContext: | |
| def __init__(self, model_type = 'gpt2', lang = 'en', device = 'cpu'): | |
| self.model_type = model_type | |
| self.lang = lang | |
| global DEVICE | |
| DEVICE = device | |
| # this means we calculate self-information sentence by sentence | |
| self.sent_level_self_info = True | |
| self._prepare_phrase_tokenizer() | |
| self.sent_tokenize_pattern = r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
| self.phrase_mask_token = '' | |
| self.sent_mask_token = "<...some content omitted.>" | |
| self._prepare_model() | |
| def _prepare_phrase_tokenizer(self): | |
| # we use space to tokenize sentence into phrases | |
| # for English, we should use `spacy.load("en_core_web_sm").add_pipe('merge_noun_chunks')` | |
| # for Chinese, use `nlp = spacy.load('zh_core_web_sm')`` directly | |
| lang = self.lang | |
| if lang == "en": | |
| self.nlp = spacy.load("en_core_web_sm", disable=["ner"]) | |
| self.nlp.add_pipe('merge_noun_chunks') | |
| elif lang == "zh": | |
| self.nlp = spacy.load('zh_core_web_sm', disable=["ner"]) | |
| # elif self.model_type == 'llama': | |
| # self.nlp = spacy.load('en_core_web_sm', disable=["ner"]) | |
| def _prepare_model(self): | |
| # Load tokenizer | |
| if self.lang == 'zh': | |
| self.tokenizer = BertTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall') | |
| elif self.lang == 'en': | |
| self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| else: | |
| raise NotImplementedError() | |
| if self.model_type == 'gpt2': | |
| if self.lang == 'zh': | |
| self.model = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall') | |
| else: | |
| self.model = GPT2LMHeadModel.from_pretrained('gpt2') | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| print('model loaded') | |
| self.max_token_length = self.model.config.n_positions | |
| self.get_self_information = self._get_self_info_via_gpt2 | |
| elif self.model_type == 'curie': | |
| global openai | |
| import openai | |
| self.max_token_length = 2048 | |
| self.get_self_information = self._get_self_info_via_curie | |
| elif self.model_type == 'llama': | |
| print("Before tokernizer") | |
| self.tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token='LLaMA TOKEN') | |
| print("Before model") | |
| config = AutoConfig.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token='LLaMA TOKEN') | |
| print("After config") | |
| self.model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', config=config, token='LLaMA TOKEN') | |
| print("Before DEVICE") | |
| self.model.to(DEVICE) | |
| print("Before eval") | |
| self.model.eval() | |
| print('model loaded') | |
| self.max_token_length = self.model.config.max_position_embeddings | |
| self.get_self_information = self._get_self_info_via_llama | |
| def get_self_information(self, text: str) -> Tuple[List[str], List[float]]: | |
| # it takes text as input, and return a list of words and a list of self-information scores | |
| raise NotImplementedError | |
| def _get_self_info_via_gpt2(self, text: str) -> Tuple[List[str], List[float]]: | |
| if self.lang == 'en': | |
| text = f"<|endoftext|>{text}" | |
| elif self.lang == 'zh': | |
| text = f"[CLS]{text}" | |
| with torch.no_grad(): | |
| encoding = self.tokenizer(text, add_special_tokens=False, return_tensors='pt') | |
| encoding = encoding.to(DEVICE) | |
| outputs = self.model(**encoding) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=-1) | |
| self_info = -torch.log(probs) | |
| input_ids = encoding['input_ids'] | |
| input_ids_expaned = input_ids[:, 1:].unsqueeze(-1) | |
| tokens = [self.tokenizer.decode(token_) for token_ in input_ids.squeeze().tolist()[1:]] | |
| return tokens, self_info[:, :-1].gather(-1, input_ids_expaned).squeeze(-1).squeeze(0).tolist() | |
| def _get_self_info_via_curie(self, text: str) -> Tuple[List[str], List[float]]: | |
| num_retry = 3 | |
| openai.api_key = os.environ["OPENAI_API_KEY"] | |
| for _ in range(num_retry): | |
| try: | |
| r = openai.Completion.create( | |
| model="curie", | |
| prompt=f"<|endoftext|>{text}", | |
| max_tokens=0, | |
| temperature=0, | |
| echo=True, | |
| logprobs=0, | |
| ) | |
| break | |
| except Exception as e: | |
| print(e) | |
| time.sleep(1) | |
| result = r['choices'][0] | |
| tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:] | |
| assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}" | |
| self_info = [ -logprob for logprob in logprobs] | |
| return tokens, self_info | |
| def _get_self_info_via_llama(self, text: str) -> Tuple[List[str], List[float]]: | |
| inputs = self.tokenizer.encode_plus(text, return_tensors="pt") | |
| input_ids = inputs.input_ids.to(DEVICE) | |
| attention_mask = inputs.attention_mask.to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=-1) | |
| self_info = -torch.log(probs) | |
| input_ids = input_ids.squeeze() | |
| self_info = self_info.squeeze() | |
| tokens = self.tokenizer.convert_ids_to_tokens(input_ids) | |
| return tokens, self_info.tolist() | |
| def _lexical_unit(self, sents): | |
| if self.sent_level_self_info: | |
| sent_self_info = [] | |
| all_noun_phrases = [] | |
| all_noun_phrases_info = [] | |
| all_tokens = [] | |
| all_token_self_info = [] | |
| for sent in sents: | |
| # print(sent) | |
| tokens, self_info = self.get_self_information(sent) | |
| sent_self_info.append(np.mean(self_info)) | |
| all_tokens.extend(tokens) | |
| all_token_self_info.extend(self_info) | |
| noun_phrases, noun_phrases_info = self._calculate_lexical_unit(tokens, self_info) | |
| # We need to add a space before the first noun phrase for every sentence except the first one | |
| if len(all_noun_phrases) != 0: | |
| noun_phrases[0] = f" {noun_phrases[0]}" | |
| all_noun_phrases.extend(noun_phrases) | |
| all_noun_phrases_info.extend(noun_phrases_info) | |
| return [ | |
| LexicalUnits('sent', text=sents, self_info=sent_self_info), | |
| LexicalUnits('phrase', text=all_noun_phrases, self_info=all_noun_phrases_info), | |
| LexicalUnits('token', text=all_tokens, self_info=all_token_self_info) | |
| ] | |
| def _calculate_lexical_unit(self, tokens, self_info): | |
| def _unit_info(tokens, self_info, units): | |
| current_unit_idx = 0 | |
| current_position = 0 | |
| unit_self_info = [[] for _ in range(len(units))] | |
| for idx, (token, info) in enumerate(zip(tokens, self_info)): | |
| current_position += len(token) | |
| if current_position == len(units[current_unit_idx]): | |
| unit_self_info[current_unit_idx].append(info) | |
| current_position = current_position - len(units[current_unit_idx]) | |
| current_unit_idx += 1 | |
| elif current_position > len(units[current_unit_idx]): | |
| counter_ = 1 | |
| current_position = current_position - len(units[current_unit_idx]) | |
| current_unit_idx += 1 | |
| while current_position >= len(units[current_unit_idx]): | |
| counter_ += 1 | |
| current_position = current_position - len(units[current_unit_idx]) | |
| current_unit_idx += 1 | |
| if current_unit_idx >= len(units): | |
| break | |
| partial_info = info/counter_ | |
| for _ in range(counter_): | |
| unit_self_info[(current_unit_idx-1) - _].append(partial_info) | |
| else: | |
| if token == " ": | |
| continue | |
| unit_self_info[current_unit_idx].append(info) | |
| unit_self_info_ = [np.mean(info) for info in unit_self_info] | |
| return unit_self_info_ | |
| def _noun_phrases(sent): | |
| noun_phrases = [] | |
| doc = self.nlp(sent) | |
| for index, chunk in enumerate(doc): | |
| if index == 0: | |
| noun_phrases.append(chunk.text) | |
| else: | |
| noun_phrases.append(doc[index-1].whitespace_ + chunk.text) | |
| return noun_phrases | |
| if self.sent_level_self_info: | |
| # in this case, the self_info is for each sentence | |
| # we only need to calculate the self_info for each phrase | |
| sent = ''.join(tokens) | |
| # noun_phrases = [chunk.text for chunk in self.nlp(sent).noun_chunks] | |
| noun_phrases = _noun_phrases(sent) | |
| # noun_phrases[-1] = noun_phrases[-1] + ' ' | |
| noun_phrases_info = _unit_info(tokens, self_info, noun_phrases) | |
| return noun_phrases, noun_phrases_info | |
| def beautify_context(self, context: str) -> str: | |
| context = re.sub(r"\s+", " ", context) | |
| return context | |
| def self_info_mask(self, sents: List[str], self_info: List[float], mask_level): | |
| # mask_level: mask sentences, phrases, or tokens | |
| sents_after_mask = [] | |
| masked_sents = [] | |
| self.ppl_threshold = np.nanpercentile(self_info, self.mask_ratio * 100) | |
| # if title is not None: | |
| # with open(os.path.join(self.path, title+'_prob_token.tsv'), 'w', encoding='utf-8') as f: | |
| # for token, info in zip(tokens, self_info): | |
| # f.write(f"{token}\t{info}\n") | |
| # with open(os.path.join(self.path, title+'_prob_sent.tsv'), 'w', encoding='utf-8') as f: | |
| # for sent, info in zip(sents, sent_self_info): | |
| # f.write(f"{sent}\n{info}\n\n") | |
| for sent, info in zip(sents, self_info): | |
| if info < self.ppl_threshold: | |
| masked_sents.append(sent) | |
| sents_after_mask.append(self.mask_a_sent(sent, mask_level)) | |
| else: | |
| sents_after_mask.append(sent) | |
| masked_context = " ".join(sents_after_mask) if mask_level == 'sent' else "".join(sents_after_mask) | |
| return masked_context, masked_sents | |
| def mask_a_sent(self, sent, level): | |
| if level == 'phrase': | |
| return self.phrase_mask_token | |
| elif level == 'sent': | |
| if self.keep_leading_word: | |
| leading_few_words = " ".join(word_tokenize(sent)[:self.num_lead_words]) + " " | |
| else: | |
| leading_few_words = "" | |
| return leading_few_words + self.mask_token | |
| elif level == 'token': | |
| return '' | |
| def __call__(self, text: str, reduce_ratio: float = 0.35, reduce_level :str = 'phrase') -> List[str]: | |
| context = self.beautify_context(text) | |
| self.mask_ratio = reduce_ratio | |
| sents = re.split(self.sent_tokenize_pattern, context) | |
| sents = [sent.strip() for sent in sents if sent.strip()] | |
| # You want the reduce happen at sentence level, phrase level, or token level? | |
| assert reduce_level in ['sent', 'phrase', 'token'], f"reduce_level should be one of ['sent', 'phrase', 'token'], got {reduce_level}" | |
| sent_lus, phrase_lus, token_lus = self._lexical_unit(sents) | |
| # print(phrase_lus, '^^^^') | |
| lexical_level = { | |
| 'sent': sent_lus, | |
| 'phrase': phrase_lus, | |
| 'token': token_lus | |
| } | |
| # context is the reduced context, masked_sents denotes what context has been filtered out | |
| context, masked_sents = self.self_info_mask(lexical_level[reduce_level].text, lexical_level[reduce_level].self_info, reduce_level) | |
| return context, masked_sents | |
| def main( | |
| model_type = 'gpt2', # you can choose from ['gpt2', 'curie'] | |
| lang = 'en', # currenlty only support en and zh | |
| file_to_process: str = None, | |
| file_to_save: str = None, | |
| ): | |
| global DEVICE | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {DEVICE}") | |
| sc = SelectiveContext(model_type=model_type, lang=lang) | |
| if file_to_process is None: | |
| while True: | |
| text = input("Please input the text you want to reduce: ") | |
| if text == 'exit': | |
| break | |
| context, masked_sents = sc(text) | |
| print('***********\nThe resultsing context is: \n') | |
| print(context, '\n\n') | |
| print('***********\nThe content that has been filtered out is: \n') | |
| print(masked_sents, '\n\n') | |
| else: | |
| with open(file_to_process, 'r') as f: | |
| text = f.read() | |
| context, masked_sents = sc(text) | |
| with open(file_to_save, 'w') as f: | |
| f.write(context) | |
| if __name__ == "__main__": | |
| main(model_type='gpt2', lang = 'zh') |