Upload 25 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/bert_model.cpython-39.pyc +0 -0
- utils/__pycache__/callbacks.cpython-39.pyc +0 -0
- utils/__pycache__/file_utils.cpython-39.pyc +0 -0
- utils/__pycache__/finetune.cpython-39.pyc +0 -0
- utils/__pycache__/lightning_base.cpython-39.pyc +0 -0
- utils/__pycache__/sentence_retrieval_model.cpython-39.pyc +0 -0
- utils/__pycache__/sentence_retrieval_module.cpython-39.pyc +0 -0
- utils/__pycache__/textual_entailment_module.cpython-39.pyc +0 -0
- utils/__pycache__/utils_graph2text.cpython-39.pyc +0 -0
- utils/__pycache__/utils_verbalisation_module.cpython-39.pyc +0 -0
- utils/__pycache__/verbalisation_module.cpython-39.pyc +0 -0
- utils/__pycache__/wikidata_utils.cpython-39.pyc +0 -0
- utils/bert_model.py +775 -0
- utils/callbacks.py +140 -0
- utils/file_utils.py +249 -0
- utils/finetune.py +633 -0
- utils/lightning_base.py +418 -0
- utils/sentence_retrieval_model.py +20 -0
- utils/sentence_retrieval_module.py +77 -0
- utils/textual_entailment_module.py +94 -0
- utils/utils_graph2text.py +114 -0
- utils/utils_verbalisation_module.py +610 -0
- utils/verbalisation_module.py +300 -0
- utils/wikidata_utils.py +173 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/bert_model.cpython-39.pyc
ADDED
|
Binary file (30.6 kB). View file
|
|
|
utils/__pycache__/callbacks.cpython-39.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
utils/__pycache__/file_utils.cpython-39.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
utils/__pycache__/finetune.cpython-39.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
utils/__pycache__/lightning_base.cpython-39.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
utils/__pycache__/sentence_retrieval_model.cpython-39.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
utils/__pycache__/sentence_retrieval_module.cpython-39.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
utils/__pycache__/textual_entailment_module.cpython-39.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
utils/__pycache__/utils_graph2text.cpython-39.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
utils/__pycache__/utils_verbalisation_module.cpython-39.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
utils/__pycache__/verbalisation_module.cpython-39.pyc
ADDED
|
Binary file (7.37 kB). View file
|
|
|
utils/__pycache__/wikidata_utils.cpython-39.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
utils/bert_model.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 19 |
+
|
| 20 |
+
import copy
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import math
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
import tarfile
|
| 27 |
+
import tempfile
|
| 28 |
+
import sys
|
| 29 |
+
from io import open
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from torch import nn
|
| 33 |
+
from torch.nn import CrossEntropyLoss
|
| 34 |
+
|
| 35 |
+
from utils.file_utils import cached_path
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
| 40 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
| 41 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
| 42 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
| 43 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
| 44 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
| 45 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
| 46 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
| 47 |
+
}
|
| 48 |
+
CONFIG_NAME = 'bert_config.json'
|
| 49 |
+
WEIGHTS_NAME = 'pytorch_model.bin'
|
| 50 |
+
TF_WEIGHTS_NAME = 'model.ckpt'
|
| 51 |
+
|
| 52 |
+
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
| 53 |
+
""" Load tf checkpoints in a pytorch model
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
import re
|
| 57 |
+
import numpy as np
|
| 58 |
+
import tensorflow as tf
|
| 59 |
+
except ImportError:
|
| 60 |
+
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
| 61 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
| 62 |
+
raise
|
| 63 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 64 |
+
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 65 |
+
# Load weights from TF model
|
| 66 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 67 |
+
names = []
|
| 68 |
+
arrays = []
|
| 69 |
+
for name, shape in init_vars:
|
| 70 |
+
print("Loading TF weight {} with shape {}".format(name, shape))
|
| 71 |
+
array = tf.train.load_variable(tf_path, name)
|
| 72 |
+
names.append(name)
|
| 73 |
+
arrays.append(array)
|
| 74 |
+
|
| 75 |
+
for name, array in zip(names, arrays):
|
| 76 |
+
name = name.split('/')
|
| 77 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 78 |
+
# which are not required for using pretrained model
|
| 79 |
+
if any(n in ["adam_v", "adam_m"] for n in name):
|
| 80 |
+
print("Skipping {}".format("/".join(name)))
|
| 81 |
+
continue
|
| 82 |
+
pointer = model
|
| 83 |
+
for m_name in name:
|
| 84 |
+
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
| 85 |
+
l = re.split(r'_(\d+)', m_name)
|
| 86 |
+
else:
|
| 87 |
+
l = [m_name]
|
| 88 |
+
if l[0] == 'kernel' or l[0] == 'gamma':
|
| 89 |
+
pointer = getattr(pointer, 'weight')
|
| 90 |
+
elif l[0] == 'output_bias' or l[0] == 'beta':
|
| 91 |
+
pointer = getattr(pointer, 'bias')
|
| 92 |
+
elif l[0] == 'output_weights':
|
| 93 |
+
pointer = getattr(pointer, 'weight')
|
| 94 |
+
else:
|
| 95 |
+
pointer = getattr(pointer, l[0])
|
| 96 |
+
if len(l) >= 2:
|
| 97 |
+
num = int(l[1])
|
| 98 |
+
pointer = pointer[num]
|
| 99 |
+
if m_name[-11:] == '_embeddings':
|
| 100 |
+
pointer = getattr(pointer, 'weight')
|
| 101 |
+
elif m_name == 'kernel':
|
| 102 |
+
array = np.transpose(array)
|
| 103 |
+
try:
|
| 104 |
+
assert pointer.shape == array.shape
|
| 105 |
+
except AssertionError as e:
|
| 106 |
+
e.args += (pointer.shape, array.shape)
|
| 107 |
+
raise
|
| 108 |
+
print("Initialize PyTorch weight {}".format(name))
|
| 109 |
+
pointer.data = torch.from_numpy(array)
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def gelu(x):
|
| 114 |
+
"""Implementation of the gelu activation function.
|
| 115 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 116 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 117 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 118 |
+
"""
|
| 119 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def swish(x):
|
| 123 |
+
return x * torch.sigmoid(x)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class BertConfig(object):
|
| 130 |
+
"""Configuration class to store the configuration of a `BertModel`.
|
| 131 |
+
"""
|
| 132 |
+
def __init__(self,
|
| 133 |
+
vocab_size_or_config_json_file,
|
| 134 |
+
hidden_size=768,
|
| 135 |
+
num_hidden_layers=12,
|
| 136 |
+
num_attention_heads=12,
|
| 137 |
+
intermediate_size=3072,
|
| 138 |
+
hidden_act="gelu",
|
| 139 |
+
hidden_dropout_prob=0.1,
|
| 140 |
+
attention_probs_dropout_prob=0.1,
|
| 141 |
+
max_position_embeddings=512,
|
| 142 |
+
type_vocab_size=2,
|
| 143 |
+
initializer_range=0.02):
|
| 144 |
+
"""Constructs BertConfig.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
| 148 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
| 149 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
| 150 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
| 151 |
+
the Transformer encoder.
|
| 152 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
| 153 |
+
layer in the Transformer encoder.
|
| 154 |
+
hidden_act: The non-linear activation function (function or string) in the
|
| 155 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
| 156 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
| 157 |
+
layers in the embeddings, encoder, and pooler.
|
| 158 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
| 159 |
+
probabilities.
|
| 160 |
+
max_position_embeddings: The maximum sequence length that this model might
|
| 161 |
+
ever be used with. Typically set this to something large just in case
|
| 162 |
+
(e.g., 512 or 1024 or 2048).
|
| 163 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
| 164 |
+
`BertModel`.
|
| 165 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
| 166 |
+
initializing all weight matrices.
|
| 167 |
+
"""
|
| 168 |
+
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
| 169 |
+
and isinstance(vocab_size_or_config_json_file, unicode)):
|
| 170 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
| 171 |
+
json_config = json.loads(reader.read())
|
| 172 |
+
for key, value in json_config.items():
|
| 173 |
+
self.__dict__[key] = value
|
| 174 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
| 175 |
+
self.vocab_size = vocab_size_or_config_json_file
|
| 176 |
+
self.hidden_size = hidden_size
|
| 177 |
+
self.num_hidden_layers = num_hidden_layers
|
| 178 |
+
self.num_attention_heads = num_attention_heads
|
| 179 |
+
self.hidden_act = hidden_act
|
| 180 |
+
self.intermediate_size = intermediate_size
|
| 181 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 182 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 183 |
+
self.max_position_embeddings = max_position_embeddings
|
| 184 |
+
self.type_vocab_size = type_vocab_size
|
| 185 |
+
self.initializer_range = initializer_range
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
| 188 |
+
"or the path to a pretrained model config file (str)")
|
| 189 |
+
|
| 190 |
+
@classmethod
|
| 191 |
+
def from_dict(cls, json_object):
|
| 192 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
| 193 |
+
config = BertConfig(vocab_size_or_config_json_file=-1)
|
| 194 |
+
for key, value in json_object.items():
|
| 195 |
+
config.__dict__[key] = value
|
| 196 |
+
return config
|
| 197 |
+
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_json_file(cls, json_file):
|
| 200 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
| 201 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
| 202 |
+
text = reader.read()
|
| 203 |
+
return cls.from_dict(json.loads(text))
|
| 204 |
+
|
| 205 |
+
def __repr__(self):
|
| 206 |
+
return str(self.to_json_string())
|
| 207 |
+
|
| 208 |
+
def to_dict(self):
|
| 209 |
+
"""Serializes this instance to a Python dictionary."""
|
| 210 |
+
output = copy.deepcopy(self.__dict__)
|
| 211 |
+
return output
|
| 212 |
+
|
| 213 |
+
def to_json_string(self):
|
| 214 |
+
"""Serializes this instance to a JSON string."""
|
| 215 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
| 219 |
+
except ImportError:
|
| 220 |
+
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
| 221 |
+
class BertLayerNorm(nn.Module):
|
| 222 |
+
def __init__(self, hidden_size, eps=1e-12):
|
| 223 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
| 224 |
+
"""
|
| 225 |
+
super(BertLayerNorm, self).__init__()
|
| 226 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 227 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
| 228 |
+
self.variance_epsilon = eps
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
u = x.mean(-1, keepdim=True)
|
| 232 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
| 233 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
| 234 |
+
return self.weight * x + self.bias
|
| 235 |
+
|
| 236 |
+
class BertEmbeddings(nn.Module):
|
| 237 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
| 238 |
+
"""
|
| 239 |
+
def __init__(self, config):
|
| 240 |
+
super(BertEmbeddings, self).__init__()
|
| 241 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 242 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 243 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 244 |
+
|
| 245 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 246 |
+
# any TensorFlow checkpoint file
|
| 247 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
| 248 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 249 |
+
|
| 250 |
+
def forward(self, input_ids, token_type_ids=None):
|
| 251 |
+
seq_length = input_ids.size(1)
|
| 252 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
| 253 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 254 |
+
if token_type_ids is None:
|
| 255 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 256 |
+
|
| 257 |
+
words_embeddings = self.word_embeddings(input_ids)
|
| 258 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 259 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 260 |
+
|
| 261 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
| 262 |
+
embeddings = self.LayerNorm(embeddings)
|
| 263 |
+
embeddings = self.dropout(embeddings)
|
| 264 |
+
return embeddings
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class BertSelfAttention(nn.Module):
|
| 268 |
+
def __init__(self, config):
|
| 269 |
+
super(BertSelfAttention, self).__init__()
|
| 270 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 273 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
| 274 |
+
self.num_attention_heads = config.num_attention_heads
|
| 275 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 276 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 277 |
+
|
| 278 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 279 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 280 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 281 |
+
|
| 282 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 283 |
+
|
| 284 |
+
def transpose_for_scores(self, x):
|
| 285 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 286 |
+
x = x.view(*new_x_shape)
|
| 287 |
+
return x.permute(0, 2, 1, 3)
|
| 288 |
+
|
| 289 |
+
def forward(self, hidden_states, attention_mask):
|
| 290 |
+
mixed_query_layer = self.query(hidden_states)
|
| 291 |
+
mixed_key_layer = self.key(hidden_states)
|
| 292 |
+
mixed_value_layer = self.value(hidden_states)
|
| 293 |
+
|
| 294 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 295 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 296 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 297 |
+
|
| 298 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 299 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 300 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 301 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 302 |
+
attention_scores = attention_scores + attention_mask
|
| 303 |
+
|
| 304 |
+
# Normalize the attention scores to probabilities.
|
| 305 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 306 |
+
|
| 307 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 308 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 309 |
+
attention_probs = self.dropout(attention_probs)
|
| 310 |
+
|
| 311 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 312 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 313 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 314 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 315 |
+
return context_layer
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class BertSelfOutput(nn.Module):
|
| 319 |
+
def __init__(self, config):
|
| 320 |
+
super(BertSelfOutput, self).__init__()
|
| 321 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 322 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
| 323 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 324 |
+
|
| 325 |
+
def forward(self, hidden_states, input_tensor):
|
| 326 |
+
hidden_states = self.dense(hidden_states)
|
| 327 |
+
hidden_states = self.dropout(hidden_states)
|
| 328 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 329 |
+
return hidden_states
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class BertAttention(nn.Module):
|
| 333 |
+
def __init__(self, config):
|
| 334 |
+
super(BertAttention, self).__init__()
|
| 335 |
+
self.self = BertSelfAttention(config)
|
| 336 |
+
self.output = BertSelfOutput(config)
|
| 337 |
+
|
| 338 |
+
def forward(self, input_tensor, attention_mask):
|
| 339 |
+
self_output = self.self(input_tensor, attention_mask)
|
| 340 |
+
attention_output = self.output(self_output, input_tensor)
|
| 341 |
+
return attention_output
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class BertIntermediate(nn.Module):
|
| 345 |
+
def __init__(self, config):
|
| 346 |
+
super(BertIntermediate, self).__init__()
|
| 347 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 348 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
| 349 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 350 |
+
else:
|
| 351 |
+
self.intermediate_act_fn = config.hidden_act
|
| 352 |
+
|
| 353 |
+
def forward(self, hidden_states):
|
| 354 |
+
hidden_states = self.dense(hidden_states)
|
| 355 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 356 |
+
return hidden_states
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class BertOutput(nn.Module):
|
| 360 |
+
def __init__(self, config):
|
| 361 |
+
super(BertOutput, self).__init__()
|
| 362 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 363 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
| 364 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 365 |
+
|
| 366 |
+
def forward(self, hidden_states, input_tensor):
|
| 367 |
+
hidden_states = self.dense(hidden_states)
|
| 368 |
+
hidden_states = self.dropout(hidden_states)
|
| 369 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 370 |
+
return hidden_states
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class BertLayer(nn.Module):
|
| 374 |
+
def __init__(self, config):
|
| 375 |
+
super(BertLayer, self).__init__()
|
| 376 |
+
self.attention = BertAttention(config)
|
| 377 |
+
self.intermediate = BertIntermediate(config)
|
| 378 |
+
self.output = BertOutput(config)
|
| 379 |
+
|
| 380 |
+
def forward(self, hidden_states, attention_mask):
|
| 381 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
| 382 |
+
intermediate_output = self.intermediate(attention_output)
|
| 383 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 384 |
+
return layer_output
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class BertEncoder(nn.Module):
|
| 388 |
+
def __init__(self, config):
|
| 389 |
+
super(BertEncoder, self).__init__()
|
| 390 |
+
layer = BertLayer(config)
|
| 391 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
| 392 |
+
|
| 393 |
+
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
| 394 |
+
all_encoder_layers = []
|
| 395 |
+
for layer_module in self.layer:
|
| 396 |
+
hidden_states = layer_module(hidden_states, attention_mask)
|
| 397 |
+
if output_all_encoded_layers:
|
| 398 |
+
all_encoder_layers.append(hidden_states)
|
| 399 |
+
if not output_all_encoded_layers:
|
| 400 |
+
all_encoder_layers.append(hidden_states)
|
| 401 |
+
return all_encoder_layers
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class BertPooler(nn.Module):
|
| 405 |
+
def __init__(self, config):
|
| 406 |
+
super(BertPooler, self).__init__()
|
| 407 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 408 |
+
self.activation = nn.Tanh()
|
| 409 |
+
|
| 410 |
+
def forward(self, hidden_states):
|
| 411 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 412 |
+
# to the first token.
|
| 413 |
+
first_token_tensor = hidden_states[:, 0]
|
| 414 |
+
pooled_output = self.dense(first_token_tensor)
|
| 415 |
+
pooled_output = self.activation(pooled_output)
|
| 416 |
+
return pooled_output
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 420 |
+
def __init__(self, config):
|
| 421 |
+
super(BertPredictionHeadTransform, self).__init__()
|
| 422 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 423 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
| 424 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 425 |
+
else:
|
| 426 |
+
self.transform_act_fn = config.hidden_act
|
| 427 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
| 428 |
+
|
| 429 |
+
def forward(self, hidden_states):
|
| 430 |
+
hidden_states = self.dense(hidden_states)
|
| 431 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 432 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 433 |
+
return hidden_states
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class BertLMPredictionHead(nn.Module):
|
| 437 |
+
def __init__(self, config, bert_model_embedding_weights):
|
| 438 |
+
super(BertLMPredictionHead, self).__init__()
|
| 439 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 440 |
+
|
| 441 |
+
# The output weights are the same as the input embeddings, but there is
|
| 442 |
+
# an output-only bias for each token.
|
| 443 |
+
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
| 444 |
+
bert_model_embedding_weights.size(0),
|
| 445 |
+
bias=False)
|
| 446 |
+
self.decoder.weight = bert_model_embedding_weights
|
| 447 |
+
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
| 448 |
+
|
| 449 |
+
def forward(self, hidden_states):
|
| 450 |
+
hidden_states = self.transform(hidden_states)
|
| 451 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
| 452 |
+
return hidden_states
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class BertOnlyMLMHead(nn.Module):
|
| 456 |
+
def __init__(self, config, bert_model_embedding_weights):
|
| 457 |
+
super(BertOnlyMLMHead, self).__init__()
|
| 458 |
+
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
| 459 |
+
|
| 460 |
+
def forward(self, sequence_output):
|
| 461 |
+
prediction_scores = self.predictions(sequence_output)
|
| 462 |
+
return prediction_scores
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class BertOnlyNSPHead(nn.Module):
|
| 466 |
+
def __init__(self, config):
|
| 467 |
+
super(BertOnlyNSPHead, self).__init__()
|
| 468 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 469 |
+
|
| 470 |
+
def forward(self, pooled_output):
|
| 471 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 472 |
+
return seq_relationship_score
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class BertPreTrainingHeads(nn.Module):
|
| 476 |
+
def __init__(self, config, bert_model_embedding_weights):
|
| 477 |
+
super(BertPreTrainingHeads, self).__init__()
|
| 478 |
+
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
| 479 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 480 |
+
|
| 481 |
+
def forward(self, sequence_output, pooled_output):
|
| 482 |
+
prediction_scores = self.predictions(sequence_output)
|
| 483 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 484 |
+
return prediction_scores, seq_relationship_score
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class BertPreTrainedModel(nn.Module):
|
| 488 |
+
""" An abstract class to handle weights initialization and
|
| 489 |
+
a simple interface for dowloading and loading pretrained models.
|
| 490 |
+
"""
|
| 491 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 492 |
+
super(BertPreTrainedModel, self).__init__()
|
| 493 |
+
if not isinstance(config, BertConfig):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
| 496 |
+
"To create a model from a Google pretrained model use "
|
| 497 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 498 |
+
self.__class__.__name__, self.__class__.__name__
|
| 499 |
+
))
|
| 500 |
+
self.config = config
|
| 501 |
+
|
| 502 |
+
def init_bert_weights(self, module):
|
| 503 |
+
""" Initialize the weights.
|
| 504 |
+
"""
|
| 505 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 506 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 507 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 508 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 509 |
+
elif isinstance(module, BertLayerNorm):
|
| 510 |
+
module.bias.data.zero_()
|
| 511 |
+
module.weight.data.fill_(1.0)
|
| 512 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 513 |
+
module.bias.data.zero_()
|
| 514 |
+
|
| 515 |
+
@classmethod
|
| 516 |
+
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
| 517 |
+
from_tf=False, *inputs, **kwargs):
|
| 518 |
+
"""
|
| 519 |
+
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| 520 |
+
Download and cache the pre-trained model file if needed.
|
| 521 |
+
|
| 522 |
+
Params:
|
| 523 |
+
pretrained_model_name_or_path: either:
|
| 524 |
+
- a str with the name of a pre-trained model to load selected in the list of:
|
| 525 |
+
. `bert-base-uncased`
|
| 526 |
+
. `bert-large-uncased`
|
| 527 |
+
. `bert-base-cased`
|
| 528 |
+
. `bert-large-cased`
|
| 529 |
+
. `bert-base-multilingual-uncased`
|
| 530 |
+
. `bert-base-multilingual-cased`
|
| 531 |
+
. `bert-base-chinese`
|
| 532 |
+
- a path or url to a pretrained model archive containing:
|
| 533 |
+
. `bert_config.json` a configuration file for the model
|
| 534 |
+
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
| 535 |
+
- a path or url to a pretrained model archive containing:
|
| 536 |
+
. `bert_config.json` a configuration file for the model
|
| 537 |
+
. `model.chkpt` a TensorFlow checkpoint
|
| 538 |
+
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
| 539 |
+
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
| 540 |
+
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
| 541 |
+
*inputs, **kwargs: additional input for the specific Bert class
|
| 542 |
+
(ex: num_labels for BertForSequenceClassification)
|
| 543 |
+
"""
|
| 544 |
+
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
| 545 |
+
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
| 546 |
+
else:
|
| 547 |
+
archive_file = pretrained_model_name_or_path
|
| 548 |
+
# redirect to the cache, if necessary
|
| 549 |
+
try:
|
| 550 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
| 551 |
+
except EnvironmentError:
|
| 552 |
+
logger.error(
|
| 553 |
+
"Model name '{}' was not found in model name list ({}). "
|
| 554 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
| 555 |
+
"associated to this path or url.".format(
|
| 556 |
+
pretrained_model_name_or_path,
|
| 557 |
+
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
| 558 |
+
archive_file))
|
| 559 |
+
return None
|
| 560 |
+
if resolved_archive_file == archive_file:
|
| 561 |
+
logger.info("loading archive file {}".format(archive_file))
|
| 562 |
+
else:
|
| 563 |
+
logger.info("loading archive file {} from cache at {}".format(
|
| 564 |
+
archive_file, resolved_archive_file))
|
| 565 |
+
tempdir = None
|
| 566 |
+
if os.path.isdir(resolved_archive_file) or from_tf:
|
| 567 |
+
serialization_dir = resolved_archive_file
|
| 568 |
+
else:
|
| 569 |
+
# Extract archive to temp dir
|
| 570 |
+
tempdir = tempfile.mkdtemp()
|
| 571 |
+
logger.info("extracting archive file {} to temp dir {}".format(
|
| 572 |
+
resolved_archive_file, tempdir))
|
| 573 |
+
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
| 574 |
+
archive.extractall(tempdir)
|
| 575 |
+
serialization_dir = tempdir
|
| 576 |
+
# Load config
|
| 577 |
+
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
| 578 |
+
config = BertConfig.from_json_file(config_file)
|
| 579 |
+
logger.info("Model config {}".format(config))
|
| 580 |
+
# Instantiate model.
|
| 581 |
+
model = cls(config, *inputs, **kwargs)
|
| 582 |
+
if state_dict is None and not from_tf:
|
| 583 |
+
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
| 584 |
+
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
| 585 |
+
if tempdir:
|
| 586 |
+
# Clean up temp dir
|
| 587 |
+
shutil.rmtree(tempdir)
|
| 588 |
+
if from_tf:
|
| 589 |
+
# Directly load from a TensorFlow checkpoint
|
| 590 |
+
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
| 591 |
+
return load_tf_weights_in_bert(model, weights_path)
|
| 592 |
+
# Load from a PyTorch state_dict
|
| 593 |
+
old_keys = []
|
| 594 |
+
new_keys = []
|
| 595 |
+
for key in state_dict.keys():
|
| 596 |
+
new_key = None
|
| 597 |
+
if 'gamma' in key:
|
| 598 |
+
new_key = key.replace('gamma', 'weight')
|
| 599 |
+
if 'beta' in key:
|
| 600 |
+
new_key = key.replace('beta', 'bias')
|
| 601 |
+
if new_key:
|
| 602 |
+
old_keys.append(key)
|
| 603 |
+
new_keys.append(new_key)
|
| 604 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
| 605 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 606 |
+
|
| 607 |
+
missing_keys = []
|
| 608 |
+
unexpected_keys = []
|
| 609 |
+
error_msgs = []
|
| 610 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 611 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 612 |
+
state_dict = state_dict.copy()
|
| 613 |
+
if metadata is not None:
|
| 614 |
+
state_dict._metadata = metadata
|
| 615 |
+
|
| 616 |
+
def load(module, prefix=''):
|
| 617 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 618 |
+
module._load_from_state_dict(
|
| 619 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| 620 |
+
for name, child in module._modules.items():
|
| 621 |
+
if child is not None:
|
| 622 |
+
load(child, prefix + name + '.')
|
| 623 |
+
start_prefix = ''
|
| 624 |
+
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
|
| 625 |
+
start_prefix = 'bert.'
|
| 626 |
+
load(model, prefix=start_prefix)
|
| 627 |
+
if len(missing_keys) > 0:
|
| 628 |
+
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
| 629 |
+
model.__class__.__name__, missing_keys))
|
| 630 |
+
if len(unexpected_keys) > 0:
|
| 631 |
+
logger.info("Weights from pretrained model not used in {}: {}".format(
|
| 632 |
+
model.__class__.__name__, unexpected_keys))
|
| 633 |
+
if len(error_msgs) > 0:
|
| 634 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
| 635 |
+
model.__class__.__name__, "\n\t".join(error_msgs)))
|
| 636 |
+
return model
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class BertModel(BertPreTrainedModel):
|
| 640 |
+
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
| 641 |
+
|
| 642 |
+
Params:
|
| 643 |
+
config: a BertConfig class instance with the configuration to build a new model
|
| 644 |
+
|
| 645 |
+
Inputs:
|
| 646 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
| 647 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
| 648 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
| 649 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
| 650 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
| 651 |
+
a `sentence B` token (see BERT paper for more details).
|
| 652 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
| 653 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
| 654 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
| 655 |
+
a batch has varying length sentences.
|
| 656 |
+
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
| 657 |
+
|
| 658 |
+
Outputs: Tuple of (encoded_layers, pooled_output)
|
| 659 |
+
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
| 660 |
+
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
| 661 |
+
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
| 662 |
+
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
| 663 |
+
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
| 664 |
+
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
| 665 |
+
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
| 666 |
+
classifier pretrained on top of the hidden state associated to the first character of the
|
| 667 |
+
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
| 668 |
+
|
| 669 |
+
Example usage:
|
| 670 |
+
```python
|
| 671 |
+
# Already been converted into WordPiece token ids
|
| 672 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
| 673 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
| 674 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
| 675 |
+
|
| 676 |
+
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
| 677 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
| 678 |
+
|
| 679 |
+
model = modeling.BertModel(config=config)
|
| 680 |
+
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| 681 |
+
```
|
| 682 |
+
"""
|
| 683 |
+
def __init__(self, config):
|
| 684 |
+
super(BertModel, self).__init__(config)
|
| 685 |
+
self.embeddings = BertEmbeddings(config)
|
| 686 |
+
self.encoder = BertEncoder(config)
|
| 687 |
+
self.pooler = BertPooler(config)
|
| 688 |
+
self.apply(self.init_bert_weights)
|
| 689 |
+
|
| 690 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
| 691 |
+
if attention_mask is None:
|
| 692 |
+
attention_mask = torch.ones_like(input_ids)
|
| 693 |
+
if token_type_ids is None:
|
| 694 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 695 |
+
|
| 696 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 697 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 698 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 699 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 700 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 701 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 702 |
+
|
| 703 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 704 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 705 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 706 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 707 |
+
# effectively the same as removing these entirely.
|
| 708 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 709 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 710 |
+
|
| 711 |
+
embedding_output = self.embeddings(input_ids, token_type_ids)
|
| 712 |
+
encoded_layers = self.encoder(embedding_output,
|
| 713 |
+
extended_attention_mask,
|
| 714 |
+
output_all_encoded_layers=output_all_encoded_layers)
|
| 715 |
+
sequence_output = encoded_layers[-1]
|
| 716 |
+
pooled_output = self.pooler(sequence_output)
|
| 717 |
+
if not output_all_encoded_layers:
|
| 718 |
+
encoded_layers = encoded_layers[-1]
|
| 719 |
+
return encoded_layers, pooled_output
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class BertForSequenceEncoder(BertPreTrainedModel):
|
| 726 |
+
"""BERT model for classification.
|
| 727 |
+
This module is composed of the BERT model with a linear layer on top of
|
| 728 |
+
the pooled output.
|
| 729 |
+
Params:
|
| 730 |
+
`config`: a BertConfig class instance with the configuration to build a new model.
|
| 731 |
+
`num_labels`: the number of classes for the classifier. Default = 2.
|
| 732 |
+
Inputs:
|
| 733 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
| 734 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
| 735 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
| 736 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
| 737 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
| 738 |
+
a `sentence B` token (see BERT paper for more details).
|
| 739 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
| 740 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
| 741 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
| 742 |
+
a batch has varying length sentences.
|
| 743 |
+
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
| 744 |
+
with indices selected in [0, ..., num_labels].
|
| 745 |
+
Outputs:
|
| 746 |
+
if `labels` is not `None`:
|
| 747 |
+
Outputs the CrossEntropy classification loss of the output with the labels.
|
| 748 |
+
if `labels` is `None`:
|
| 749 |
+
Outputs the classification logits of shape [batch_size, num_labels].
|
| 750 |
+
Example usage:
|
| 751 |
+
```python
|
| 752 |
+
# Already been converted into WordPiece token ids
|
| 753 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
| 754 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
| 755 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
| 756 |
+
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
| 757 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
| 758 |
+
num_labels = 2
|
| 759 |
+
model = BertForSequenceClassification(config, num_labels)
|
| 760 |
+
logits = model(input_ids, token_type_ids, input_mask)
|
| 761 |
+
```
|
| 762 |
+
"""
|
| 763 |
+
def __init__(self, config):
|
| 764 |
+
super(BertForSequenceEncoder, self).__init__(config)
|
| 765 |
+
self.bert = BertModel(config)
|
| 766 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 767 |
+
self.apply(self.init_bert_weights)
|
| 768 |
+
|
| 769 |
+
def forward(self, input_ids, attention_mask, token_type_ids):
|
| 770 |
+
output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
| 771 |
+
output = self.dropout(output)
|
| 772 |
+
pooled_output = self.dropout(pooled_output)
|
| 773 |
+
return output, pooled_output
|
| 774 |
+
|
| 775 |
+
|
utils/callbacks.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
| 9 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 10 |
+
|
| 11 |
+
from utils.utils_verbalisation_module import save_json
|
| 12 |
+
from pytorch_lightning.utilities import rank_zero_info
|
| 13 |
+
|
| 14 |
+
def count_trainable_parameters(model):
|
| 15 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 16 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 17 |
+
return params
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Seq2SeqLoggingCallback(pl.Callback):
|
| 25 |
+
def on_batch_end(self, trainer, pl_module):
|
| 26 |
+
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
| 27 |
+
pl_module.logger.log_metrics(lrs)
|
| 28 |
+
|
| 29 |
+
@rank_zero_only
|
| 30 |
+
def _write_logs(
|
| 31 |
+
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
| 32 |
+
) -> None:
|
| 33 |
+
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
| 34 |
+
metrics = trainer.callback_metrics
|
| 35 |
+
#print(metrics.keys())
|
| 36 |
+
new_metrics = {}
|
| 37 |
+
ms = ["log", "progress_bar", "preds"]
|
| 38 |
+
for k, v in metrics.items():
|
| 39 |
+
ver = True
|
| 40 |
+
for m in ms:
|
| 41 |
+
if m in k:
|
| 42 |
+
ver = False
|
| 43 |
+
break
|
| 44 |
+
if ver:
|
| 45 |
+
new_metrics[k] = v
|
| 46 |
+
|
| 47 |
+
print(new_metrics)
|
| 48 |
+
trainer.logger.log_metrics(new_metrics)
|
| 49 |
+
# Log results
|
| 50 |
+
od = Path(pl_module.hparams.output_dir)
|
| 51 |
+
if type_path == "test":
|
| 52 |
+
results_file = od / "test_results.txt"
|
| 53 |
+
generations_file = od / "test_generations.txt"
|
| 54 |
+
else:
|
| 55 |
+
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
| 56 |
+
# If people want this it will be easy enough to add back.
|
| 57 |
+
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
| 58 |
+
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
| 59 |
+
results_file.parent.mkdir(exist_ok=True)
|
| 60 |
+
generations_file.parent.mkdir(exist_ok=True)
|
| 61 |
+
with open(results_file, "a+") as writer:
|
| 62 |
+
for key in sorted(metrics):
|
| 63 |
+
if key in ["log", "progress_bar", "preds"]:
|
| 64 |
+
continue
|
| 65 |
+
try:
|
| 66 |
+
val = metrics[key]
|
| 67 |
+
if isinstance(val, torch.Tensor):
|
| 68 |
+
val = val.item()
|
| 69 |
+
msg = f"{key}: {val:.6f}\n"
|
| 70 |
+
writer.write(msg)
|
| 71 |
+
except:
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
if not save_generations:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
if "preds" in metrics:
|
| 78 |
+
content = "\n".join(metrics["preds"])
|
| 79 |
+
generations_file.open("w+").write(content)
|
| 80 |
+
|
| 81 |
+
@rank_zero_only
|
| 82 |
+
def on_train_start(self, trainer, pl_module):
|
| 83 |
+
try:
|
| 84 |
+
npars = pl_module.model.model.num_parameters()
|
| 85 |
+
except AttributeError:
|
| 86 |
+
npars = pl_module.model.num_parameters()
|
| 87 |
+
|
| 88 |
+
n_trainable_pars = count_trainable_parameters(pl_module)
|
| 89 |
+
# mp stands for million parameters
|
| 90 |
+
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
| 91 |
+
|
| 92 |
+
@rank_zero_only
|
| 93 |
+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
| 94 |
+
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
| 95 |
+
return self._write_logs(trainer, pl_module, "test")
|
| 96 |
+
|
| 97 |
+
@rank_zero_only
|
| 98 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
| 99 |
+
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
| 100 |
+
|
| 101 |
+
rank_zero_info("***** Validation results *****")
|
| 102 |
+
metrics = trainer.callback_metrics
|
| 103 |
+
# Log results
|
| 104 |
+
for key in sorted(metrics):
|
| 105 |
+
if key not in ["log", "progress_bar", "preds"]:
|
| 106 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
| 107 |
+
# Uncommenting this will save val generations
|
| 108 |
+
# return self._write_logs(trainer, pl_module, "valid")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
| 112 |
+
"""Saves the best model by validation ROUGE2 score."""
|
| 113 |
+
if metric == "rouge2":
|
| 114 |
+
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
| 115 |
+
elif metric == "bleu":
|
| 116 |
+
exp = "{val_avg_bleu:.4f}-{step_count}"
|
| 117 |
+
elif metric == "loss":
|
| 118 |
+
exp = "{val_avg_loss:.4f}-{step_count}"
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError(
|
| 121 |
+
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
checkpoint_callback = ModelCheckpoint(
|
| 125 |
+
filepath=os.path.join(output_dir, exp),
|
| 126 |
+
monitor=f"val_{metric}",
|
| 127 |
+
mode="min" if "loss" in metric else "max",
|
| 128 |
+
save_top_k=save_top_k,
|
| 129 |
+
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
| 130 |
+
)
|
| 131 |
+
return checkpoint_callback
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_early_stopping_callback(metric, patience):
|
| 135 |
+
return EarlyStopping(
|
| 136 |
+
monitor=f"val_{metric}", # does this need avg?
|
| 137 |
+
mode="min" if "loss" in metric else "max",
|
| 138 |
+
patience=patience,
|
| 139 |
+
verbose=True,
|
| 140 |
+
)
|
utils/file_utils.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for working with the local dataset cache.
|
| 3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
| 4 |
+
Copyright by the AllenNLP authors.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import tempfile
|
| 13 |
+
from functools import wraps
|
| 14 |
+
from hashlib import sha256
|
| 15 |
+
import sys
|
| 16 |
+
from io import open
|
| 17 |
+
|
| 18 |
+
import boto3
|
| 19 |
+
import requests
|
| 20 |
+
from botocore.exceptions import ClientError
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from urllib.parse import urlparse
|
| 25 |
+
except ImportError:
|
| 26 |
+
from urlparse import urlparse
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
| 31 |
+
Path.home() / '.pytorch_pretrained_bert'))
|
| 32 |
+
except AttributeError:
|
| 33 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
| 34 |
+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def url_to_filename(url, etag=None):
|
| 40 |
+
"""
|
| 41 |
+
Convert `url` into a hashed filename in a repeatable way.
|
| 42 |
+
If `etag` is specified, append its hash to the url's, delimited
|
| 43 |
+
by a period.
|
| 44 |
+
"""
|
| 45 |
+
url_bytes = url.encode('utf-8')
|
| 46 |
+
url_hash = sha256(url_bytes)
|
| 47 |
+
filename = url_hash.hexdigest()
|
| 48 |
+
|
| 49 |
+
if etag:
|
| 50 |
+
etag_bytes = etag.encode('utf-8')
|
| 51 |
+
etag_hash = sha256(etag_bytes)
|
| 52 |
+
filename += '.' + etag_hash.hexdigest()
|
| 53 |
+
|
| 54 |
+
return filename
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def filename_to_url(filename, cache_dir=None):
|
| 58 |
+
"""
|
| 59 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
| 60 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
| 61 |
+
"""
|
| 62 |
+
if cache_dir is None:
|
| 63 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
| 64 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 65 |
+
cache_dir = str(cache_dir)
|
| 66 |
+
|
| 67 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 68 |
+
if not os.path.exists(cache_path):
|
| 69 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
| 70 |
+
|
| 71 |
+
meta_path = cache_path + '.json'
|
| 72 |
+
if not os.path.exists(meta_path):
|
| 73 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
| 74 |
+
|
| 75 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
| 76 |
+
metadata = json.load(meta_file)
|
| 77 |
+
url = metadata['url']
|
| 78 |
+
etag = metadata['etag']
|
| 79 |
+
|
| 80 |
+
return url, etag
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def cached_path(url_or_filename, cache_dir=None):
|
| 84 |
+
"""
|
| 85 |
+
Given something that might be a URL (or might be a local path),
|
| 86 |
+
determine which. If it's a URL, download the file and cache it, and
|
| 87 |
+
return the path to the cached file. If it's already a local path,
|
| 88 |
+
make sure the file exists and then return the path.
|
| 89 |
+
"""
|
| 90 |
+
if cache_dir is None:
|
| 91 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
| 92 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
| 93 |
+
url_or_filename = str(url_or_filename)
|
| 94 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 95 |
+
cache_dir = str(cache_dir)
|
| 96 |
+
|
| 97 |
+
parsed = urlparse(url_or_filename)
|
| 98 |
+
|
| 99 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
| 100 |
+
# URL, so get it from the cache (downloading if necessary)
|
| 101 |
+
return get_from_cache(url_or_filename, cache_dir)
|
| 102 |
+
elif os.path.exists(url_or_filename):
|
| 103 |
+
# File, and it exists.
|
| 104 |
+
return url_or_filename
|
| 105 |
+
elif parsed.scheme == '':
|
| 106 |
+
# File, but it doesn't exist.
|
| 107 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
| 108 |
+
else:
|
| 109 |
+
# Something unknown
|
| 110 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def split_s3_path(url):
|
| 114 |
+
"""Split a full s3 path into the bucket name and path."""
|
| 115 |
+
parsed = urlparse(url)
|
| 116 |
+
if not parsed.netloc or not parsed.path:
|
| 117 |
+
raise ValueError("bad s3 path {}".format(url))
|
| 118 |
+
bucket_name = parsed.netloc
|
| 119 |
+
s3_path = parsed.path
|
| 120 |
+
# Remove '/' at beginning of path.
|
| 121 |
+
if s3_path.startswith("/"):
|
| 122 |
+
s3_path = s3_path[1:]
|
| 123 |
+
return bucket_name, s3_path
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def s3_request(func):
|
| 127 |
+
"""
|
| 128 |
+
Wrapper function for s3 requests in order to create more helpful error
|
| 129 |
+
messages.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
@wraps(func)
|
| 133 |
+
def wrapper(url, *args, **kwargs):
|
| 134 |
+
try:
|
| 135 |
+
return func(url, *args, **kwargs)
|
| 136 |
+
except ClientError as exc:
|
| 137 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
| 138 |
+
raise EnvironmentError("file {} not found".format(url))
|
| 139 |
+
else:
|
| 140 |
+
raise
|
| 141 |
+
|
| 142 |
+
return wrapper
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@s3_request
|
| 146 |
+
def s3_etag(url):
|
| 147 |
+
"""Check ETag on S3 object."""
|
| 148 |
+
s3_resource = boto3.resource("s3")
|
| 149 |
+
bucket_name, s3_path = split_s3_path(url)
|
| 150 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
| 151 |
+
return s3_object.e_tag
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@s3_request
|
| 155 |
+
def s3_get(url, temp_file):
|
| 156 |
+
"""Pull a file directly from S3."""
|
| 157 |
+
s3_resource = boto3.resource("s3")
|
| 158 |
+
bucket_name, s3_path = split_s3_path(url)
|
| 159 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def http_get(url, temp_file):
|
| 163 |
+
req = requests.get(url, stream=True)
|
| 164 |
+
content_length = req.headers.get('Content-Length')
|
| 165 |
+
total = int(content_length) if content_length is not None else None
|
| 166 |
+
progress = tqdm(unit="B", total=total)
|
| 167 |
+
for chunk in req.iter_content(chunk_size=1024):
|
| 168 |
+
if chunk: # filter out keep-alive new chunks
|
| 169 |
+
progress.update(len(chunk))
|
| 170 |
+
temp_file.write(chunk)
|
| 171 |
+
progress.close()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_from_cache(url, cache_dir=None):
|
| 175 |
+
"""
|
| 176 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
| 177 |
+
If it's not there, download it. Then return the path to the cached file.
|
| 178 |
+
"""
|
| 179 |
+
if cache_dir is None:
|
| 180 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
| 181 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 182 |
+
cache_dir = str(cache_dir)
|
| 183 |
+
|
| 184 |
+
if not os.path.exists(cache_dir):
|
| 185 |
+
os.makedirs(cache_dir)
|
| 186 |
+
|
| 187 |
+
# Get eTag to add to filename, if it exists.
|
| 188 |
+
if url.startswith("s3://"):
|
| 189 |
+
etag = s3_etag(url)
|
| 190 |
+
else:
|
| 191 |
+
response = requests.head(url, allow_redirects=True)
|
| 192 |
+
if response.status_code != 200:
|
| 193 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
| 194 |
+
.format(url, response.status_code))
|
| 195 |
+
etag = response.headers.get("ETag")
|
| 196 |
+
|
| 197 |
+
filename = url_to_filename(url, etag)
|
| 198 |
+
|
| 199 |
+
# get cache path to put the file
|
| 200 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 201 |
+
|
| 202 |
+
if not os.path.exists(cache_path):
|
| 203 |
+
# Download to temporary file, then copy to cache dir once finished.
|
| 204 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
| 205 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
| 206 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
| 207 |
+
|
| 208 |
+
# GET file object
|
| 209 |
+
if url.startswith("s3://"):
|
| 210 |
+
s3_get(url, temp_file)
|
| 211 |
+
else:
|
| 212 |
+
http_get(url, temp_file)
|
| 213 |
+
|
| 214 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
| 215 |
+
temp_file.flush()
|
| 216 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
| 217 |
+
temp_file.seek(0)
|
| 218 |
+
|
| 219 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
| 220 |
+
with open(cache_path, 'wb') as cache_file:
|
| 221 |
+
shutil.copyfileobj(temp_file, cache_file)
|
| 222 |
+
|
| 223 |
+
logger.info("creating metadata file for %s", cache_path)
|
| 224 |
+
meta = {'url': url, 'etag': etag}
|
| 225 |
+
meta_path = cache_path + '.json'
|
| 226 |
+
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
| 227 |
+
json.dump(meta, meta_file)
|
| 228 |
+
|
| 229 |
+
logger.info("removing temp file %s", temp_file.name)
|
| 230 |
+
|
| 231 |
+
return cache_path
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def read_set_from_file(filename):
|
| 235 |
+
'''
|
| 236 |
+
Extract a de-duped collection (set) of text from a file.
|
| 237 |
+
Expected file format is one item per line.
|
| 238 |
+
'''
|
| 239 |
+
collection = set()
|
| 240 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
| 241 |
+
for line in file_:
|
| 242 |
+
collection.add(line.rstrip())
|
| 243 |
+
return collection
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_file_extension(path, dot=True, lower=True):
|
| 247 |
+
ext = os.path.splitext(path)[1]
|
| 248 |
+
ext = ext if dot else ext[1:]
|
| 249 |
+
return ext.lower() if lower else ext
|
utils/finetune.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import glob
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Tuple
|
| 12 |
+
import pdb
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pytorch_lightning as pl
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
from pytorch_lightning.utilities import rank_zero_info
|
| 20 |
+
|
| 21 |
+
from utils.callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
| 22 |
+
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
| 25 |
+
from utils.utils_verbalisation_module import (
|
| 26 |
+
ROUGE_KEYS,
|
| 27 |
+
LegacySeq2SeqDataset,
|
| 28 |
+
Seq2SeqDataset,
|
| 29 |
+
assert_all_frozen,
|
| 30 |
+
calculate_bleu,
|
| 31 |
+
calculate_rouge,
|
| 32 |
+
flatten_list,
|
| 33 |
+
freeze_embeds,
|
| 34 |
+
freeze_params,
|
| 35 |
+
label_smoothed_nll_loss,
|
| 36 |
+
lmap,
|
| 37 |
+
pickle_save,
|
| 38 |
+
save_json,
|
| 39 |
+
use_task_specific_params,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
from utils.utils_graph2text import convert_text, eval_meteor, eval_bleu, eval_chrf, eval_meteor_test_webnlg, eval_chrf_test_webnlg
|
| 43 |
+
|
| 44 |
+
# need the parent dir module
|
| 45 |
+
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
| 46 |
+
from utils.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SummarizationModule(BaseTransformer):
|
| 53 |
+
mode = "summarization"
|
| 54 |
+
loss_names = ["loss"]
|
| 55 |
+
metric_names = ROUGE_KEYS
|
| 56 |
+
default_val_metric = "rouge2"
|
| 57 |
+
|
| 58 |
+
def __init__(self, hparams, **kwargs):
|
| 59 |
+
if hparams.sortish_sampler and hparams.gpus > 1:
|
| 60 |
+
hparams.replace_sampler_ddp = False
|
| 61 |
+
elif hparams.max_tokens_per_batch is not None:
|
| 62 |
+
if hparams.gpus > 1:
|
| 63 |
+
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
|
| 64 |
+
if hparams.sortish_sampler:
|
| 65 |
+
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
|
| 66 |
+
|
| 67 |
+
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
| 68 |
+
#use_task_specific_params(self.model, "summarization")
|
| 69 |
+
|
| 70 |
+
self.metrics_save_path = Path('base') / "metrics.json"
|
| 71 |
+
self.hparams_save_path = Path('base') / "hparams.pkl"
|
| 72 |
+
pickle_save(self.hparams, self.hparams_save_path)
|
| 73 |
+
self.step_count = -2
|
| 74 |
+
self.metrics = defaultdict(list)
|
| 75 |
+
self.model_type = self.config.model_type
|
| 76 |
+
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
| 77 |
+
|
| 78 |
+
if 't5' in hparams.model_name_or_path:
|
| 79 |
+
self.model.config.prefix = 'translate Graph to English: '
|
| 80 |
+
self.dataset_kwargs: dict = dict(
|
| 81 |
+
data_dir=self.hparams.data_dir,
|
| 82 |
+
max_source_length=self.hparams.max_source_length,
|
| 83 |
+
prefix=self.model.config.prefix or "",
|
| 84 |
+
)
|
| 85 |
+
n_observations_per_split = {
|
| 86 |
+
"train": self.hparams.n_train,
|
| 87 |
+
"val": self.hparams.n_val,
|
| 88 |
+
"test_seen": self.hparams.n_test,
|
| 89 |
+
"test_unseen": self.hparams.n_test,
|
| 90 |
+
"test_both": self.hparams.n_test,
|
| 91 |
+
}
|
| 92 |
+
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
| 93 |
+
|
| 94 |
+
self.target_lens = {
|
| 95 |
+
"train": self.hparams.max_target_length,
|
| 96 |
+
"val": self.hparams.val_max_target_length,
|
| 97 |
+
"test_seen": self.hparams.test_max_target_length,
|
| 98 |
+
"test_unseen": self.hparams.test_max_target_length,
|
| 99 |
+
"test_both": self.hparams.test_max_target_length,
|
| 100 |
+
}
|
| 101 |
+
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
| 102 |
+
assert self.target_lens["train"] <= self.target_lens["test_both"], f"target_lens: {self.target_lens}"
|
| 103 |
+
if self.hparams.freeze_embeds:
|
| 104 |
+
freeze_embeds(self.model)
|
| 105 |
+
if self.hparams.freeze_encoder:
|
| 106 |
+
freeze_params(self.model.get_encoder())
|
| 107 |
+
assert_all_frozen(self.model.get_encoder())
|
| 108 |
+
|
| 109 |
+
self.num_workers = hparams.num_workers
|
| 110 |
+
self.decoder_start_token_id = None # default to config
|
| 111 |
+
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
| 112 |
+
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
| 113 |
+
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
| 114 |
+
self.dataset_class = (
|
| 115 |
+
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
| 116 |
+
)
|
| 117 |
+
self.already_saved_batch = False
|
| 118 |
+
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
| 119 |
+
if self.hparams.eval_max_gen_length is not None:
|
| 120 |
+
self.eval_max_length = self.hparams.eval_max_gen_length
|
| 121 |
+
else:
|
| 122 |
+
self.eval_max_length = self.model.config.max_length
|
| 123 |
+
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
| 124 |
+
|
| 125 |
+
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
| 126 |
+
"""A debugging utility"""
|
| 127 |
+
|
| 128 |
+
readable_batch = {
|
| 129 |
+
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
|
| 130 |
+
}
|
| 131 |
+
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
|
| 132 |
+
|
| 133 |
+
tb = {}
|
| 134 |
+
for k, v in batch.items():
|
| 135 |
+
tb[k] = v.tolist()
|
| 136 |
+
|
| 137 |
+
save_json(tb, Path(self.output_dir) / "tok_batch.json")
|
| 138 |
+
|
| 139 |
+
self.already_saved_batch = True
|
| 140 |
+
return readable_batch
|
| 141 |
+
|
| 142 |
+
def forward(self, input_ids, **kwargs):
|
| 143 |
+
return self.model(input_ids, **kwargs)
|
| 144 |
+
|
| 145 |
+
def ids_to_clean_text(self, generated_ids: List[int]):
|
| 146 |
+
gen_text = self.tokenizer.batch_decode(
|
| 147 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
| 148 |
+
)
|
| 149 |
+
return lmap(str.strip, gen_text)
|
| 150 |
+
|
| 151 |
+
def _step(self, batch: dict) -> Tuple:
|
| 152 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 153 |
+
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
|
| 154 |
+
if isinstance(self.model, T5ForConditionalGeneration):
|
| 155 |
+
tgt_ids = batch["labels"]
|
| 156 |
+
decoder_input_ids = self.model._shift_right(tgt_ids)
|
| 157 |
+
else:
|
| 158 |
+
#decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
| 159 |
+
y = batch["labels"]
|
| 160 |
+
decoder_input_ids = y[:, :-1].contiguous()
|
| 161 |
+
tgt_ids = y[:, 1:].clone()
|
| 162 |
+
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
|
| 163 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
| 164 |
+
self.save_readable_batch(batch)
|
| 165 |
+
|
| 166 |
+
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
| 167 |
+
lm_logits = outputs[0]
|
| 168 |
+
if self.hparams.label_smoothing == 0:
|
| 169 |
+
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
| 170 |
+
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
| 171 |
+
|
| 172 |
+
assert lm_logits.shape[-1] == self.vocab_size
|
| 173 |
+
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
| 174 |
+
else:
|
| 175 |
+
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
| 176 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
| 177 |
+
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
| 178 |
+
)
|
| 179 |
+
return (loss,)
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def pad(self) -> int:
|
| 183 |
+
return self.tokenizer.pad_token_id
|
| 184 |
+
|
| 185 |
+
def training_step(self, batch, batch_idx) -> Dict:
|
| 186 |
+
loss_tensors = self._step(batch)
|
| 187 |
+
|
| 188 |
+
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
| 189 |
+
# tokens per batch
|
| 190 |
+
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
| 191 |
+
logs["bs"] = batch["input_ids"].shape[0]
|
| 192 |
+
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
|
| 193 |
+
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
|
| 194 |
+
# TODO(SS): make a wandb summary metric for this
|
| 195 |
+
return {"loss": loss_tensors[0], "log": logs}
|
| 196 |
+
|
| 197 |
+
def validation_step(self, batch, batch_idx) -> Dict:
|
| 198 |
+
return self._generative_step(batch)
|
| 199 |
+
|
| 200 |
+
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
| 201 |
+
|
| 202 |
+
self.step_count += 1
|
| 203 |
+
|
| 204 |
+
val_outputs_folder = "val_outputs"
|
| 205 |
+
os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
|
| 206 |
+
|
| 207 |
+
if prefix == "val":
|
| 208 |
+
output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_predictions_" +
|
| 209 |
+
str(self.step_count) + ".txt")
|
| 210 |
+
output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_targets_" +
|
| 211 |
+
str(self.step_count) + ".txt")
|
| 212 |
+
# write predictions and targets for later rouge evaluation.
|
| 213 |
+
with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
|
| 214 |
+
for output_batch in outputs:
|
| 215 |
+
p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
|
| 216 |
+
t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
|
| 217 |
+
p_writer.close()
|
| 218 |
+
t_writer.close()
|
| 219 |
+
|
| 220 |
+
bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, 'val')
|
| 221 |
+
|
| 222 |
+
rank_zero_info("%s bleu_info: %s", self.step_count, bleu_info)
|
| 223 |
+
|
| 224 |
+
if bleu_info == -1:
|
| 225 |
+
bleu_info = float(bleu_info)
|
| 226 |
+
else:
|
| 227 |
+
bleu_info = float(bleu_info.split(",")[0].split("BLEU = ")[1])
|
| 228 |
+
|
| 229 |
+
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
| 230 |
+
loss = losses["loss"]
|
| 231 |
+
generative_metrics = {
|
| 232 |
+
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
generative_metrics['bleu'] = bleu_info
|
| 236 |
+
|
| 237 |
+
metric_val = (
|
| 238 |
+
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
|
| 239 |
+
self.val_metric]
|
| 240 |
+
)
|
| 241 |
+
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
| 242 |
+
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
| 243 |
+
losses.update(generative_metrics)
|
| 244 |
+
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
| 245 |
+
all_metrics["step_count"] = self.step_count
|
| 246 |
+
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
| 247 |
+
preds = flatten_list([x["preds"] for x in outputs])
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"bleu": bleu_info,
|
| 251 |
+
"log": all_metrics,
|
| 252 |
+
"preds": preds,
|
| 253 |
+
f"{prefix}_loss": loss,
|
| 254 |
+
f"{prefix}_{self.val_metric}": metric_tensor,
|
| 255 |
+
}
|
| 256 |
+
else:
|
| 257 |
+
|
| 258 |
+
data_logs = {}
|
| 259 |
+
for output in outputs:
|
| 260 |
+
|
| 261 |
+
dataset_idx = output[0]['dataloader_idx']
|
| 262 |
+
|
| 263 |
+
if dataset_idx == 0:
|
| 264 |
+
dataset_name = 'test_both'
|
| 265 |
+
elif dataset_idx == 1:
|
| 266 |
+
dataset_name = 'test_seen'
|
| 267 |
+
else:
|
| 268 |
+
dataset_name = 'test_unseen'
|
| 269 |
+
|
| 270 |
+
if output[0]['bleu'] == -1:
|
| 271 |
+
bleu_info = float(output[0]['bleu'])
|
| 272 |
+
else:
|
| 273 |
+
bleu_info = float(output[0]['bleu'].split(",")[0].split("BLEU = ")[1])
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
losses = {k: torch.stack([x[k] for x in output]).mean() for k in self.loss_names}
|
| 277 |
+
loss = losses["loss"]
|
| 278 |
+
generative_metrics = {
|
| 279 |
+
k: np.array([x[k] for x in output]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
generative_metrics['bleu'] = bleu_info
|
| 283 |
+
|
| 284 |
+
metric_val = (
|
| 285 |
+
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
|
| 286 |
+
self.val_metric]
|
| 287 |
+
)
|
| 288 |
+
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
| 289 |
+
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
| 290 |
+
losses.update(generative_metrics)
|
| 291 |
+
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
| 292 |
+
all_metrics["step_count"] = self.step_count
|
| 293 |
+
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
| 294 |
+
preds = flatten_list([x["preds"] for x in output])
|
| 295 |
+
|
| 296 |
+
data_logs.update({
|
| 297 |
+
"log" + "_" + dataset_name: all_metrics,
|
| 298 |
+
"preds" + "_" + dataset_name: preds,
|
| 299 |
+
f"{prefix}_loss" + "_" + dataset_name: loss,
|
| 300 |
+
f"{prefix}_{self.val_metric}" + "_" + dataset_name: metric_tensor,
|
| 301 |
+
})
|
| 302 |
+
return data_logs
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
#######
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def calc_generative_metrics(self, preds, target) -> Dict:
|
| 311 |
+
return calculate_rouge(preds, target)
|
| 312 |
+
|
| 313 |
+
def _generative_step(self, batch: dict, batch_idx=None, dataloader_idx=None) -> dict:
|
| 314 |
+
t0 = time.time()
|
| 315 |
+
|
| 316 |
+
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
|
| 317 |
+
generated_ids = self.model.generate(
|
| 318 |
+
batch["input_ids"],
|
| 319 |
+
attention_mask=batch["attention_mask"],
|
| 320 |
+
use_cache=True,
|
| 321 |
+
decoder_start_token_id=self.decoder_start_token_id,
|
| 322 |
+
num_beams=self.eval_beams,
|
| 323 |
+
max_length=self.eval_max_length,
|
| 324 |
+
length_penalty=1.0
|
| 325 |
+
)
|
| 326 |
+
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
| 327 |
+
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
| 328 |
+
target: List[str] = self.ids_to_clean_text(batch["labels"])
|
| 329 |
+
loss_tensors = self._step(batch)
|
| 330 |
+
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
| 331 |
+
rouge: Dict = self.calc_generative_metrics(preds, target)
|
| 332 |
+
summ_len = np.mean(lmap(len, generated_ids))
|
| 333 |
+
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
|
| 334 |
+
|
| 335 |
+
if dataloader_idx is not None:
|
| 336 |
+
base_metrics.update(batch_idx=batch_idx, dataloader_idx=dataloader_idx)
|
| 337 |
+
return base_metrics
|
| 338 |
+
|
| 339 |
+
def test_step(self, batch, batch_idx, dataloader_idx):
|
| 340 |
+
return self._generative_step(batch, batch_idx, dataloader_idx)
|
| 341 |
+
|
| 342 |
+
def test_epoch_end(self, outputs_all_testsets):
|
| 343 |
+
|
| 344 |
+
val_outputs_folder = "val_outputs"
|
| 345 |
+
os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
|
| 346 |
+
|
| 347 |
+
for outputs in outputs_all_testsets:
|
| 348 |
+
dataset_idx = outputs[0]['dataloader_idx']
|
| 349 |
+
|
| 350 |
+
if dataset_idx == 0:
|
| 351 |
+
file_name = "test_both_predictions.txt"
|
| 352 |
+
file_name_tgt = "test_both_targets.txt"
|
| 353 |
+
dataset_name = 'test_both'
|
| 354 |
+
elif dataset_idx == 1:
|
| 355 |
+
file_name = "test_seen_predictions.txt"
|
| 356 |
+
file_name_tgt = "test_seen_targets.txt"
|
| 357 |
+
dataset_name = 'test_seen'
|
| 358 |
+
else:
|
| 359 |
+
file_name = "test_unseen_predictions.txt"
|
| 360 |
+
file_name_tgt = "test_unseen_targets.txt"
|
| 361 |
+
dataset_name = 'test_unseen'
|
| 362 |
+
|
| 363 |
+
file_name += '.debug'
|
| 364 |
+
file_name_tgt += '.debug'
|
| 365 |
+
|
| 366 |
+
output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name)
|
| 367 |
+
output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name_tgt)
|
| 368 |
+
# write predictions and targets for later rouge evaluation.
|
| 369 |
+
with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
|
| 370 |
+
for output_batch in outputs:
|
| 371 |
+
|
| 372 |
+
p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
|
| 373 |
+
t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
|
| 374 |
+
p_writer.close()
|
| 375 |
+
t_writer.close()
|
| 376 |
+
|
| 377 |
+
bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
| 378 |
+
meteor_info = eval_meteor_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
| 379 |
+
chrf_info = eval_chrf_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
| 380 |
+
|
| 381 |
+
rank_zero_info(" %s - bleu_info: %s", dataset_name, bleu_info)
|
| 382 |
+
rank_zero_info(" %s - meteor_info: %s", dataset_name, meteor_info)
|
| 383 |
+
rank_zero_info(" %s - chrf_info: %s", dataset_name, chrf_info)
|
| 384 |
+
|
| 385 |
+
outputs[0]['bleu'] = bleu_info
|
| 386 |
+
|
| 387 |
+
return self.validation_epoch_end(outputs_all_testsets, prefix="test")
|
| 388 |
+
|
| 389 |
+
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
| 390 |
+
n_obs = self.n_obs[type_path]
|
| 391 |
+
max_target_length = self.target_lens[type_path]
|
| 392 |
+
dataset = self.dataset_class(
|
| 393 |
+
self.tokenizer,
|
| 394 |
+
type_path=type_path,
|
| 395 |
+
n_obs=n_obs,
|
| 396 |
+
max_target_length=max_target_length,
|
| 397 |
+
**self.dataset_kwargs,
|
| 398 |
+
)
|
| 399 |
+
return dataset
|
| 400 |
+
|
| 401 |
+
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
| 402 |
+
dataset = self.get_dataset(type_path)
|
| 403 |
+
|
| 404 |
+
if self.hparams.sortish_sampler and type_path != "test":
|
| 405 |
+
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
| 406 |
+
return DataLoader(
|
| 407 |
+
dataset,
|
| 408 |
+
batch_size=batch_size,
|
| 409 |
+
collate_fn=dataset.collate_fn,
|
| 410 |
+
shuffle=False,
|
| 411 |
+
num_workers=self.num_workers,
|
| 412 |
+
sampler=sampler,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
| 416 |
+
batch_sampler = dataset.make_dynamic_sampler(
|
| 417 |
+
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
| 418 |
+
)
|
| 419 |
+
return DataLoader(
|
| 420 |
+
dataset,
|
| 421 |
+
batch_sampler=batch_sampler,
|
| 422 |
+
collate_fn=dataset.collate_fn,
|
| 423 |
+
# shuffle=False,
|
| 424 |
+
num_workers=self.num_workers,
|
| 425 |
+
# batch_size=None,
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
return DataLoader(
|
| 429 |
+
dataset,
|
| 430 |
+
batch_size=batch_size,
|
| 431 |
+
collate_fn=dataset.collate_fn,
|
| 432 |
+
shuffle=shuffle,
|
| 433 |
+
num_workers=self.num_workers,
|
| 434 |
+
sampler=None,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def train_dataloader(self) -> DataLoader:
|
| 438 |
+
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
| 439 |
+
return dataloader
|
| 440 |
+
|
| 441 |
+
def val_dataloader(self) -> DataLoader:
|
| 442 |
+
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
| 443 |
+
|
| 444 |
+
def test_dataloader(self) -> List[DataLoader]:
|
| 445 |
+
test_dataloader = self.get_dataloader("test_both", batch_size=self.hparams.eval_batch_size)
|
| 446 |
+
test_seen_dataloader = self.get_dataloader("test_seen", batch_size=self.hparams.eval_batch_size)
|
| 447 |
+
test_unseen_dataloader = self.get_dataloader("test_unseen", batch_size=self.hparams.eval_batch_size)
|
| 448 |
+
|
| 449 |
+
return [test_dataloader, test_seen_dataloader, test_unseen_dataloader]
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def add_model_specific_args(parser, root_dir):
|
| 453 |
+
BaseTransformer.add_model_specific_args(parser, root_dir)
|
| 454 |
+
add_generic_args(parser, root_dir)
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--max_source_length",
|
| 457 |
+
default=1024,
|
| 458 |
+
type=int,
|
| 459 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
| 460 |
+
"than this will be truncated, sequences shorter will be padded.",
|
| 461 |
+
)
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--max_target_length",
|
| 464 |
+
default=56,
|
| 465 |
+
type=int,
|
| 466 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
| 467 |
+
"than this will be truncated, sequences shorter will be padded.",
|
| 468 |
+
)
|
| 469 |
+
parser.add_argument(
|
| 470 |
+
"--val_max_target_length",
|
| 471 |
+
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
| 472 |
+
type=int,
|
| 473 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
| 474 |
+
"than this will be truncated, sequences shorter will be padded.",
|
| 475 |
+
)
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--test_max_target_length",
|
| 478 |
+
default=142,
|
| 479 |
+
type=int,
|
| 480 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
| 481 |
+
"than this will be truncated, sequences shorter will be padded.",
|
| 482 |
+
)
|
| 483 |
+
parser.add_argument("--freeze_encoder", action="store_true")
|
| 484 |
+
parser.add_argument("--freeze_embeds", action="store_true")
|
| 485 |
+
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
| 486 |
+
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
| 487 |
+
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
| 488 |
+
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
| 489 |
+
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
| 490 |
+
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
| 491 |
+
parser.add_argument(
|
| 492 |
+
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
| 495 |
+
parser.add_argument("--src_lang", type=str, default="", required=False)
|
| 496 |
+
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
| 497 |
+
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
| 498 |
+
parser.add_argument("--checkpoint", type=str, default=None, required=False)
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
| 501 |
+
)
|
| 502 |
+
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
| 503 |
+
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
| 504 |
+
parser.add_argument(
|
| 505 |
+
"--early_stopping_patience",
|
| 506 |
+
type=int,
|
| 507 |
+
default=-1,
|
| 508 |
+
required=False,
|
| 509 |
+
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return parser
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class TranslationModule(SummarizationModule):
|
| 516 |
+
mode = "translation"
|
| 517 |
+
loss_names = ["loss"]
|
| 518 |
+
metric_names = ["bleu"]
|
| 519 |
+
default_val_metric = "bleu"
|
| 520 |
+
|
| 521 |
+
def __init__(self, hparams, **kwargs):
|
| 522 |
+
super().__init__(hparams, **kwargs)
|
| 523 |
+
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
| 524 |
+
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
| 525 |
+
|
| 526 |
+
def calc_generative_metrics(self, preds, target) -> dict:
|
| 527 |
+
return calculate_bleu(preds, target)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class Graph2TextModule(SummarizationModule):
|
| 531 |
+
mode = "graph2text"
|
| 532 |
+
loss_names = ["loss"]
|
| 533 |
+
metric_names = ["sacrebleu"]
|
| 534 |
+
default_val_metric = "bleu"
|
| 535 |
+
|
| 536 |
+
def __init__(self, hparams, **kwargs):
|
| 537 |
+
if type(hparams) == dict:
|
| 538 |
+
hparams = argparse.Namespace(**hparams)
|
| 539 |
+
print(f'Graph2Text hparams are: {hparams}')
|
| 540 |
+
super().__init__(hparams, **kwargs)
|
| 541 |
+
|
| 542 |
+
self.hparams.update(vars(hparams))
|
| 543 |
+
|
| 544 |
+
rank_zero_info("parameters %s", hparams)
|
| 545 |
+
|
| 546 |
+
def calc_generative_metrics(self, preds, target) -> dict:
|
| 547 |
+
return calculate_bleu(preds, target)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def main(args, model=None) -> SummarizationModule:
|
| 551 |
+
Path(args.output_dir).mkdir(exist_ok=True)
|
| 552 |
+
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
| 553 |
+
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
| 554 |
+
if model is None:
|
| 555 |
+
if "summarization" in args.task:
|
| 556 |
+
model: SummarizationModule = SummarizationModule(args)
|
| 557 |
+
elif "translation" in args.task:
|
| 558 |
+
model: SummarizationModule = TranslationModule(args)
|
| 559 |
+
else:
|
| 560 |
+
model: SummarizationModule = Graph2TextModule(args)
|
| 561 |
+
dataset = Path(args.data_dir).name
|
| 562 |
+
if (
|
| 563 |
+
args.logger_name == "default"
|
| 564 |
+
or args.fast_dev_run
|
| 565 |
+
or str(args.output_dir).startswith("/tmp")
|
| 566 |
+
or str(args.output_dir).startswith("/var")
|
| 567 |
+
):
|
| 568 |
+
logger = True # don't pollute wandb logs unnecessarily
|
| 569 |
+
elif args.logger_name == "wandb":
|
| 570 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 571 |
+
|
| 572 |
+
project = os.environ.get("WANDB_PROJECT", dataset)
|
| 573 |
+
logger = WandbLogger(name=model.output_dir.name, project=project)
|
| 574 |
+
|
| 575 |
+
elif args.logger_name == "wandb_shared":
|
| 576 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 577 |
+
|
| 578 |
+
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
| 579 |
+
|
| 580 |
+
if args.early_stopping_patience >= 0:
|
| 581 |
+
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
| 582 |
+
else:
|
| 583 |
+
es_callback = False
|
| 584 |
+
|
| 585 |
+
lower_is_better = args.val_metric == "loss"
|
| 586 |
+
trainer: pl.Trainer = generic_train(
|
| 587 |
+
model,
|
| 588 |
+
args,
|
| 589 |
+
logging_callback=Seq2SeqLoggingCallback(),
|
| 590 |
+
checkpoint_callback=get_checkpoint_callback(
|
| 591 |
+
args.output_dir, model.val_metric, args.save_top_k, lower_is_better
|
| 592 |
+
),
|
| 593 |
+
early_stopping_callback=es_callback,
|
| 594 |
+
logger=logger,
|
| 595 |
+
)
|
| 596 |
+
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
| 597 |
+
if not args.do_predict:
|
| 598 |
+
return model
|
| 599 |
+
|
| 600 |
+
model.hparams.test_checkpoint = ""
|
| 601 |
+
if not args.checkpoint:
|
| 602 |
+
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
| 603 |
+
else:
|
| 604 |
+
checkpoints = [args.checkpoint]
|
| 605 |
+
|
| 606 |
+
if checkpoints:
|
| 607 |
+
model.hparams.test_checkpoint = checkpoints[-1]
|
| 608 |
+
trainer.resume_from_checkpoint = checkpoints[-1]
|
| 609 |
+
|
| 610 |
+
if args.do_predict and not args.do_train:
|
| 611 |
+
|
| 612 |
+
checkpoint = checkpoints[-1]
|
| 613 |
+
print(checkpoint)
|
| 614 |
+
#trainer.test(ckpt_path=checkpoints[-1])
|
| 615 |
+
trainer.test(model, ckpt_path=checkpoint)
|
| 616 |
+
return model
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
trainer.logger.log_hyperparams(model.hparams)
|
| 620 |
+
|
| 621 |
+
# test() without a model tests using the best checkpoint automatically
|
| 622 |
+
trainer.test()
|
| 623 |
+
return model
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
if __name__ == "__main__":
|
| 627 |
+
parser = argparse.ArgumentParser()
|
| 628 |
+
parser = pl.Trainer.add_argparse_args(parser)
|
| 629 |
+
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
| 630 |
+
|
| 631 |
+
args = parser.parse_args()
|
| 632 |
+
|
| 633 |
+
main(args)
|
utils/lightning_base.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
import sys
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
from pytorch_lightning.utilities import rank_zero_info
|
| 9 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
| 10 |
+
|
| 11 |
+
from transformers import (
|
| 12 |
+
AdamW,
|
| 13 |
+
AutoConfig,
|
| 14 |
+
AutoModel,
|
| 15 |
+
AutoModelForPreTraining,
|
| 16 |
+
AutoModelForQuestionAnswering,
|
| 17 |
+
AutoModelForSeq2SeqLM,
|
| 18 |
+
AutoModelForSequenceClassification,
|
| 19 |
+
AutoModelForTokenClassification,
|
| 20 |
+
AutoModelWithLMHead,
|
| 21 |
+
AutoTokenizer,
|
| 22 |
+
PretrainedConfig,
|
| 23 |
+
PreTrainedTokenizer,
|
| 24 |
+
)
|
| 25 |
+
from transformers.optimization import (
|
| 26 |
+
Adafactor,
|
| 27 |
+
get_cosine_schedule_with_warmup,
|
| 28 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 29 |
+
get_linear_schedule_with_warmup,
|
| 30 |
+
get_polynomial_decay_schedule_with_warmup,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from tokenizers import AddedToken
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
MODEL_MODES = {
|
| 38 |
+
"base": AutoModel,
|
| 39 |
+
"sequence-classification": AutoModelForSequenceClassification,
|
| 40 |
+
"question-answering": AutoModelForQuestionAnswering,
|
| 41 |
+
"pretraining": AutoModelForPreTraining,
|
| 42 |
+
"token-classification": AutoModelForTokenClassification,
|
| 43 |
+
"language-modeling": AutoModelWithLMHead,
|
| 44 |
+
"summarization": AutoModelForSeq2SeqLM,
|
| 45 |
+
"translation": AutoModelForSeq2SeqLM,
|
| 46 |
+
"graph2text": AutoModelForSeq2SeqLM,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# update this and the import above to support new schedulers from transformers.optimization
|
| 51 |
+
arg_to_scheduler = {
|
| 52 |
+
"linear": get_linear_schedule_with_warmup,
|
| 53 |
+
"cosine": get_cosine_schedule_with_warmup,
|
| 54 |
+
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 55 |
+
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
| 56 |
+
# '': get_constant_schedule, # not supported for now
|
| 57 |
+
# '': get_constant_schedule_with_warmup, # not supported for now
|
| 58 |
+
}
|
| 59 |
+
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
| 60 |
+
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BaseTransformer(pl.LightningModule):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
hparams: argparse.Namespace,
|
| 67 |
+
num_labels=None,
|
| 68 |
+
mode="base",
|
| 69 |
+
config=None,
|
| 70 |
+
tokenizer=None,
|
| 71 |
+
model=None,
|
| 72 |
+
**config_kwargs
|
| 73 |
+
):
|
| 74 |
+
"""Initialize a model, tokenizer and config."""
|
| 75 |
+
super().__init__()
|
| 76 |
+
# TODO: move to self.save_hyperparameters()
|
| 77 |
+
# self.save_hyperparameters()
|
| 78 |
+
# can also expand arguments into trainer signature for easier reading
|
| 79 |
+
self.save_hyperparameters(hparams)
|
| 80 |
+
self.step_count = -2
|
| 81 |
+
self.output_dir = Path(self.hparams.output_dir)
|
| 82 |
+
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
| 83 |
+
if config is None:
|
| 84 |
+
self.config = AutoConfig.from_pretrained(
|
| 85 |
+
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
| 86 |
+
**({"num_labels": num_labels} if num_labels is not None else {}),
|
| 87 |
+
cache_dir=cache_dir,
|
| 88 |
+
**config_kwargs,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
self.config: PretrainedConfig = config
|
| 92 |
+
|
| 93 |
+
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
| 94 |
+
for p in extra_model_params:
|
| 95 |
+
if getattr(self.hparams, p, None):
|
| 96 |
+
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
| 97 |
+
setattr(self.config, p, getattr(self.hparams, p))
|
| 98 |
+
|
| 99 |
+
if tokenizer is None:
|
| 100 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 101 |
+
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
| 102 |
+
cache_dir=cache_dir,
|
| 103 |
+
)
|
| 104 |
+
new_tokens = [
|
| 105 |
+
'<H>','<R>','<T>'
|
| 106 |
+
]
|
| 107 |
+
new_tokens_vocab = {}
|
| 108 |
+
new_tokens_vocab['additional_special_tokens'] = []
|
| 109 |
+
for idx, t in enumerate(new_tokens):
|
| 110 |
+
new_tokens_vocab['additional_special_tokens'].append(t)
|
| 111 |
+
num_added_toks = self.tokenizer.add_special_tokens(new_tokens_vocab)
|
| 112 |
+
rank_zero_info('We have added %s tokens', num_added_toks)
|
| 113 |
+
else:
|
| 114 |
+
self.tokenizer: PreTrainedTokenizer = tokenizer
|
| 115 |
+
self.model_type = MODEL_MODES[mode]
|
| 116 |
+
if model is None:
|
| 117 |
+
self.model = self.model_type.from_pretrained(
|
| 118 |
+
self.hparams.model_name_or_path,
|
| 119 |
+
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
| 120 |
+
config=self.config,
|
| 121 |
+
cache_dir=cache_dir,
|
| 122 |
+
)
|
| 123 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 124 |
+
else:
|
| 125 |
+
self.model = model
|
| 126 |
+
|
| 127 |
+
def load_hf_checkpoint(self, *args, **kwargs):
|
| 128 |
+
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
| 129 |
+
|
| 130 |
+
def get_lr_scheduler(self):
|
| 131 |
+
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
| 132 |
+
scheduler = get_schedule_func(
|
| 133 |
+
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
| 134 |
+
)
|
| 135 |
+
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
| 136 |
+
return scheduler
|
| 137 |
+
|
| 138 |
+
def configure_optimizers(self):
|
| 139 |
+
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
| 140 |
+
model = self.model
|
| 141 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
| 142 |
+
optimizer_grouped_parameters = [
|
| 143 |
+
{
|
| 144 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 145 |
+
"weight_decay": self.hparams.weight_decay,
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
| 149 |
+
"weight_decay": 0.0,
|
| 150 |
+
},
|
| 151 |
+
]
|
| 152 |
+
if self.hparams.adafactor:
|
| 153 |
+
optimizer = Adafactor(
|
| 154 |
+
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
optimizer = AdamW(
|
| 159 |
+
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
| 160 |
+
)
|
| 161 |
+
self.opt = optimizer
|
| 162 |
+
|
| 163 |
+
scheduler = self.get_lr_scheduler()
|
| 164 |
+
|
| 165 |
+
return [optimizer], [scheduler]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_step(self, batch, batch_nb):
|
| 169 |
+
return self.validation_step(batch, batch_nb)
|
| 170 |
+
|
| 171 |
+
def test_epoch_end(self, outputs):
|
| 172 |
+
return self.validation_end(outputs)
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def total_steps(self) -> int:
|
| 176 |
+
# print('self.hparams.gpus', self.hparams.gpus)
|
| 177 |
+
# print('self.hparams.accumulate_grad_batches', self.hparams.accumulate_grad_batches)
|
| 178 |
+
# print('self.train_loader.dataset', self.train_loader.dataset)
|
| 179 |
+
# print('self.hparams.max_epochs', self.hparams.max_epochs)
|
| 180 |
+
# print('self.hparams.train_batch_size', self.hparams.train_batch_size)
|
| 181 |
+
# exit()
|
| 182 |
+
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
| 183 |
+
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
| 184 |
+
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
| 185 |
+
dataset_size = len(self.train_loader.dataset)
|
| 186 |
+
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
|
| 187 |
+
|
| 188 |
+
def setup(self, mode):
|
| 189 |
+
#if mode == "fit":
|
| 190 |
+
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
| 191 |
+
|
| 192 |
+
def get_dataloader(self, type_path, batch_size, shuffle=False):
|
| 193 |
+
raise NotImplementedError("You must implement this for your task")
|
| 194 |
+
|
| 195 |
+
def train_dataloader(self):
|
| 196 |
+
return self.train_loader
|
| 197 |
+
|
| 198 |
+
def val_dataloader(self):
|
| 199 |
+
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
|
| 200 |
+
|
| 201 |
+
def test_dataloader(self):
|
| 202 |
+
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
|
| 203 |
+
|
| 204 |
+
def _feature_file(self, mode):
|
| 205 |
+
return os.path.join(
|
| 206 |
+
self.hparams.data_dir,
|
| 207 |
+
"cached_{}_{}_{}".format(
|
| 208 |
+
mode,
|
| 209 |
+
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
| 210 |
+
str(self.hparams.max_seq_length),
|
| 211 |
+
),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def get_progress_bar_dict(self):
|
| 215 |
+
#metrics = self.trainer.callback_metrics
|
| 216 |
+
#print(self.trainer.lr_logger.lrs)
|
| 217 |
+
lrs = self.trainer.lr_logger.lrs['lr-AdamW/pg1'][-1]
|
| 218 |
+
running_train_loss = self.trainer.running_loss.mean()
|
| 219 |
+
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
|
| 220 |
+
tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss), "lr": lrs}
|
| 221 |
+
return tqdm_dict
|
| 222 |
+
|
| 223 |
+
@pl.utilities.rank_zero_only
|
| 224 |
+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
| 225 |
+
save_path = self.output_dir.joinpath("best_tfmr")
|
| 226 |
+
self.model.config.save_step = self.step_count
|
| 227 |
+
self.model.save_pretrained(save_path)
|
| 228 |
+
self.tokenizer.save_pretrained(save_path)
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def add_model_specific_args(parser, root_dir):
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--model_name_or_path",
|
| 234 |
+
default=None,
|
| 235 |
+
type=str,
|
| 236 |
+
required=True,
|
| 237 |
+
help="Path to pretrained model or model identifier from huggingface.co/models",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--tokenizer_name",
|
| 244 |
+
default=None,
|
| 245 |
+
type=str,
|
| 246 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--cache_dir",
|
| 250 |
+
default="",
|
| 251 |
+
type=str,
|
| 252 |
+
help="Where do you want to store the pre-trained models downloaded from s3",
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--encoder_layerdrop",
|
| 256 |
+
type=float,
|
| 257 |
+
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--decoder_layerdrop",
|
| 261 |
+
type=float,
|
| 262 |
+
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--dropout",
|
| 266 |
+
type=float,
|
| 267 |
+
help="Dropout probability (Optional). Goes into model.config",
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--attention_dropout",
|
| 271 |
+
type=float,
|
| 272 |
+
help="Attention dropout probability (Optional). Goes into model.config",
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--lr_scheduler",
|
| 277 |
+
default="linear",
|
| 278 |
+
choices=arg_to_scheduler_choices,
|
| 279 |
+
metavar=arg_to_scheduler_metavar,
|
| 280 |
+
type=str,
|
| 281 |
+
help="Learning rate scheduler",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
| 284 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
| 285 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
| 286 |
+
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
| 287 |
+
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
| 288 |
+
parser.add_argument("--train_batch_size", default=32, type=int)
|
| 289 |
+
parser.add_argument("--eval_batch_size", default=32, type=int)
|
| 290 |
+
parser.add_argument("--adafactor", action="store_true")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class LoggingCallback(pl.Callback):
|
| 294 |
+
def on_batch_end(self, trainer, pl_module):
|
| 295 |
+
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
| 296 |
+
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
| 297 |
+
pl_module.logger.log_metrics(lrs)
|
| 298 |
+
|
| 299 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
| 300 |
+
rank_zero_info("***** Validation results *****")
|
| 301 |
+
metrics = trainer.callback_metrics
|
| 302 |
+
rank_zero_info(trainer.logger)
|
| 303 |
+
# Log results
|
| 304 |
+
for key in sorted(metrics):
|
| 305 |
+
if key not in ["log", "progress_bar"]:
|
| 306 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
| 307 |
+
|
| 308 |
+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
| 309 |
+
rank_zero_info("***** Test results *****")
|
| 310 |
+
metrics = trainer.callback_metrics
|
| 311 |
+
# Log and save results to file
|
| 312 |
+
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
| 313 |
+
with open(output_test_results_file, "w") as writer:
|
| 314 |
+
for key in sorted(metrics):
|
| 315 |
+
if key not in ["log", "progress_bar"]:
|
| 316 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
| 317 |
+
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def add_generic_args(parser, root_dir) -> None:
|
| 321 |
+
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--output_dir",
|
| 324 |
+
default=None,
|
| 325 |
+
type=str,
|
| 326 |
+
required=True,
|
| 327 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--fp16",
|
| 331 |
+
action="store_true",
|
| 332 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--fp16_opt_level",
|
| 337 |
+
type=str,
|
| 338 |
+
default="O2",
|
| 339 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
| 340 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
| 343 |
+
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
| 344 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
| 345 |
+
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--gradient_accumulation_steps",
|
| 348 |
+
dest="accumulate_grad_batches",
|
| 349 |
+
type=int,
|
| 350 |
+
default=1,
|
| 351 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--data_dir",
|
| 356 |
+
default=None,
|
| 357 |
+
type=str,
|
| 358 |
+
required=True,
|
| 359 |
+
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def generic_train(
|
| 364 |
+
model: BaseTransformer,
|
| 365 |
+
args: argparse.Namespace,
|
| 366 |
+
early_stopping_callback=False,
|
| 367 |
+
logger=True, # can pass WandbLogger() here
|
| 368 |
+
extra_callbacks=[],
|
| 369 |
+
checkpoint_callback=None,
|
| 370 |
+
logging_callback=None,
|
| 371 |
+
**extra_train_kwargs
|
| 372 |
+
):
|
| 373 |
+
pl.seed_everything(args.seed)
|
| 374 |
+
|
| 375 |
+
# init model
|
| 376 |
+
odir = Path(model.hparams.output_dir)
|
| 377 |
+
odir.mkdir(exist_ok=True)
|
| 378 |
+
|
| 379 |
+
# add custom checkpoints
|
| 380 |
+
if checkpoint_callback is None:
|
| 381 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
| 382 |
+
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
| 383 |
+
)
|
| 384 |
+
if logging_callback is None:
|
| 385 |
+
logging_callback = LoggingCallback()
|
| 386 |
+
|
| 387 |
+
train_params = {}
|
| 388 |
+
|
| 389 |
+
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
| 390 |
+
if args.fp16:
|
| 391 |
+
train_params["precision"] = 16
|
| 392 |
+
train_params["amp_level"] = args.fp16_opt_level
|
| 393 |
+
|
| 394 |
+
if args.gpus > 1:
|
| 395 |
+
train_params["distributed_backend"] = "ddp"
|
| 396 |
+
|
| 397 |
+
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
| 398 |
+
|
| 399 |
+
lr_logger = LearningRateMonitor(logging_interval='step')
|
| 400 |
+
|
| 401 |
+
# deterministic=True,
|
| 402 |
+
trainer = pl.Trainer.from_argparse_args(
|
| 403 |
+
args,
|
| 404 |
+
weights_summary='full',
|
| 405 |
+
callbacks=[logging_callback, lr_logger],
|
| 406 |
+
logger=logger,
|
| 407 |
+
checkpoint_callback=checkpoint_callback,
|
| 408 |
+
early_stop_callback=early_stopping_callback,
|
| 409 |
+
num_sanity_val_steps=4,
|
| 410 |
+
**train_params,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
trainer.lr_logger = lr_logger
|
| 414 |
+
|
| 415 |
+
if args.do_train:
|
| 416 |
+
trainer.fit(model)
|
| 417 |
+
|
| 418 |
+
return trainer
|
utils/sentence_retrieval_model.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from utils.bert_model import BertForSequenceEncoder
|
| 5 |
+
|
| 6 |
+
class sentence_retrieval_model(nn.Module):
|
| 7 |
+
def __init__(self, args):
|
| 8 |
+
super(sentence_retrieval_model, self).__init__()
|
| 9 |
+
self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
|
| 10 |
+
self.bert_hidden_dim = args['bert_hidden_dim']
|
| 11 |
+
self.dropout = nn.Dropout(args['dropout'])
|
| 12 |
+
self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def forward(self, inp_tensor, msk_tensor, seg_tensor):
|
| 16 |
+
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
|
| 17 |
+
inputs = self.dropout(inputs)
|
| 18 |
+
score = self.proj_match(inputs).squeeze(-1)
|
| 19 |
+
score = torch.tanh(score)
|
| 20 |
+
return score
|
utils/sentence_retrieval_module.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
import pathlib
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import BertTokenizer
|
| 7 |
+
|
| 8 |
+
from utils.sentence_retrieval_model import sentence_retrieval_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
THIS_DIR = pathlib.Path(__file__).parent.absolute()
|
| 12 |
+
ARGS = {
|
| 13 |
+
'batch_size': 32,
|
| 14 |
+
'bert_pretrain': 'base/bert_base',
|
| 15 |
+
'checkpoint': 'base/model.best.32.pt',
|
| 16 |
+
'dropout': 0.6,
|
| 17 |
+
'bert_hidden_dim': 768,
|
| 18 |
+
'max_len': 384,
|
| 19 |
+
'cuda': torch.cuda.is_available()
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
if not ARGS['cuda']:
|
| 23 |
+
print('CUDA NOT AVAILABLE')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process_sent(sentence):
|
| 27 |
+
sentence = re.sub("LSB.*?RSB", "", sentence)
|
| 28 |
+
sentence = re.sub("LRB\s*?RRB", "", sentence)
|
| 29 |
+
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
|
| 30 |
+
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
|
| 31 |
+
sentence = re.sub("--", "-", sentence)
|
| 32 |
+
sentence = re.sub("``", '"', sentence)
|
| 33 |
+
sentence = re.sub("''", '"', sentence)
|
| 34 |
+
return sentence
|
| 35 |
+
|
| 36 |
+
class SentenceRetrievalModule():
|
| 37 |
+
|
| 38 |
+
def __init__(self, max_len=None):
|
| 39 |
+
|
| 40 |
+
if max_len:
|
| 41 |
+
ARGS['max_len'] = max_len
|
| 42 |
+
|
| 43 |
+
self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
|
| 44 |
+
self.model = sentence_retrieval_model(ARGS)
|
| 45 |
+
self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
|
| 46 |
+
if ARGS['cuda']:
|
| 47 |
+
self.model = self.model.cuda()
|
| 48 |
+
|
| 49 |
+
def score_sentence_pairs(self, inputs: List[Tuple[str]]):
|
| 50 |
+
inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]
|
| 51 |
+
|
| 52 |
+
encodings = self.tokenizer(
|
| 53 |
+
inputs_processed,
|
| 54 |
+
padding='max_length',
|
| 55 |
+
truncation='longest_first',
|
| 56 |
+
max_length=ARGS['max_len'],
|
| 57 |
+
return_token_type_ids=True,
|
| 58 |
+
return_attention_mask=True,
|
| 59 |
+
return_tensors='pt',
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
inp = encodings['input_ids']
|
| 63 |
+
msk = encodings['attention_mask']
|
| 64 |
+
seg = encodings['token_type_ids']
|
| 65 |
+
|
| 66 |
+
if ARGS['cuda']:
|
| 67 |
+
inp = inp.cuda()
|
| 68 |
+
msk = msk.cuda()
|
| 69 |
+
seg = seg.cuda()
|
| 70 |
+
|
| 71 |
+
self.model.eval()
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
outputs = self.model(inp, msk, seg).tolist()
|
| 74 |
+
|
| 75 |
+
assert len(outputs) == len(inputs)
|
| 76 |
+
|
| 77 |
+
return outputs
|
utils/textual_entailment_module.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
| 9 |
+
|
| 10 |
+
# Constants and paths
|
| 11 |
+
HOME = Path('/users/k2031554')
|
| 12 |
+
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 13 |
+
MAX_LEN = 512
|
| 14 |
+
CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
|
| 15 |
+
METHODS = ['WEIGHTED_SUM', 'MALON']
|
| 16 |
+
|
| 17 |
+
def process_sent(sentence):
|
| 18 |
+
sentence = re.sub("LSB.*?RSB", "", sentence)
|
| 19 |
+
sentence = re.sub("LRB\s*?RRB", "", sentence)
|
| 20 |
+
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
|
| 21 |
+
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
|
| 22 |
+
sentence = re.sub("--", "-", sentence)
|
| 23 |
+
sentence = re.sub("``", '"', sentence)
|
| 24 |
+
sentence = re.sub("''", '"', sentence)
|
| 25 |
+
return sentence
|
| 26 |
+
|
| 27 |
+
class TextualEntailmentModule():
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
model_path = 'base/models/BERT_FEVER_v4_model_PBT',
|
| 32 |
+
tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
|
| 33 |
+
):
|
| 34 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
| 35 |
+
tokenizer_path
|
| 36 |
+
)
|
| 37 |
+
self.model = BertForSequenceClassification.from_pretrained(
|
| 38 |
+
model_path
|
| 39 |
+
)
|
| 40 |
+
self.model.to(DEVICE)
|
| 41 |
+
|
| 42 |
+
#def get_pair_scores(self, claim, evidence):
|
| 43 |
+
#
|
| 44 |
+
# encodings = self.tokenizer(
|
| 45 |
+
# [claim, evidence],
|
| 46 |
+
# max_length= MAX_LEN,
|
| 47 |
+
# return_token_type_ids=False,
|
| 48 |
+
# padding='max_length',
|
| 49 |
+
# truncation=True,
|
| 50 |
+
# return_tensors='pt',
|
| 51 |
+
# ).to(DEVICE)
|
| 52 |
+
#
|
| 53 |
+
# self.model.eval()
|
| 54 |
+
# with torch.no_grad():
|
| 55 |
+
# probs = self.model(
|
| 56 |
+
# input_ids=encodings['input_ids'],
|
| 57 |
+
# attention_mask=encodings['attention_mask']
|
| 58 |
+
# )
|
| 59 |
+
#
|
| 60 |
+
# return torch.softmax(probs.logits,dim=1).cpu().numpy()
|
| 61 |
+
|
| 62 |
+
def get_batch_scores(self, claims, evidence):
|
| 63 |
+
|
| 64 |
+
inputs = list(zip(claims, evidence))
|
| 65 |
+
|
| 66 |
+
encodings = self.tokenizer(
|
| 67 |
+
inputs,
|
| 68 |
+
max_length= MAX_LEN,
|
| 69 |
+
return_token_type_ids=False,
|
| 70 |
+
padding='max_length',
|
| 71 |
+
truncation=True,
|
| 72 |
+
return_tensors='pt',
|
| 73 |
+
).to(DEVICE)
|
| 74 |
+
|
| 75 |
+
self.model.eval()
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
probs = self.model(
|
| 78 |
+
input_ids=encodings['input_ids'],
|
| 79 |
+
attention_mask=encodings['attention_mask']
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return torch.softmax(probs.logits,dim=1).cpu().numpy()
|
| 83 |
+
|
| 84 |
+
def get_label_from_scores(self, scores):
|
| 85 |
+
return CLASSES[np.argmax(scores)]
|
| 86 |
+
|
| 87 |
+
def get_label_malon(self, score_set):
|
| 88 |
+
score_labels = [np.argmax(s) for s in score_set]
|
| 89 |
+
if 1 not in score_labels and 0 not in score_labels:
|
| 90 |
+
return CLASSES[2] #NOT ENOUGH INFO
|
| 91 |
+
elif 0 in score_labels:
|
| 92 |
+
return CLASSES[0] #SUPPORTS
|
| 93 |
+
elif 1 in score_labels:
|
| 94 |
+
return CLASSES[1] #REFUTES
|
utils/utils_graph2text.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def convert_text(text):
|
| 5 |
+
#return text
|
| 6 |
+
text = text.lower()
|
| 7 |
+
text = ' '.join(re.split('(\W)', text))
|
| 8 |
+
text = ' '.join(text.split())
|
| 9 |
+
return text
|
| 10 |
+
|
| 11 |
+
def eval_meteor_test_webnlg(folder_data, pred_file, dataset):
|
| 12 |
+
|
| 13 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 14 |
+
folder_data_before = dir_path + "/../utils"
|
| 15 |
+
|
| 16 |
+
cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
|
| 17 |
+
+ folder_data + "/" + dataset + ".target_eval_meteor -l en -norm -r 3 > " + pred_file.replace("txt", "meteor")
|
| 18 |
+
|
| 19 |
+
os.system(cmd_string)
|
| 20 |
+
|
| 21 |
+
meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
|
| 22 |
+
|
| 23 |
+
return meteor_info
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def eval_chrf_test_webnlg(folder_data, pred_file, dataset):
|
| 27 |
+
|
| 28 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 29 |
+
folder_data_before = dir_path + "/../utils"
|
| 30 |
+
|
| 31 |
+
cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
|
| 32 |
+
+ folder_data + "/" + dataset + ".target_eval_crf > " + pred_file.replace("txt", "chrf")
|
| 33 |
+
|
| 34 |
+
os.system(cmd_string)
|
| 35 |
+
|
| 36 |
+
chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
|
| 37 |
+
chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
|
| 38 |
+
|
| 39 |
+
return chrf_info_1 + " " + chrf_info_2
|
| 40 |
+
|
| 41 |
+
def eval_bleu(folder_data, pred_file, dataset):
|
| 42 |
+
|
| 43 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 44 |
+
folder_data_before = dir_path + "/data/"
|
| 45 |
+
|
| 46 |
+
cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target_eval " \
|
| 47 |
+
+ folder_data + "/" + dataset + ".target2_eval " + folder_data + "/" + dataset + ".target3_eval < " \
|
| 48 |
+
+ pred_file + " > " + pred_file.replace("txt", "bleu")
|
| 49 |
+
|
| 50 |
+
os.system(cmd_string)
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
bleu_info = open(pred_file.replace("txt", "bleu"), 'r').readlines()[0].strip()
|
| 54 |
+
except:
|
| 55 |
+
bleu_info = -1
|
| 56 |
+
|
| 57 |
+
return bleu_info
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def eval_bleu_sents_tok(pred_file, folder_data, dataset):
|
| 61 |
+
|
| 62 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 63 |
+
folder_data_before = dir_path + "/../utils"
|
| 64 |
+
|
| 65 |
+
cmd_string = "perl " + folder_data_before + "/tokenizer.perl -threads 4 -no-escape < " + pred_file + " > " +\
|
| 66 |
+
pred_file + "_tok"
|
| 67 |
+
os.system(cmd_string)
|
| 68 |
+
|
| 69 |
+
cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target.tok"\
|
| 70 |
+
+ " < " + pred_file + "_tok" + " > " + pred_file.replace("txt", "bleu_data")
|
| 71 |
+
os.system(cmd_string)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
bleu_info_data = open(pred_file.replace("txt", "bleu_data"), 'r').readlines()[0].strip()
|
| 75 |
+
except:
|
| 76 |
+
bleu_info_data = 'no data'
|
| 77 |
+
|
| 78 |
+
return bleu_info_data
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def eval_meteor(ref_file, pred_file):
|
| 82 |
+
|
| 83 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 84 |
+
folder_data_before = dir_path + "/../utils"
|
| 85 |
+
|
| 86 |
+
cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
|
| 87 |
+
+ ref_file + " > " + pred_file.replace("txt", "meteor")
|
| 88 |
+
|
| 89 |
+
os.system(cmd_string)
|
| 90 |
+
|
| 91 |
+
meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
|
| 92 |
+
|
| 93 |
+
return meteor_info
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def eval_chrf(ref_file, pred_file):
|
| 97 |
+
|
| 98 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 99 |
+
folder_data_before = dir_path + "/../utils"
|
| 100 |
+
|
| 101 |
+
cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
|
| 102 |
+
+ ref_file + " > " + pred_file.replace("txt", "chrf")
|
| 103 |
+
|
| 104 |
+
os.system(cmd_string)
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
|
| 108 |
+
chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
|
| 109 |
+
chrf_data = chrf_info_1 + " " + chrf_info_2
|
| 110 |
+
except:
|
| 111 |
+
chrf_data = "no data"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
return chrf_data
|
utils/utils_verbalisation_module.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import linecache
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import socket
|
| 8 |
+
from logging import getLogger
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from rouge_score import rouge_scorer, scoring
|
| 16 |
+
from sacrebleu import corpus_bleu
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.utils.data import Dataset, Sampler
|
| 19 |
+
|
| 20 |
+
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
| 21 |
+
from transformers.file_utils import cached_property
|
| 22 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
| 23 |
+
from utils.utils_graph2text import convert_text, eval_bleu
|
| 24 |
+
from pytorch_lightning.utilities import rank_zero_info
|
| 25 |
+
import pdb
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from fairseq.data.data_utils import batch_by_size
|
| 30 |
+
|
| 31 |
+
FAIRSEQ_AVAILABLE = True
|
| 32 |
+
except (ImportError, ModuleNotFoundError):
|
| 33 |
+
FAIRSEQ_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
| 37 |
+
"""From fairseq"""
|
| 38 |
+
if target.dim() == lprobs.dim() - 1:
|
| 39 |
+
target = target.unsqueeze(-1)
|
| 40 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
| 41 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
| 42 |
+
if ignore_index is not None:
|
| 43 |
+
pad_mask = target.eq(ignore_index)
|
| 44 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
| 45 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
| 46 |
+
else:
|
| 47 |
+
nll_loss = nll_loss.squeeze(-1)
|
| 48 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
| 49 |
+
|
| 50 |
+
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
| 51 |
+
smooth_loss = smooth_loss.sum()
|
| 52 |
+
eps_i = epsilon / lprobs.size(-1)
|
| 53 |
+
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
| 54 |
+
return loss, nll_loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def lmap(f: Callable, x: Iterable) -> List:
|
| 58 |
+
"""list(map(f, x))"""
|
| 59 |
+
return list(map(f, x))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def calculate_bleu(output_lns, refs_lns) -> dict:
|
| 63 |
+
"""Uses sacrebleu's corpus_bleu implementation."""
|
| 64 |
+
return {"sacrebleu": round(corpus_bleu(output_lns, [refs_lns]).score, 4)}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
| 68 |
+
def non_pad_len(tokens: np.ndarray) -> int:
|
| 69 |
+
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
| 70 |
+
|
| 71 |
+
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
| 72 |
+
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
| 73 |
+
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
| 74 |
+
pred_str = lmap(str.strip, pred_str)
|
| 75 |
+
label_str = lmap(str.strip, label_str)
|
| 76 |
+
return pred_str, label_str
|
| 77 |
+
|
| 78 |
+
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
| 79 |
+
pred_str, label_str = decode_pred(pred)
|
| 80 |
+
rouge: Dict = calculate_rouge(pred_str, label_str)
|
| 81 |
+
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
| 82 |
+
rouge.update({"gen_len": summ_len})
|
| 83 |
+
return rouge
|
| 84 |
+
|
| 85 |
+
def translation_metrics(pred: EvalPrediction) -> Dict:
|
| 86 |
+
pred_str, label_str = decode_pred(pred)
|
| 87 |
+
bleu: Dict = calculate_bleu(pred_str, label_str)
|
| 88 |
+
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
| 89 |
+
bleu.update({"gen_len": gen_len})
|
| 90 |
+
return bleu
|
| 91 |
+
|
| 92 |
+
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
| 93 |
+
return compute_metrics_fn
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def trim_batch(
|
| 97 |
+
input_ids,
|
| 98 |
+
pad_token_id,
|
| 99 |
+
attention_mask=None,
|
| 100 |
+
):
|
| 101 |
+
"""Remove columns that are populated exclusively by pad_token_id"""
|
| 102 |
+
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
| 103 |
+
if attention_mask is None:
|
| 104 |
+
return input_ids[:, keep_column_mask]
|
| 105 |
+
else:
|
| 106 |
+
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class AbstractSeq2SeqDataset(Dataset):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
tokenizer,
|
| 113 |
+
data_dir,
|
| 114 |
+
max_source_length,
|
| 115 |
+
max_target_length,
|
| 116 |
+
type_path="train",
|
| 117 |
+
n_obs=None,
|
| 118 |
+
prefix="",
|
| 119 |
+
**dataset_kwargs
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
| 123 |
+
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
| 124 |
+
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
| 125 |
+
if os.path.exists(self.len_file):
|
| 126 |
+
self.src_lens = pickle_load(self.len_file)
|
| 127 |
+
self.used_char_len = False
|
| 128 |
+
else:
|
| 129 |
+
self.src_lens = self.get_char_lens(self.src_file)
|
| 130 |
+
self.used_char_len = True
|
| 131 |
+
self.max_source_length = max_source_length
|
| 132 |
+
self.max_target_length = max_target_length
|
| 133 |
+
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
| 134 |
+
self.tokenizer = tokenizer
|
| 135 |
+
self.prefix = prefix if prefix is not None else ""
|
| 136 |
+
|
| 137 |
+
if n_obs is not None:
|
| 138 |
+
self.src_lens = self.src_lens[:n_obs]
|
| 139 |
+
self.pad_token_id = self.tokenizer.pad_token_id
|
| 140 |
+
self.dataset_kwargs = dataset_kwargs
|
| 141 |
+
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
|
| 142 |
+
|
| 143 |
+
def __len__(self):
|
| 144 |
+
return len(self.src_lens)
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def get_char_lens(data_file):
|
| 148 |
+
return [len(x) for x in Path(data_file).open().readlines()]
|
| 149 |
+
|
| 150 |
+
@cached_property
|
| 151 |
+
def tgt_lens(self):
|
| 152 |
+
"""Length in characters of target documents"""
|
| 153 |
+
return self.get_char_lens(self.tgt_file)
|
| 154 |
+
|
| 155 |
+
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
| 156 |
+
if distributed:
|
| 157 |
+
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
| 158 |
+
else:
|
| 159 |
+
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
| 160 |
+
|
| 161 |
+
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
| 162 |
+
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
| 163 |
+
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
|
| 164 |
+
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
| 165 |
+
|
| 166 |
+
def num_tokens_in_example(i):
|
| 167 |
+
return min(self.src_lens[i], self.max_target_length)
|
| 168 |
+
|
| 169 |
+
# call fairseq cython function
|
| 170 |
+
batch_sampler: List[List[int]] = batch_by_size(
|
| 171 |
+
sorted_indices,
|
| 172 |
+
num_tokens_fn=num_tokens_in_example,
|
| 173 |
+
max_tokens=max_tokens_per_batch,
|
| 174 |
+
required_batch_size_multiple=64,
|
| 175 |
+
)
|
| 176 |
+
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
|
| 177 |
+
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
| 178 |
+
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
|
| 179 |
+
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
| 180 |
+
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
| 181 |
+
shuffled_batches[largest_batch_idx],
|
| 182 |
+
shuffled_batches[0],
|
| 183 |
+
)
|
| 184 |
+
return shuffled_batches
|
| 185 |
+
|
| 186 |
+
def __getitem__(self, item):
|
| 187 |
+
raise NotImplementedError("You must implement this")
|
| 188 |
+
|
| 189 |
+
def collate_fn(self, batch):
|
| 190 |
+
raise NotImplementedError("You must implement this")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
| 194 |
+
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
| 195 |
+
"""Call tokenizer on src and tgt_lines"""
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
index = index + 1 # linecache starts at 1
|
| 199 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
| 200 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
| 201 |
+
assert source_line, f"empty source line for index {index}"
|
| 202 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
| 203 |
+
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
|
| 204 |
+
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
| 205 |
+
|
| 206 |
+
source_ids = source_inputs["input_ids"].squeeze()
|
| 207 |
+
target_ids = target_inputs["input_ids"].squeeze()
|
| 208 |
+
src_mask = source_inputs["attention_mask"].squeeze()
|
| 209 |
+
return {
|
| 210 |
+
"input_ids": source_ids,
|
| 211 |
+
"attention_mask": src_mask,
|
| 212 |
+
"labels": target_ids,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
| 216 |
+
"""Only used by LegacyDataset"""
|
| 217 |
+
return tokenizer(
|
| 218 |
+
[line],
|
| 219 |
+
max_length=max_length,
|
| 220 |
+
padding="max_length" if pad_to_max_length else None,
|
| 221 |
+
truncation=True,
|
| 222 |
+
return_tensors=return_tensors,
|
| 223 |
+
**self.dataset_kwargs,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
| 227 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
| 228 |
+
masks = torch.stack([x["attention_mask"] for x in batch])
|
| 229 |
+
target_ids = torch.stack([x["labels"] for x in batch])
|
| 230 |
+
pad_token_id = self.pad_token_id
|
| 231 |
+
y = trim_batch(target_ids, pad_token_id)
|
| 232 |
+
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
| 233 |
+
batch = {
|
| 234 |
+
"input_ids": source_ids,
|
| 235 |
+
"attention_mask": source_mask,
|
| 236 |
+
"labels": y,
|
| 237 |
+
}
|
| 238 |
+
return batch
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
| 242 |
+
"""A dataset that calls prepare_seq2seq_batch."""
|
| 243 |
+
|
| 244 |
+
def __getitem__(self, index) -> Dict[str, str]:
|
| 245 |
+
|
| 246 |
+
#print(self.dataset_kwargs['model_t'])
|
| 247 |
+
# if 't5' in self.dataset_kwargs['model_t']:
|
| 248 |
+
# self.prefix = 'translate Graph to English: '
|
| 249 |
+
# print('aac')
|
| 250 |
+
# exit()
|
| 251 |
+
|
| 252 |
+
index = index + 1 # linecache starts at 1
|
| 253 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
| 254 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
| 255 |
+
assert source_line, f"empty source line for index {index}"
|
| 256 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
| 257 |
+
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
| 258 |
+
|
| 259 |
+
def collate_fn(self, batch):
|
| 260 |
+
"""Call prepare_seq2seq_batch."""
|
| 261 |
+
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
| 262 |
+
[x["src_texts"] for x in batch],
|
| 263 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
| 264 |
+
max_length=self.max_source_length,
|
| 265 |
+
max_target_length=self.max_target_length,
|
| 266 |
+
return_tensors="pt",
|
| 267 |
+
**self.dataset_kwargs,
|
| 268 |
+
).data
|
| 269 |
+
#lens = (batch_encoding['attention_mask'] == 1.).sum(dim=1).tolist()
|
| 270 |
+
|
| 271 |
+
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
| 272 |
+
|
| 273 |
+
return batch_encoding
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Seq2SeqDataCollator:
|
| 278 |
+
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
| 279 |
+
self.tokenizer = tokenizer
|
| 280 |
+
self.pad_token_id = tokenizer.pad_token_id
|
| 281 |
+
assert (
|
| 282 |
+
self.pad_token_id is not None
|
| 283 |
+
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
| 284 |
+
self.data_args = data_args
|
| 285 |
+
self.tpu_num_cores = tpu_num_cores
|
| 286 |
+
self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
|
| 287 |
+
if data_args.src_lang is not None:
|
| 288 |
+
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
| 289 |
+
if data_args.tgt_lang is not None:
|
| 290 |
+
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
| 291 |
+
|
| 292 |
+
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
| 293 |
+
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
| 294 |
+
batch = self._encode(batch)
|
| 295 |
+
input_ids, attention_mask, labels = (
|
| 296 |
+
batch["input_ids"],
|
| 297 |
+
batch["attention_mask"],
|
| 298 |
+
batch["labels"],
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
| 302 |
+
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
| 303 |
+
labels = torch.stack([x["labels"] for x in batch])
|
| 304 |
+
|
| 305 |
+
labels = trim_batch(labels, self.pad_token_id)
|
| 306 |
+
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
| 307 |
+
|
| 308 |
+
if isinstance(self.tokenizer, T5Tokenizer):
|
| 309 |
+
decoder_input_ids = self._shift_right_t5(labels)
|
| 310 |
+
else:
|
| 311 |
+
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
| 312 |
+
|
| 313 |
+
batch = {
|
| 314 |
+
"input_ids": input_ids,
|
| 315 |
+
"attention_mask": attention_mask,
|
| 316 |
+
"decoder_input_ids": decoder_input_ids,
|
| 317 |
+
"labels": labels,
|
| 318 |
+
}
|
| 319 |
+
return batch
|
| 320 |
+
|
| 321 |
+
def _shift_right_t5(self, input_ids):
|
| 322 |
+
# shift inputs to the right
|
| 323 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 324 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 325 |
+
shifted_input_ids[..., 0] = self.pad_token_id
|
| 326 |
+
return shifted_input_ids
|
| 327 |
+
|
| 328 |
+
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
| 329 |
+
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
| 330 |
+
[x["src_texts"] for x in batch],
|
| 331 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
| 332 |
+
max_length=self.data_args.max_source_length,
|
| 333 |
+
max_target_length=self.data_args.max_target_length,
|
| 334 |
+
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
| 335 |
+
return_tensors="pt",
|
| 336 |
+
**self.dataset_kwargs,
|
| 337 |
+
)
|
| 338 |
+
return batch_encoding.data
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class SortishSampler(Sampler):
|
| 342 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
| 343 |
+
|
| 344 |
+
def __init__(self, data, batch_size, shuffle=True):
|
| 345 |
+
self.data, self.bs, self.shuffle = data, batch_size, shuffle
|
| 346 |
+
|
| 347 |
+
def __len__(self) -> int:
|
| 348 |
+
return len(self.data)
|
| 349 |
+
|
| 350 |
+
def __iter__(self):
|
| 351 |
+
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
| 355 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
| 356 |
+
if not shuffle:
|
| 357 |
+
return np.argsort(np.array(data) * -1)
|
| 358 |
+
|
| 359 |
+
def key_fn(i):
|
| 360 |
+
return data[i]
|
| 361 |
+
|
| 362 |
+
idxs = np.random.permutation(len(data))
|
| 363 |
+
sz = bs * 50
|
| 364 |
+
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
| 365 |
+
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
|
| 366 |
+
sz = bs
|
| 367 |
+
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
| 368 |
+
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
| 369 |
+
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
| 370 |
+
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
| 371 |
+
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
| 372 |
+
return sort_idx
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DistributedSortishSampler(Sampler):
|
| 376 |
+
"""Copied from torch DistributedSampler"""
|
| 377 |
+
|
| 378 |
+
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
|
| 379 |
+
if num_replicas is None:
|
| 380 |
+
if not dist.is_available():
|
| 381 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 382 |
+
num_replicas = dist.get_world_size()
|
| 383 |
+
if rank is None:
|
| 384 |
+
if not dist.is_available():
|
| 385 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 386 |
+
rank = dist.get_rank()
|
| 387 |
+
self.dataset = dataset
|
| 388 |
+
self.num_replicas = num_replicas
|
| 389 |
+
self.rank = rank
|
| 390 |
+
self.epoch = 0
|
| 391 |
+
if add_extra_examples:
|
| 392 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 393 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 394 |
+
else:
|
| 395 |
+
self.total_size = len(dataset)
|
| 396 |
+
self.num_samples = len(self.available_indices)
|
| 397 |
+
self.batch_size = batch_size
|
| 398 |
+
self.add_extra_examples = add_extra_examples
|
| 399 |
+
self.shuffle = shuffle
|
| 400 |
+
|
| 401 |
+
def __iter__(self) -> Iterable:
|
| 402 |
+
g = torch.Generator()
|
| 403 |
+
g.manual_seed(self.epoch)
|
| 404 |
+
|
| 405 |
+
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
| 406 |
+
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
|
| 407 |
+
indices = [self.available_indices[i] for i in sortish_indices]
|
| 408 |
+
assert len(indices) == self.num_samples
|
| 409 |
+
return iter(indices)
|
| 410 |
+
|
| 411 |
+
@cached_property
|
| 412 |
+
def available_indices(self) -> np.array:
|
| 413 |
+
indices = list(range(len(self.dataset)))
|
| 414 |
+
# add extra samples to make it evenly divisible
|
| 415 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 416 |
+
assert len(indices) == self.total_size
|
| 417 |
+
# subsample
|
| 418 |
+
available_indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 419 |
+
return available_indices
|
| 420 |
+
|
| 421 |
+
def __len__(self):
|
| 422 |
+
return self.num_samples
|
| 423 |
+
|
| 424 |
+
def set_epoch(self, epoch):
|
| 425 |
+
self.epoch = epoch
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
logger = getLogger(__name__)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def use_task_specific_params(model, task):
|
| 432 |
+
"""Update config with summarization specific params."""
|
| 433 |
+
task_specific_params = model.config.task_specific_params
|
| 434 |
+
|
| 435 |
+
if task_specific_params is not None:
|
| 436 |
+
pars = task_specific_params.get(task, {})
|
| 437 |
+
logger.info(f"using task specific params for {task}: {pars}")
|
| 438 |
+
model.config.update(pars)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def pickle_load(path):
|
| 442 |
+
"""pickle.load(path)"""
|
| 443 |
+
with open(path, "rb") as f:
|
| 444 |
+
return pickle.load(f)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def pickle_save(obj, path):
|
| 448 |
+
"""pickle.dump(obj, path)"""
|
| 449 |
+
with open(path, "wb") as f:
|
| 450 |
+
return pickle.dump(obj, f)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def flatten_list(summary_ids: List[List]):
|
| 454 |
+
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def save_json(content, path, indent=4, **json_dump_kwargs):
|
| 458 |
+
with open(path, "w") as f:
|
| 459 |
+
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def load_json(path):
|
| 463 |
+
with open(path) as f:
|
| 464 |
+
return json.load(f)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def extract_rouge_mid_statistics(dct):
|
| 471 |
+
new_dict = {}
|
| 472 |
+
for k1, v1 in dct.items():
|
| 473 |
+
mid = v1.mid
|
| 474 |
+
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
|
| 475 |
+
return new_dict
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def calculate_rouge(
|
| 479 |
+
pred_lns: List[str],
|
| 480 |
+
tgt_lns: List[str],
|
| 481 |
+
use_stemmer=True,
|
| 482 |
+
rouge_keys=ROUGE_KEYS,
|
| 483 |
+
return_precision_and_recall=False,
|
| 484 |
+
bootstrap_aggregation=True,
|
| 485 |
+
newline_sep=True,
|
| 486 |
+
) -> Dict:
|
| 487 |
+
"""Calculate rouge using rouge_scorer package.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
pred_lns: list of summaries generated by model
|
| 491 |
+
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
|
| 492 |
+
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
| 493 |
+
strip word suffixes to improve matching.
|
| 494 |
+
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
|
| 495 |
+
return_precision_and_recall: (False) whether to also return precision and recall.
|
| 496 |
+
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
|
| 497 |
+
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
|
| 498 |
+
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
|
| 499 |
+
on multi sentence summaries (CNN/DM dataset).
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
|
| 503 |
+
|
| 504 |
+
"""
|
| 505 |
+
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
|
| 506 |
+
aggregator = scoring.BootstrapAggregator()
|
| 507 |
+
for pred, tgt in zip(tgt_lns, pred_lns):
|
| 508 |
+
# rougeLsum expects "\n" separated sentences within a summary
|
| 509 |
+
if newline_sep:
|
| 510 |
+
pred = add_newline_to_end_of_each_sentence(pred)
|
| 511 |
+
tgt = add_newline_to_end_of_each_sentence(tgt)
|
| 512 |
+
scores = scorer.score(pred, tgt)
|
| 513 |
+
aggregator.add_scores(scores)
|
| 514 |
+
|
| 515 |
+
if bootstrap_aggregation:
|
| 516 |
+
result = aggregator.aggregate()
|
| 517 |
+
if return_precision_and_recall:
|
| 518 |
+
return extract_rouge_mid_statistics(result) # here we return dict
|
| 519 |
+
else:
|
| 520 |
+
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
return aggregator._scores # here we return defaultdict(list)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# Utilities for freezing parameters and checking whether they are frozen
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def freeze_params(model: nn.Module):
|
| 530 |
+
"""Set requires_grad=False for each of model.parameters()"""
|
| 531 |
+
for par in model.parameters():
|
| 532 |
+
par.requires_grad = False
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def freeze_embeds(model):
|
| 536 |
+
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
| 537 |
+
model_type = model.config.model_type
|
| 538 |
+
|
| 539 |
+
if model_type == "t5":
|
| 540 |
+
freeze_params(model.shared)
|
| 541 |
+
for d in [model.encoder, model.decoder]:
|
| 542 |
+
freeze_params(d.embed_tokens)
|
| 543 |
+
elif model_type == "fsmt":
|
| 544 |
+
for d in [model.model.encoder, model.model.decoder]:
|
| 545 |
+
freeze_params(d.embed_positions)
|
| 546 |
+
freeze_params(d.embed_tokens)
|
| 547 |
+
else:
|
| 548 |
+
freeze_params(model.model.shared)
|
| 549 |
+
for d in [model.model.encoder, model.model.decoder]:
|
| 550 |
+
freeze_params(d.embed_positions)
|
| 551 |
+
freeze_params(d.embed_tokens)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def grad_status(model: nn.Module) -> Iterable:
|
| 555 |
+
return (par.requires_grad for par in model.parameters())
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def any_requires_grad(model: nn.Module) -> bool:
|
| 559 |
+
return any(grad_status(model))
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def assert_all_frozen(model):
|
| 563 |
+
model_grads: List[bool] = list(grad_status(model))
|
| 564 |
+
n_require_grad = sum(lmap(int, model_grads))
|
| 565 |
+
npars = len(model_grads)
|
| 566 |
+
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def assert_not_all_frozen(model):
|
| 570 |
+
model_grads: List[bool] = list(grad_status(model))
|
| 571 |
+
npars = len(model_grads)
|
| 572 |
+
assert any(model_grads), f"none of {npars} weights require grad"
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
| 576 |
+
"""
|
| 577 |
+
Parse an argv list of unspecified command line args to a dict.
|
| 578 |
+
Assumes all values are either numeric or boolean in the form of true/false.
|
| 579 |
+
"""
|
| 580 |
+
result = {}
|
| 581 |
+
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
| 582 |
+
num_pairs = len(unparsed_args) // 2
|
| 583 |
+
for pair_num in range(num_pairs):
|
| 584 |
+
i = 2 * pair_num
|
| 585 |
+
assert unparsed_args[i].startswith("--")
|
| 586 |
+
if unparsed_args[i + 1].lower() == "true":
|
| 587 |
+
value = True
|
| 588 |
+
elif unparsed_args[i + 1].lower() == "false":
|
| 589 |
+
value = False
|
| 590 |
+
else:
|
| 591 |
+
try:
|
| 592 |
+
value = int(unparsed_args[i + 1])
|
| 593 |
+
except ValueError:
|
| 594 |
+
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
| 595 |
+
|
| 596 |
+
result[unparsed_args[i][2:]] = value
|
| 597 |
+
return result
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def write_txt_file(ordered_tgt, path):
|
| 601 |
+
f = Path(path).open("w")
|
| 602 |
+
for ln in ordered_tgt:
|
| 603 |
+
f.write(ln + "\n")
|
| 604 |
+
f.flush()
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def chunks(lst, n):
|
| 608 |
+
"""Yield successive n-sized chunks from lst."""
|
| 609 |
+
for i in range(0, len(lst), n):
|
| 610 |
+
yield lst[i : i + n]
|
utils/verbalisation_module.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.finetune import Graph2TextModule
|
| 2 |
+
from typing import Dict, List, Tuple, Union, Optional
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
DEVICE = 'cuda'
|
| 8 |
+
else:
|
| 9 |
+
DEVICE = 'cpu'
|
| 10 |
+
print('CUDA NOT AVAILABLE')
|
| 11 |
+
|
| 12 |
+
CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
|
| 13 |
+
MAX_LENGTH = 384
|
| 14 |
+
SEED = 42
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VerbModule():
|
| 18 |
+
|
| 19 |
+
def __init__(self, override_args: Dict[str, str] = None):
|
| 20 |
+
# Model
|
| 21 |
+
if not override_args:
|
| 22 |
+
override_args = {}
|
| 23 |
+
self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
|
| 24 |
+
self.tokenizer = self.g2t_module.tokenizer
|
| 25 |
+
# Unk replacer
|
| 26 |
+
self.vocab = self.tokenizer.get_vocab()
|
| 27 |
+
self.convert_some_japanese_characters = True
|
| 28 |
+
self.unk_char_replace_sliding_window_size = 2
|
| 29 |
+
self.unknowns = []
|
| 30 |
+
|
| 31 |
+
def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
|
| 32 |
+
try:
|
| 33 |
+
inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
|
| 34 |
+
inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
|
| 35 |
+
)
|
| 36 |
+
inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
|
| 37 |
+
|
| 38 |
+
self.g2t_module.model.eval()
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
gen_output = self.g2t_module.model.generate(
|
| 41 |
+
inputs_encoding['input_ids'],
|
| 42 |
+
attention_mask=inputs_encoding['attention_mask'],
|
| 43 |
+
use_cache=True,
|
| 44 |
+
decoder_start_token_id = self.g2t_module.decoder_start_token_id,
|
| 45 |
+
num_beams= self.g2t_module.eval_beams,
|
| 46 |
+
max_length= self.g2t_module.eval_max_length,
|
| 47 |
+
length_penalty=1.0
|
| 48 |
+
)
|
| 49 |
+
except Exception:
|
| 50 |
+
print(inputs)
|
| 51 |
+
raise
|
| 52 |
+
|
| 53 |
+
return gen_output
|
| 54 |
+
|
| 55 |
+
'''
|
| 56 |
+
We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
|
| 57 |
+
'''
|
| 58 |
+
def __decode_ids_to_string_custom(
|
| 59 |
+
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
| 60 |
+
) -> str:
|
| 61 |
+
filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
|
| 62 |
+
# Do not remove special tokens yet
|
| 63 |
+
|
| 64 |
+
# To avoid mixing byte-level and unicode for byte-level BPT
|
| 65 |
+
# we need to build string separatly for added tokens and byte-level tokens
|
| 66 |
+
# cf. https://github.com/huggingface/transformers/issues/1133
|
| 67 |
+
sub_texts = []
|
| 68 |
+
current_sub_text = []
|
| 69 |
+
for token in filtered_tokens:
|
| 70 |
+
if skip_special_tokens and\
|
| 71 |
+
token != self.tokenizer.unk_token and\
|
| 72 |
+
token in self.tokenizer.all_special_tokens:
|
| 73 |
+
|
| 74 |
+
continue
|
| 75 |
+
else:
|
| 76 |
+
current_sub_text.append(token)
|
| 77 |
+
if current_sub_text:
|
| 78 |
+
sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
|
| 79 |
+
text = " ".join(sub_texts)
|
| 80 |
+
|
| 81 |
+
if clean_up_tokenization_spaces:
|
| 82 |
+
clean_text = self.tokenizer.clean_up_tokenization(text)
|
| 83 |
+
return clean_text
|
| 84 |
+
else:
|
| 85 |
+
return text
|
| 86 |
+
|
| 87 |
+
def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
|
| 88 |
+
if type(encoded_sentences) == str:
|
| 89 |
+
encoded_sentences = [encoded_sentences]
|
| 90 |
+
decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
|
| 91 |
+
return decoded_sentences
|
| 92 |
+
|
| 93 |
+
def verbalise_sentence(self, inputs: Union[str, List[str]]):
|
| 94 |
+
if type(inputs) == str:
|
| 95 |
+
inputs = [inputs]
|
| 96 |
+
|
| 97 |
+
gen_output = self.__generate_verbalisations_from_inputs(inputs)
|
| 98 |
+
|
| 99 |
+
decoded_sentences = self.__decode_sentences(gen_output)
|
| 100 |
+
|
| 101 |
+
if len(decoded_sentences) == 1:
|
| 102 |
+
return decoded_sentences[0]
|
| 103 |
+
else:
|
| 104 |
+
return decoded_sentences
|
| 105 |
+
|
| 106 |
+
def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
|
| 107 |
+
if type(input_triples) == dict:
|
| 108 |
+
input_triples = [input_triples]
|
| 109 |
+
|
| 110 |
+
verbalisation_inputs = []
|
| 111 |
+
for triple in input_triples:
|
| 112 |
+
if type(triple) == dict:
|
| 113 |
+
assert 'subject' in triple
|
| 114 |
+
assert 'predicate' in triple
|
| 115 |
+
assert 'object' in triple
|
| 116 |
+
verbalisation_inputs.append(
|
| 117 |
+
f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
|
| 118 |
+
)
|
| 119 |
+
elif type(triple) == list:
|
| 120 |
+
input_sentence = ['translate Graph to English:']
|
| 121 |
+
for subtriple in triple:
|
| 122 |
+
assert 'subject' in subtriple
|
| 123 |
+
assert 'predicate' in subtriple
|
| 124 |
+
assert 'object' in subtriple
|
| 125 |
+
input_sentence.append(f'<H> {subtriple["subject"]}')
|
| 126 |
+
input_sentence.append(f'<R> {subtriple["predicate"]}')
|
| 127 |
+
input_sentence.append(f'<T> {subtriple["object"]}')
|
| 128 |
+
verbalisation_inputs.append(
|
| 129 |
+
' '.join(input_sentence)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return self.verbalise_sentence(verbalisation_inputs)
|
| 133 |
+
|
| 134 |
+
def verbalise(self, input: Union[str, List, Dict]):
|
| 135 |
+
try:
|
| 136 |
+
if (type(input) == str) or (type(input) == list and type(input[0]) == str):
|
| 137 |
+
return self.verbalise_sentence(input)
|
| 138 |
+
elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
|
| 139 |
+
return self.verbalise_triples(input)
|
| 140 |
+
else:
|
| 141 |
+
return self.verbalise_triples(input)
|
| 142 |
+
except Exception:
|
| 143 |
+
print(f'ERROR VERBALISING {input}')
|
| 144 |
+
raise
|
| 145 |
+
|
| 146 |
+
def add_label_to_unk_replacer(self, label: str):
|
| 147 |
+
N = self.unk_char_replace_sliding_window_size
|
| 148 |
+
self.unknowns.append({})
|
| 149 |
+
|
| 150 |
+
# Some pre-processing of labels to normalise some characters
|
| 151 |
+
if self.convert_some_japanese_characters:
|
| 152 |
+
label = label.replace('(','(')
|
| 153 |
+
label = label.replace(')',')')
|
| 154 |
+
label = label.replace('〈','<')
|
| 155 |
+
label = label.replace('/','/')
|
| 156 |
+
label = label.replace('〉','>')
|
| 157 |
+
|
| 158 |
+
label_encoded = self.tokenizer.encode(label)
|
| 159 |
+
label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
|
| 160 |
+
|
| 161 |
+
# Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
|
| 162 |
+
# 1) When the whole label is all unk:
|
| 163 |
+
# label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
|
| 164 |
+
# the <unk> if it appears at the end of the sentence, which is not the desired effect.
|
| 165 |
+
# But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
|
| 166 |
+
# on the last replacing pass.
|
| 167 |
+
# 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
|
| 168 |
+
# involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
|
| 169 |
+
# to only happen if the label appears in the end of the sentence.
|
| 170 |
+
label_tokens = [t for t in label_tokens if t not in [
|
| 171 |
+
self.tokenizer.eos_token, self.tokenizer.pad_token
|
| 172 |
+
]]
|
| 173 |
+
|
| 174 |
+
label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
|
| 175 |
+
unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
|
| 176 |
+
|
| 177 |
+
#print(label_encoded,label_tokens,label_token_to_string)
|
| 178 |
+
|
| 179 |
+
match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
|
| 180 |
+
if len(match_unks_in_label) > 0:
|
| 181 |
+
# If the whole label is made of UNK
|
| 182 |
+
if (match_unks_in_label[0]) == label_token_to_string:
|
| 183 |
+
#print('Label is all unks')
|
| 184 |
+
self.unknowns[-1][label_token_to_string.strip()] = label
|
| 185 |
+
# Else, there should be non-UNK characters in the label
|
| 186 |
+
else:
|
| 187 |
+
#print('Label is NOT all unks')
|
| 188 |
+
# Analyse the label with a sliding window of size N (N before, N ahead)
|
| 189 |
+
for idx, token in enumerate(label_tokens):
|
| 190 |
+
idx_before = max(0,idx-N)
|
| 191 |
+
idx_ahead = min(len(label_tokens), idx+N+1)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Found a UNK
|
| 195 |
+
if token == self.tokenizer.unk_token:
|
| 196 |
+
|
| 197 |
+
# In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
|
| 198 |
+
if len(match_unks_in_label) > 1:
|
| 199 |
+
#print(idx)
|
| 200 |
+
#print(label_tokens)
|
| 201 |
+
#print(label_tokens[idx_before:idx_ahead])
|
| 202 |
+
#print('HERE!')
|
| 203 |
+
# Reduce on the right, expanding on the left
|
| 204 |
+
while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
|
| 205 |
+
idx_before = max(0,idx_before-1)
|
| 206 |
+
idx_ahead = min(idx+2, idx_ahead-1)
|
| 207 |
+
#print(label_tokens[idx_before:idx_ahead])
|
| 208 |
+
# Now just reduce on the left
|
| 209 |
+
while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
|
| 210 |
+
idx_before = min(idx-1,idx_before+2)
|
| 211 |
+
#print(label_tokens[idx_before:idx_ahead])
|
| 212 |
+
|
| 213 |
+
span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])
|
| 214 |
+
# First token of the label is UNK
|
| 215 |
+
if idx == 1 and label_tokens[0] == '▁':
|
| 216 |
+
#print('Label begins with unks')
|
| 217 |
+
to_replace = '^' + re.escape(span).replace(
|
| 218 |
+
re.escape(unk_token_to_string),
|
| 219 |
+
'.+?'
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
replaced_span = re.search(
|
| 223 |
+
to_replace,
|
| 224 |
+
label
|
| 225 |
+
)[0]
|
| 226 |
+
self.unknowns[-1][span.strip()] = replaced_span
|
| 227 |
+
# Last token of the label is UNK
|
| 228 |
+
elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
|
| 229 |
+
#print('Label ends with unks')
|
| 230 |
+
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
|
| 231 |
+
pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
|
| 232 |
+
to_replace = re.escape(span).replace(
|
| 233 |
+
re.escape(unk_token_to_string),
|
| 234 |
+
f'[^{re.escape(pre_idx)}]+?'
|
| 235 |
+
) + '$'
|
| 236 |
+
|
| 237 |
+
if pre_idx.strip() == '':
|
| 238 |
+
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
|
| 239 |
+
|
| 240 |
+
replaced_span = re.search(
|
| 241 |
+
to_replace,
|
| 242 |
+
label
|
| 243 |
+
)[0]
|
| 244 |
+
self.unknowns[-1][span.strip()] = replaced_span
|
| 245 |
+
|
| 246 |
+
# A token in-between the label is UNK
|
| 247 |
+
else:
|
| 248 |
+
#print('Label has unks in the middle')
|
| 249 |
+
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
|
| 250 |
+
|
| 251 |
+
to_replace = re.escape(span).replace(
|
| 252 |
+
re.escape(unk_token_to_string),
|
| 253 |
+
f'[^{re.escape(pre_idx)}]+?'
|
| 254 |
+
)
|
| 255 |
+
#If there is nothing behind the ??, because it is in the middle but the previous token is also
|
| 256 |
+
#a ??, then we would end up with to_replace beginning with [^], which we can't have
|
| 257 |
+
if pre_idx.strip() == '':
|
| 258 |
+
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
|
| 259 |
+
|
| 260 |
+
replaced_span = re.search(
|
| 261 |
+
to_replace,
|
| 262 |
+
label
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if replaced_span:
|
| 266 |
+
span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
|
| 267 |
+
self.unknowns[-1][span] = replaced_span[0]
|
| 268 |
+
|
| 269 |
+
def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
|
| 270 |
+
# Loop through in case the labels are repeated, maximum of three times
|
| 271 |
+
while '<unk>' in sentence and loop_n > 0:
|
| 272 |
+
loop_n -= 1
|
| 273 |
+
for unknowns in self.unknowns:
|
| 274 |
+
for k,v in unknowns.items():
|
| 275 |
+
# Leave to replace all-unk labels at the last pass
|
| 276 |
+
if k == '<unk>' and loop_n > 0:
|
| 277 |
+
continue
|
| 278 |
+
# In case it is because the first letter of the sentence has been uppercased
|
| 279 |
+
if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
|
| 280 |
+
k = k[0].upper() + k[1:]
|
| 281 |
+
v = v[0].upper() + v[1:]
|
| 282 |
+
# In case it is because a double space is found where it should not be
|
| 283 |
+
elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
|
| 284 |
+
k = re.sub(r'\s+', ' ', k)
|
| 285 |
+
#print(k,'/',v,'/',sentence)
|
| 286 |
+
sentence = sentence.replace(k.strip(),v.strip(),1)
|
| 287 |
+
#sentence = re.sub(k, v, sentence)
|
| 288 |
+
# Removing final doublespaces
|
| 289 |
+
sentence = re.sub(r'\s+', ' ', sentence).strip()
|
| 290 |
+
# Removing spaces before punctuation
|
| 291 |
+
sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
|
| 292 |
+
if empty_after:
|
| 293 |
+
self.unknowns = []
|
| 294 |
+
return sentence
|
| 295 |
+
|
| 296 |
+
if __name__ == '__main__':
|
| 297 |
+
|
| 298 |
+
verb_module = VerbModule()
|
| 299 |
+
verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
|
| 300 |
+
print(verbs)
|
utils/wikidata_utils.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import uuid
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import requests
|
| 7 |
+
import traceback
|
| 8 |
+
import pdb
|
| 9 |
+
import math
|
| 10 |
+
import ast
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import pickle
|
| 13 |
+
from qwikidata.linked_data_interface import get_entity_dict_from_api
|
| 14 |
+
from qwikidata.sparql import return_sparql_query_results
|
| 15 |
+
|
| 16 |
+
from urllib3.exceptions import MaxRetryError, ConnectionError
|
| 17 |
+
from qwikidata.linked_data_interface import LdiResponseNotOk
|
| 18 |
+
|
| 19 |
+
import hashlib
|
| 20 |
+
|
| 21 |
+
class CachedWikidataAPI():
|
| 22 |
+
|
| 23 |
+
def __init__(self, cache_path = 'entity_cache.p', save_every_x_queries=1):
|
| 24 |
+
self.save_every_x_queries = save_every_x_queries
|
| 25 |
+
self.x_queries_passed = 0
|
| 26 |
+
self.languages = ['en','fr','es','pt','pt-br','it','de']
|
| 27 |
+
self.cache_path = cache_path
|
| 28 |
+
try:
|
| 29 |
+
with open(self.cache_path,'rb') as f:
|
| 30 |
+
self.entity_cache = pickle.load(f)
|
| 31 |
+
except FileNotFoundError:
|
| 32 |
+
self.entity_cache = {}
|
| 33 |
+
|
| 34 |
+
def get_unique_id_from_str(self, my_str):
|
| 35 |
+
return hashlib.md5(str.encode(my_str)).hexdigest()
|
| 36 |
+
|
| 37 |
+
def save_entity_cache(self, force=False):
|
| 38 |
+
if force:
|
| 39 |
+
self.x_queries_passed = self.save_every_x_queries
|
| 40 |
+
self.x_queries_passed = self.x_queries_passed+1
|
| 41 |
+
if self.x_queries_passed >= self.save_every_x_queries:
|
| 42 |
+
with open(self.cache_path,'wb') as f:
|
| 43 |
+
pickle.dump(self.entity_cache,f)
|
| 44 |
+
self.x_queries_passed = 0
|
| 45 |
+
|
| 46 |
+
def get_entity(self, item_id):
|
| 47 |
+
if item_id in self.entity_cache:
|
| 48 |
+
return self.entity_cache[item_id]
|
| 49 |
+
while True:
|
| 50 |
+
try:
|
| 51 |
+
entity = get_entity_dict_from_api(item_id)
|
| 52 |
+
self.entity_cache[item_id] = entity
|
| 53 |
+
self.save_entity_cache()
|
| 54 |
+
return entity
|
| 55 |
+
except (ConnectionError, MaxRetryError) as e:
|
| 56 |
+
#traceback.print_exc()
|
| 57 |
+
time.sleep(1)
|
| 58 |
+
continue
|
| 59 |
+
except LdiResponseNotOk:
|
| 60 |
+
#traceback.print_exc()
|
| 61 |
+
self.entity_cache[item_id] = 'deleted'
|
| 62 |
+
self.save_entity_cache()
|
| 63 |
+
return 'deleted'
|
| 64 |
+
|
| 65 |
+
def get_label(self, item, non_language_set=False):
|
| 66 |
+
if type(item) == str:
|
| 67 |
+
entity = self.get_entity(item)
|
| 68 |
+
if entity == 'deleted':
|
| 69 |
+
return (entity, 'none')
|
| 70 |
+
labels = entity['labels' if 'labels' in entity else 'lemmas']
|
| 71 |
+
elif type(item) == dict:
|
| 72 |
+
if 'labels' in item:
|
| 73 |
+
labels = item['labels']
|
| 74 |
+
elif 'lemmas' in item:
|
| 75 |
+
labels = item['lemmas']
|
| 76 |
+
for l in self.languages:
|
| 77 |
+
if l in labels:
|
| 78 |
+
return (labels[l]['value'], l)
|
| 79 |
+
if non_language_set:
|
| 80 |
+
all_labels = list(labels.keys())
|
| 81 |
+
if len(all_labels)>0:
|
| 82 |
+
return (labels[all_labels[0]]['value'], all_labels[0])
|
| 83 |
+
return ('no-label', 'none')
|
| 84 |
+
|
| 85 |
+
def get_desc(self, item, non_language_set=False):
|
| 86 |
+
if type(item) == str:
|
| 87 |
+
entity = self.get_entity(item)
|
| 88 |
+
if entity == 'deleted':
|
| 89 |
+
return (entity, 'none')
|
| 90 |
+
descriptions = entity['descriptions']
|
| 91 |
+
elif type(item) == dict:
|
| 92 |
+
if 'descriptions' in item:
|
| 93 |
+
descriptions = item['descriptions']
|
| 94 |
+
for l in self.languages:
|
| 95 |
+
if l in descriptions:
|
| 96 |
+
return (descriptions[l]['value'], l)
|
| 97 |
+
if non_language_set:
|
| 98 |
+
all_descriptions = list(descriptions.keys())
|
| 99 |
+
if len(all_descriptions)>0:
|
| 100 |
+
return (descriptions[all_descriptions[0]]['value'], all_descriptions[0])
|
| 101 |
+
return ('no-desc', 'none')
|
| 102 |
+
|
| 103 |
+
def get_alias(self, item, non_language_set=False):
|
| 104 |
+
if type(item) == str:
|
| 105 |
+
entity = self.get_entity(item)
|
| 106 |
+
if entity == 'deleted':
|
| 107 |
+
return ([entity], 'none')
|
| 108 |
+
aliases = entity['aliases']
|
| 109 |
+
elif type(item) == dict:
|
| 110 |
+
if 'aliases' in item:
|
| 111 |
+
aliases = item['aliases']
|
| 112 |
+
for l in self.languages:
|
| 113 |
+
if l in aliases:
|
| 114 |
+
return ([alias['value'] for alias in aliases[l]], l)
|
| 115 |
+
if non_language_set:
|
| 116 |
+
all_aliases = list(aliases.keys())
|
| 117 |
+
if len(all_aliases)>0:
|
| 118 |
+
return (aliases[all_aliases[0]]['value'], all_aliases[0])
|
| 119 |
+
return ([alias['value'] for alias in aliases[all_aliases[0]]], all_aliases[0])
|
| 120 |
+
return ('no-alias', 'none')
|
| 121 |
+
|
| 122 |
+
def get_datatype(self, item):
|
| 123 |
+
try:
|
| 124 |
+
if type(item) == str:
|
| 125 |
+
entity = self.get_entity(item)
|
| 126 |
+
if entity == 'deleted':
|
| 127 |
+
return entity
|
| 128 |
+
datatype = entity['datatype']
|
| 129 |
+
elif type(item) == dict:
|
| 130 |
+
datatype = item['datatype']
|
| 131 |
+
return datatype
|
| 132 |
+
except KeyError:
|
| 133 |
+
return 'none'
|
| 134 |
+
|
| 135 |
+
def get_claim_values_of(self, item, property_id):
|
| 136 |
+
if type(item) == str:
|
| 137 |
+
entity = self.get_entity(item)
|
| 138 |
+
if entity == 'deleted':
|
| 139 |
+
return entity
|
| 140 |
+
claims = entity['claims']
|
| 141 |
+
elif type(item) == dict:
|
| 142 |
+
claims = item['claims']
|
| 143 |
+
if property_id in claims:
|
| 144 |
+
instance_of_claims = claims[property_id]
|
| 145 |
+
return [i['mainsnak']['datavalue']['value']['id'] for i in instance_of_claims]
|
| 146 |
+
else:
|
| 147 |
+
return []
|
| 148 |
+
|
| 149 |
+
def query_sparql_endpoint(self, sparql_query):
|
| 150 |
+
sparql_query_id = self.get_unique_id_from_str(sparql_query)
|
| 151 |
+
if sparql_query_id in self.entity_cache:
|
| 152 |
+
return self.entity_cache[sparql_query_id]
|
| 153 |
+
else:
|
| 154 |
+
wikidata_sparql_url = 'https://query.wikidata.org/sparql'
|
| 155 |
+
try:
|
| 156 |
+
while True:
|
| 157 |
+
res = requests.get(wikidata_sparql_url, params={"query": sparql_query, "format": "json"})
|
| 158 |
+
if res.status_code in (429,504):
|
| 159 |
+
time.sleep(1)
|
| 160 |
+
continue
|
| 161 |
+
elif res.status_code == 200:
|
| 162 |
+
res = res.json()
|
| 163 |
+
self.entity_cache[sparql_query_id] = res
|
| 164 |
+
self.save_entity_cache()
|
| 165 |
+
return res
|
| 166 |
+
else:
|
| 167 |
+
print(res.status_code)
|
| 168 |
+
raise Exception
|
| 169 |
+
except json.JSONDecodeError as e:
|
| 170 |
+
#pdb.set_trace()
|
| 171 |
+
print(res, res.__dict__)
|
| 172 |
+
raise e
|
| 173 |
+
|