Spaces:
Runtime error
Runtime error
| import itertools | |
| import torch | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| def get_char_probs(texts, predictions, tokenizer): | |
| """ | |
| Maps prediction from encoded offset mapping to the text | |
| Prediction = 466 sequence length * batch | |
| text = 768 * batch | |
| Using offset mapping [(0, 4), ] -- 466 | |
| creates results that is size of texts | |
| for each text result[i] | |
| result[0, 4] = pred[0] like wise for all | |
| """ | |
| results = [np.zeros(len(t)) for t in texts] | |
| for i, (text, prediction) in enumerate(zip(texts, predictions)): | |
| encoded = tokenizer(text, | |
| add_special_tokens=True, | |
| return_offsets_mapping=True) | |
| for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)): | |
| start = offset_mapping[0] | |
| end = offset_mapping[1] | |
| results[i][start:end] = pred | |
| return results | |
| def get_results(char_probs, th=0.5): | |
| """ | |
| Get the list of probabilites with size of text | |
| And then get the index of the characters which are more than th | |
| example: | |
| char_prob = [0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7] ## length == 766 | |
| where > 0.5 index ## [ 2, 3, 4, 5, 9, 10, 11] | |
| Groupby same one -- [[2, 3, 4, 5], [9, 10, 11]] | |
| And get the max and min and output the results | |
| """ | |
| results = [] | |
| for char_prob in char_probs: | |
| result = np.where(char_prob >= th)[0] + 1 | |
| result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))] | |
| result = [f"{min(r)} {max(r)}" for r in result] | |
| result = ";".join(result) | |
| results.append(result) | |
| return results | |
| def get_predictions(results): | |
| """ | |
| Will get the location, as a string, just like location in the df | |
| results = ['2 5', '9 11'] | |
| loop through, split it and save it as start and end and store it in array | |
| """ | |
| predictions = [] | |
| for result in results: | |
| prediction = [] | |
| if result != "": | |
| for loc in [s.split() for s in result.split(';')]: | |
| start, end = int(loc[0]), int(loc[1]) | |
| prediction.append([start, end]) | |
| predictions.append(prediction) | |
| return predictions | |
| def inference_fn(test_loader, model, device): | |
| preds = [] | |
| model.eval() | |
| model.to(device) | |
| tk0 = tqdm(test_loader, total=len(test_loader)) | |
| for inputs in tk0: | |
| for k, v in inputs.items(): | |
| inputs[k] = v.to(device) | |
| with torch.no_grad(): | |
| y_preds = model(inputs) | |
| preds.append(y_preds.sigmoid().numpy()) | |
| predictions = np.concatenate(preds) | |
| return predictions | |
| def get_text(context, indexes): | |
| if (indexes): | |
| if ';' in indexes: | |
| list_indexes = indexes.split(';') | |
| answer = '' | |
| for idx in list_indexes: | |
| start_index = int(idx.split(' ')[0]) | |
| end_index = int(idx.split(' ')[1]) | |
| answer += ' ' | |
| answer += context[start_index:end_index] | |
| return answer | |
| else: | |
| start_index = int(indexes.split(' ')[0]) | |
| end_index = int(indexes.split(' ')[1]) | |
| return context[start_index:end_index] | |
| else: | |
| return 'Not found in this Context' | |