Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/2/15 7:57 下午 | |
| # @Author : JianingWang | |
| # @File : trie | |
| import logging | |
| from typing import List | |
| from collections import OrderedDict | |
| logger = logging.getLogger(__name__) | |
| class Trie: | |
| def __init__(self): | |
| self.data = {} | |
| def add(self, word: str): | |
| """ | |
| Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. | |
| The special key `""` is used to represent termination. | |
| This function is idempotent, adding twice the same word will leave the trie unchanged | |
| Example: | |
| ```python | |
| >>> trie = Trie() | |
| >>> trie.add("Hello 友達") | |
| >>> trie.data | |
| {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} | |
| >>> trie.add("Hello") | |
| >>> trie.data | |
| {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} | |
| ``` | |
| """ | |
| if not word: | |
| # Prevent empty string | |
| return | |
| ref = self.data | |
| for char in word: | |
| ref[char] = char in ref and ref[char] or {} | |
| ref = ref[char] | |
| ref[""] = 1 | |
| def find(self, text: str): | |
| states = OrderedDict() | |
| offsets = [] | |
| skip = 0 | |
| for current, current_char in enumerate(text): | |
| if skip and current < skip: | |
| continue | |
| to_remove = set() | |
| reset = False | |
| for start, trie_pointer in states.items(): | |
| if "" in trie_pointer: | |
| for lookstart, looktrie_pointer in states.items(): | |
| if lookstart > start: | |
| break | |
| elif lookstart < start: | |
| lookahead_index = current + 1 | |
| end = current + 1 | |
| else: | |
| lookahead_index = current | |
| end = current | |
| next_char = text[lookahead_index] if lookahead_index < len(text) else None | |
| if "" in looktrie_pointer: | |
| start = lookstart | |
| end = lookahead_index | |
| skip = lookahead_index | |
| while next_char in looktrie_pointer: | |
| looktrie_pointer = looktrie_pointer[next_char] | |
| lookahead_index += 1 | |
| if "" in looktrie_pointer: | |
| start = lookstart | |
| end = lookahead_index | |
| skip = lookahead_index | |
| if lookahead_index == len(text): | |
| break | |
| next_char = text[lookahead_index] | |
| offsets.append([start, end]) | |
| reset = True | |
| break | |
| elif current_char in trie_pointer: | |
| trie_pointer = trie_pointer[current_char] | |
| states[start] = trie_pointer | |
| else: | |
| to_remove.add(start) | |
| if reset: | |
| states = {} | |
| else: | |
| for start in to_remove: | |
| del states[start] | |
| if current >= skip and current_char in self.data: | |
| states[current] = self.data[current_char] | |
| for start, trie_pointer in states.items(): | |
| if "" in trie_pointer: | |
| end = len(text) | |
| offsets.append([start, end]) | |
| break | |
| return offsets | |
| def split(self, text: str) -> List[str]: | |
| """ | |
| Example: | |
| ```python | |
| >>> trie = Trie() | |
| >>> trie.split("[CLS] This is a extra_id_100") | |
| ["[CLS] This is a extra_id_100"] | |
| >>> trie.add("[CLS]") | |
| >>> trie.add("extra_id_1") | |
| >>> trie.add("extra_id_100") | |
| >>> trie.split("[CLS] This is a extra_id_100") | |
| ["[CLS]", " This is a ", "extra_id_100"] | |
| ``` | |
| """ | |
| word_sets = self.find(text) | |
| offsets = [0] | |
| for w in word_sets: | |
| offsets.extend(w) | |
| return self.cut_text(text, offsets) | |
| def cut_text(self, text, offsets): | |
| offsets.append(len(text)) | |
| tokens = [] | |
| start = 0 | |
| for end in offsets: | |
| if start > end: | |
| logger.error( | |
| "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway." | |
| ) | |
| continue | |
| elif start == end: | |
| continue | |
| tokens.append(text[start:end]) | |
| start = end | |
| return tokens | |
| def __reduce__(self): | |
| return None | |
| if __name__ == "__main__": | |
| trie = Trie() | |
| for word in ["A", "AB", "BD", "BWA"]: | |
| trie.add(word) | |
| print(trie.__reduce__()) | |