Spaces:
Running
Running
| import os | |
| from time import time | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| def get_bpe_groups(token_offsets, bpe_offsets, input_ids, max_bpe_pieces=5): | |
| bpe_groups = [] | |
| last_used_bpe = 0 | |
| # find the size of offsets | |
| if (0, 0) in bpe_offsets: | |
| bpe_size = bpe_offsets.index((0, 0)) | |
| else: | |
| bpe_size = len(bpe_offsets) | |
| saved_ids = [i for i in range(len(input_ids))] | |
| redundant_ids = [] | |
| for token_offset in token_offsets: | |
| start_token, end_token = token_offset | |
| bpe_group = [] | |
| mapping_is_found = False | |
| for i in range(last_used_bpe, bpe_size): | |
| start_bpe, end_bpe = bpe_offsets[i] | |
| if start_bpe >= start_token and end_bpe <= end_token: | |
| # check if bpe_group is satisfy max_bpe_pieces constraint | |
| if len(bpe_group) < max_bpe_pieces: | |
| bpe_group.append(i) | |
| else: | |
| redundant_ids.append(i) | |
| last_used_bpe = i + 1 | |
| mapping_is_found = True | |
| elif mapping_is_found: | |
| # stop doing useless iterations | |
| break | |
| else: | |
| continue | |
| bpe_groups.append(bpe_group) | |
| saved_ids = [i for i in saved_ids if i not in redundant_ids] | |
| return bpe_groups, saved_ids | |
| def reduce_input_ids(input_ids, bpe_groups, saved_ids, | |
| max_bpe_length=80, max_bpe_pieces=5): | |
| # check if sequence is satisfy max_bpe_length constraint | |
| while len(saved_ids) > max_bpe_length: | |
| max_bpe_pieces -= 1 | |
| for token_id in range(len(bpe_groups)): | |
| if len(bpe_groups[token_id]) > max_bpe_pieces: | |
| redundant_ids = bpe_groups[token_id][max_bpe_pieces:] | |
| bpe_groups[token_id] = bpe_groups[token_id][:max_bpe_pieces] | |
| saved_ids = [i for i in saved_ids if i not in redundant_ids] | |
| # get offsets | |
| reduced_ids = [input_ids[i] for i in saved_ids] | |
| correct_offsets = [] | |
| idx = 0 | |
| for i, bpe_group in enumerate(bpe_groups): | |
| norm_idx = min(idx, len(reduced_ids) - 1) | |
| correct_offsets.append(norm_idx) | |
| idx += len(bpe_group) | |
| return reduced_ids, correct_offsets | |
| def get_offsets_and_reduce_input_ids(tokenizer_output, token_offset_list, | |
| index_name="bert", max_bpe_length=80, | |
| max_bpe_pieces=5): | |
| timings = {"bpe": 0, "reduce": 0, "mask": 0} | |
| output_ids, output_offsets, output_masks = [], [], [] | |
| for i, token_offsets in enumerate(token_offset_list): | |
| input_ids = tokenizer_output['input_ids'][i] | |
| t0 = time() | |
| # get bpe level offsets | |
| bpe_offsets = tokenizer_output['offset_mapping'][i] | |
| bpe_groups, saved_ids = get_bpe_groups(token_offsets, bpe_offsets, | |
| input_ids, | |
| max_bpe_pieces=max_bpe_pieces) | |
| t1 = time() | |
| timings["bpe"] += t1 - t0 | |
| # reduce sequence length | |
| reduced_ids, correct_offsets = reduce_input_ids(input_ids, bpe_groups, | |
| saved_ids, | |
| max_bpe_length=max_bpe_length, | |
| max_bpe_pieces=max_bpe_pieces) | |
| t2 = time() | |
| timings["reduce"] += t2 - t1 | |
| # get mask | |
| bpe_mask = [1 for _ in correct_offsets] | |
| output_ids.append(reduced_ids) | |
| output_offsets.append(correct_offsets) | |
| output_masks.append(bpe_mask) | |
| t3 = time() | |
| timings["mask"] += t3 - t2 | |
| # tt = sum(timings.values()) | |
| # timings = {k: f"{round(v * 100 / tt, 2)}%" for k, v in timings.items()} | |
| # print(timings) | |
| output = {index_name: output_ids, | |
| f"{index_name}-offsets": output_offsets, | |
| "mask": output_masks} | |
| return output | |
| def get_offset_for_tokens(tokens): | |
| sentence = " ".join(tokens) | |
| token_offsets = [] | |
| end_idx = 0 | |
| for token in tokens: | |
| idx = sentence[end_idx:].index(token) + end_idx | |
| end_idx = idx + len(token) | |
| offset = (idx, end_idx) | |
| token_offsets.append(offset) | |
| return token_offsets | |
| def get_token_offsets(batch): | |
| token_offset_list = [] | |
| for tokens in batch: | |
| token_offsets = get_offset_for_tokens(tokens) | |
| token_offset_list.append(token_offsets) | |
| return token_offset_list | |
| def pad_output(output, pad_idx=0): | |
| padded_output = {} | |
| for input_key in output.keys(): | |
| indexes = output[input_key] | |
| max_len = max([len(x) for x in indexes]) | |
| padded_indexes = [] | |
| for index_list in indexes: | |
| cur_len = len(index_list) | |
| pad_len = max_len - cur_len | |
| padded_indexes.append(index_list + [pad_idx] * pad_len) | |
| padded_output[input_key] = padded_indexes | |
| return padded_output | |
| def tokenize_batch(tokenizer, batch_tokens, index_name="bert", | |
| max_bpe_length=80, max_bpe_pieces=5): | |
| timings = {} | |
| t0 = time() | |
| # get batch with sentences | |
| batch_sentences = [" ".join(x) for x in batch_tokens] | |
| # get token level offsets | |
| token_offset_list = get_token_offsets(batch_tokens) | |
| # token_offset_list = get_token_offsets_multi(batch_tokens) | |
| t1 = time() | |
| timings["offset_time"] = t1 - t0 | |
| # tokenize batch | |
| tokenizer_output = tokenizer.batch_encode_plus(batch_sentences, | |
| pad_to_max_length=False, | |
| return_offsets_mapping=True, | |
| add_special_tokens=False) | |
| t2 = time() | |
| timings["tokenize_time"] = t2 - t1 | |
| # postprocess batch | |
| output = get_offsets_and_reduce_input_ids(tokenizer_output, | |
| token_offset_list, | |
| index_name=index_name, | |
| max_bpe_length=max_bpe_length, | |
| max_bpe_pieces=max_bpe_pieces) | |
| t3 = time() | |
| timings["reduce_time"] = t3 - t2 | |
| # pad output | |
| output = pad_output(output) | |
| t4 = time() | |
| timings["pading_time"] = t4 - t3 | |
| # tt = sum(timings.values()) | |
| # timings = {k:f"{round(v*100/tt, 2)}%" for k,v in timings.items()} | |
| # print(timings) | |
| return output | |