Spaces:
Running
Running
| """ | |
| This module contains the helper functions to get the word alignment mapping between two sentences. | |
| """ | |
| import torch | |
| import itertools | |
| import transformers | |
| from transformers import logging | |
| # Set the verbosity to error, so that the warning messages are not printed | |
| logging.set_verbosity_warning() | |
| logging.set_verbosity_error() | |
| def get_alignment_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"): | |
| """ | |
| Get Aligned Words | |
| """ | |
| model = transformers.BertModel.from_pretrained(model_path) | |
| tokenizer = transformers.BertTokenizer.from_pretrained(model_path) | |
| # pre-processing | |
| sent_src, sent_tgt = source.strip().split(), target.strip().split() | |
| token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [ | |
| tokenizer.tokenize(word) for word in sent_tgt] | |
| wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [ | |
| tokenizer.convert_tokens_to_ids(x) for x in token_tgt] | |
| ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)[ | |
| 'input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids'] | |
| sub2word_map_src = [] | |
| for i, word_list in enumerate(token_src): | |
| sub2word_map_src += [i for x in word_list] | |
| sub2word_map_tgt = [] | |
| for i, word_list in enumerate(token_tgt): | |
| sub2word_map_tgt += [i for x in word_list] | |
| # alignment | |
| align_layer = 8 | |
| threshold = 1e-3 | |
| model.eval() | |
| with torch.no_grad(): | |
| out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[ | |
| 2][align_layer][0, 1:-1] | |
| out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[ | |
| 2][align_layer][0, 1:-1] | |
| dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2)) | |
| softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod) | |
| softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod) | |
| softmax_inter = (softmax_srctgt > threshold) * \ | |
| (softmax_tgtsrc > threshold) | |
| align_subwords = torch.nonzero(softmax_inter, as_tuple=False) | |
| align_words = set() | |
| for i, j in align_subwords: | |
| align_words.add((sub2word_map_src[i], sub2word_map_tgt[j])) | |
| return sent_src, sent_tgt, align_words | |
| def get_word_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"): | |
| """ | |
| Get Word Aligned Mapping Words | |
| """ | |
| sent_src, sent_tgt, align_words = get_alignment_mapping( | |
| source=source, target=target, model_path=model_path) | |
| result = [] | |
| for i, j in sorted(align_words): | |
| result.append(f'bn:({sent_src[i]}) -> en:({sent_tgt[j]})') | |
| return result | |
| def get_word_index_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"): | |
| """ | |
| Get Word Aligned Mapping Index | |
| """ | |
| sent_src, sent_tgt, align_words = get_alignment_mapping( | |
| source=source, target=target, model_path=model_path) | |
| result = [] | |
| for i, j in sorted(align_words): | |
| result.append(f'bn:({i}) -> en:({j})') | |
| return result |