Spaces:
Running
Running
Enzo Reis de Oliveira
commited on
Commit
·
b60e08a
1
Parent(s):
30f063f
Fixing again
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- smi-ted/inference/smi_ted_light/.gitattributes +2 -0
- smi-ted/inference/smi_ted_light/fast_transformers/__init__.py +15 -0
- smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py +128 -0
- smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py +20 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py +113 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py +116 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py +195 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py +66 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py +88 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py +95 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py +268 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py +257 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py +92 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py +101 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py +166 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py +17 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py +61 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py +126 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py +59 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py +139 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py +67 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py +550 -0
- smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py +78 -0
- smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py +115 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py +10 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/event.py +51 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py +92 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py +141 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py +12 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc +0 -0
smi-ted/inference/smi_ted_light/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
smi-ted/inference/smi_ted_light/fast_transformers/**/*.so filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
smi-ted/inference/smi_ted_light/fast_transformers/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Provide a library with fast transformer implementations."""
|
| 8 |
+
|
| 9 |
+
__author__ = "Angelos Katharopoulos, Apoorv Vyas"
|
| 10 |
+
__copyright__ = "Copyright (c) 2020 Idiap Research Institute"
|
| 11 |
+
__license__ = "MIT"
|
| 12 |
+
__maintainer__ = "Angelos Katharopoulos, Apoorv Vyas"
|
| 13 |
+
__email__ = "angelos.katharopoulos@idiap.ch, avyas@idiap.ch"
|
| 14 |
+
__url__ = "https://github.com/idiap/fast-transformers"
|
| 15 |
+
__version__ = "0.4.0"
|
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .aggregate_cpu import aggregate as aggregate_cpu, \
|
| 11 |
+
broadcast as broadcast_cpu
|
| 12 |
+
try:
|
| 13 |
+
from .aggregate_cuda import aggregate as aggregate_gpu, \
|
| 14 |
+
broadcast as broadcast_gpu
|
| 15 |
+
from .clustered_aggregate_cuda import \
|
| 16 |
+
clustered_broadcast as clustered_broadcast_gpu, \
|
| 17 |
+
clustered_aggregate as clustered_aggregate_gpu
|
| 18 |
+
|
| 19 |
+
except ImportError:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def aggregate(X, G, F, Y=None):
|
| 24 |
+
device = X.device
|
| 25 |
+
if Y is None:
|
| 26 |
+
Y = torch.zeros(
|
| 27 |
+
F.shape + (X.shape[-1],),
|
| 28 |
+
device=device,
|
| 29 |
+
dtype=X.dtype
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
Y.zero_()
|
| 33 |
+
|
| 34 |
+
if device.type == "cpu":
|
| 35 |
+
aggregate_cpu(X, G, F, Y)
|
| 36 |
+
else:
|
| 37 |
+
aggregate_gpu(X, G, F, Y)
|
| 38 |
+
|
| 39 |
+
return Y
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def broadcast(Y, G, F, X=None):
|
| 43 |
+
device = Y.device
|
| 44 |
+
if X is None:
|
| 45 |
+
X = torch.zeros(
|
| 46 |
+
G.shape + (Y.shape[-1],),
|
| 47 |
+
device=device,
|
| 48 |
+
dtype=Y.dtype
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if device.type == "cpu":
|
| 52 |
+
broadcast_cpu(Y, G, F, X)
|
| 53 |
+
else:
|
| 54 |
+
broadcast_gpu(Y, G, F, X)
|
| 55 |
+
|
| 56 |
+
return X
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Divide the cluster into groups of equal size
|
| 60 |
+
# as constrained by the shared memory
|
| 61 |
+
def set_group(C, E):
|
| 62 |
+
C_per_block = int(192 * 64 / (E+1))
|
| 63 |
+
G_min = (C + C_per_block - 1) // C_per_block
|
| 64 |
+
for G in range(G_min, C+1):
|
| 65 |
+
if C % G == 0:
|
| 66 |
+
return G
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def clustered_broadcast(Y, groups, counts, factors, X=None):
|
| 70 |
+
device = Y.device
|
| 71 |
+
if X is None:
|
| 72 |
+
X = torch.zeros(
|
| 73 |
+
groups.shape + (Y.shape[-1],),
|
| 74 |
+
device=device,
|
| 75 |
+
dtype=Y.dtype
|
| 76 |
+
)
|
| 77 |
+
if device.type == "cpu":
|
| 78 |
+
broadcast_cpu(Y, groups, factors, X)
|
| 79 |
+
else:
|
| 80 |
+
N, H, C, E = Y.shape
|
| 81 |
+
_, _, L, _ = X.shape
|
| 82 |
+
|
| 83 |
+
# Following are some booking keeping parameters to facilitate the
|
| 84 |
+
# broadcast kernel that takes advantage of clustering
|
| 85 |
+
# More information can be found in the cuda file
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
threads = 256
|
| 88 |
+
G = set_group(C, E)
|
| 89 |
+
group_counts = counts.view(N, H, G, -1).sum(-1)
|
| 90 |
+
block_counts = (group_counts + threads - 1) // threads
|
| 91 |
+
total_blocks = block_counts.sum().item()
|
| 92 |
+
indx_maps = torch.ones(
|
| 93 |
+
(total_blocks, 5),
|
| 94 |
+
device=X.device,
|
| 95 |
+
dtype=torch.int32
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
clustered_broadcast_gpu(
|
| 99 |
+
Y,
|
| 100 |
+
groups,
|
| 101 |
+
factors,
|
| 102 |
+
X,
|
| 103 |
+
block_counts.int(),
|
| 104 |
+
group_counts.int(),
|
| 105 |
+
threads,
|
| 106 |
+
G,
|
| 107 |
+
total_blocks,
|
| 108 |
+
indx_maps
|
| 109 |
+
)
|
| 110 |
+
return X
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def clustered_aggregate(X, G, F, lengths, Y=None):
|
| 114 |
+
device = X.device
|
| 115 |
+
if Y is None:
|
| 116 |
+
Y = torch.zeros(
|
| 117 |
+
F.shape + (X.shape[-1],),
|
| 118 |
+
device=device,
|
| 119 |
+
dtype=X.dtype
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
Y.zero_()
|
| 123 |
+
|
| 124 |
+
if device.type == "cpu":
|
| 125 |
+
aggregate_cpu(X, G, F, Y)
|
| 126 |
+
else:
|
| 127 |
+
clustered_aggregate_gpu(X, G, F, lengths, Y)
|
| 128 |
+
return Y
|
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bccb1a374d4649aaef6361cc41c9ffb471086464cc07a0d6d21c5b65adb0711
|
| 3 |
+
size 138248
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implementations of different types of attention mechanisms."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from .attention_layer import AttentionLayer
|
| 11 |
+
from .full_attention import FullAttention
|
| 12 |
+
from .linear_attention import LinearAttention
|
| 13 |
+
#from .causal_linear_attention import CausalLinearAttention
|
| 14 |
+
#from .clustered_attention import ClusteredAttention
|
| 15 |
+
#from .improved_clustered_attention import ImprovedClusteredAttention
|
| 16 |
+
#from .reformer_attention import ReformerAttention
|
| 17 |
+
#from .conditional_full_attention import ConditionalFullAttention
|
| 18 |
+
#from .exact_topk_attention import ExactTopKAttention
|
| 19 |
+
#from .improved_clustered_causal_attention import ImprovedClusteredCausalAttention
|
| 20 |
+
#from .local_attention import LocalAttention
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (502 Bytes). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc
ADDED
|
Binary file (3.32 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""The base attention layer performs all the query key value projections and
|
| 8 |
+
output projections leaving the implementation of the attention to the inner
|
| 9 |
+
attention module.
|
| 10 |
+
|
| 11 |
+
The transformer layers, however, are agnostic of the attention implementation
|
| 12 |
+
and any layer that implements the same interface can substitute for the
|
| 13 |
+
attention layer.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from torch.nn import Linear, Module
|
| 17 |
+
|
| 18 |
+
from ..events import EventDispatcher, QKVEvent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AttentionLayer(Module):
|
| 22 |
+
"""Implement the attention layer. Namely project the inputs to multi-head
|
| 23 |
+
queries, keys and values, call the attention implementation and then
|
| 24 |
+
reproject the output.
|
| 25 |
+
|
| 26 |
+
It can be thought of as a decorator (see decorator design patter) of an
|
| 27 |
+
attention layer.
|
| 28 |
+
|
| 29 |
+
Arguments
|
| 30 |
+
---------
|
| 31 |
+
attention: Specific inner attention implementation that just computes a
|
| 32 |
+
weighted average of values given a similarity of queries and
|
| 33 |
+
keys.
|
| 34 |
+
d_model: The input feature dimensionality
|
| 35 |
+
n_heads: The number of heads for the multi head attention
|
| 36 |
+
d_keys: The dimensionality of the keys/queries
|
| 37 |
+
(default: d_model/n_heads)
|
| 38 |
+
d_values: The dimensionality of the values (default: d_model/n_heads)
|
| 39 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 40 |
+
module for dispatching events (default: the default
|
| 41 |
+
global dispatcher)
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
| 44 |
+
d_values=None, event_dispatcher=""):
|
| 45 |
+
super(AttentionLayer, self).__init__()
|
| 46 |
+
|
| 47 |
+
# Fill d_keys and d_values
|
| 48 |
+
d_keys = d_keys or (d_model//n_heads)
|
| 49 |
+
d_values = d_values or (d_model//n_heads)
|
| 50 |
+
|
| 51 |
+
self.inner_attention = attention
|
| 52 |
+
self.query_projection = Linear(d_model, d_keys * n_heads)
|
| 53 |
+
self.key_projection = Linear(d_model, d_keys * n_heads)
|
| 54 |
+
self.value_projection = Linear(d_model, d_values * n_heads)
|
| 55 |
+
self.out_projection = Linear(d_values * n_heads, d_model)
|
| 56 |
+
self.n_heads = n_heads
|
| 57 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 58 |
+
|
| 59 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 60 |
+
key_lengths):
|
| 61 |
+
"""Apply attention to the passed in queries/keys/values after
|
| 62 |
+
projecting them to multiple heads.
|
| 63 |
+
|
| 64 |
+
In the argument description we make use of the following sizes
|
| 65 |
+
|
| 66 |
+
- N: the batch size
|
| 67 |
+
- L: The maximum length of the queries
|
| 68 |
+
- S: The maximum length of the keys (the actual length per sequence
|
| 69 |
+
is given by the length mask)
|
| 70 |
+
- D: The input feature dimensionality passed in the constructor as
|
| 71 |
+
'd_model'
|
| 72 |
+
|
| 73 |
+
Arguments
|
| 74 |
+
---------
|
| 75 |
+
queries: (N, L, D) The tensor containing the queries
|
| 76 |
+
keys: (N, S, D) The tensor containing the keys
|
| 77 |
+
values: (N, S, D) The tensor containing the values
|
| 78 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
| 79 |
+
query can attend to
|
| 80 |
+
query_lengths: An implementation of BaseMask that encodes how
|
| 81 |
+
many queries each sequence in the batch consists of
|
| 82 |
+
key_lengths: An implementation of BaseMask that encodes how
|
| 83 |
+
many queries each sequence in the batch consists of
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
The new value for each query as a tensor of shape (N, L, D).
|
| 88 |
+
"""
|
| 89 |
+
# Extract the dimensions into local variables
|
| 90 |
+
N, L, _ = queries.shape
|
| 91 |
+
_, S, _ = keys.shape
|
| 92 |
+
H = self.n_heads
|
| 93 |
+
|
| 94 |
+
# Project the queries/keys/values
|
| 95 |
+
queries = self.query_projection(queries).view(N, L, H, -1)
|
| 96 |
+
keys = self.key_projection(keys).view(N, S, H, -1)
|
| 97 |
+
values = self.value_projection(values).view(N, S, H, -1)
|
| 98 |
+
|
| 99 |
+
# Let the world know of the qkv
|
| 100 |
+
self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
|
| 101 |
+
|
| 102 |
+
# Compute the attention
|
| 103 |
+
new_values = self.inner_attention(
|
| 104 |
+
queries,
|
| 105 |
+
keys,
|
| 106 |
+
values,
|
| 107 |
+
attn_mask,
|
| 108 |
+
query_lengths,
|
| 109 |
+
key_lengths
|
| 110 |
+
).view(N, L, -1)
|
| 111 |
+
|
| 112 |
+
# Project the output and return
|
| 113 |
+
return self.out_projection(new_values)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement causally masked linear attention."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn import Module
|
| 11 |
+
|
| 12 |
+
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
|
| 13 |
+
EventDispatcherInstance
|
| 14 |
+
from ..events import EventDispatcher
|
| 15 |
+
from ..causal_product import causal_dot_product
|
| 16 |
+
from ..feature_maps import elu_feature_map
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def causal_linear(Q, K, V):
|
| 20 |
+
Q = Q.permute(0,2,1,3).contiguous()
|
| 21 |
+
K = K.permute(0,2,1,3).contiguous()
|
| 22 |
+
V = V.permute(0,2,1,3).contiguous()
|
| 23 |
+
V_new = causal_dot_product(Q, K, V)
|
| 24 |
+
return V_new.permute(0,2,1,3).contiguous()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CausalLinearAttention(Module):
|
| 28 |
+
"""Implement causally masked attention using dot product of feature maps in
|
| 29 |
+
O(N D^2) complexity.
|
| 30 |
+
|
| 31 |
+
See fast_transformers.attention.linear_attention.LinearAttention for the
|
| 32 |
+
general concept of replacing the softmax with feature maps. In addition to
|
| 33 |
+
that, we also make use of the fact that causal masking is a triangular mask
|
| 34 |
+
which allows us to apply the masking and still compute the attention in O(N
|
| 35 |
+
D^2) complexity.
|
| 36 |
+
|
| 37 |
+
Arguments
|
| 38 |
+
---------
|
| 39 |
+
feature_map: callable, a callable that applies the feature map to the
|
| 40 |
+
last dimension of a tensor (default: elu(x)+1)
|
| 41 |
+
eps: float, a small number to ensure the numerical stability of the
|
| 42 |
+
denominator (default: 1e-6)
|
| 43 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 44 |
+
module for dispatching events (default: the default
|
| 45 |
+
global dispatcher)
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
|
| 48 |
+
event_dispatcher=""):
|
| 49 |
+
super(CausalLinearAttention, self).__init__()
|
| 50 |
+
self.feature_map = (
|
| 51 |
+
feature_map(query_dimensions) if feature_map else
|
| 52 |
+
elu_feature_map(query_dimensions)
|
| 53 |
+
)
|
| 54 |
+
self.eps = eps
|
| 55 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 56 |
+
|
| 57 |
+
def _make_sizes_compatible(self, Q, K):
|
| 58 |
+
"""Either slice or pad K in case that the sizes do not match between Q
|
| 59 |
+
and K."""
|
| 60 |
+
N, L, H, E = Q.shape
|
| 61 |
+
_, S, _, _ = K.shape
|
| 62 |
+
if L == S:
|
| 63 |
+
return Q, K
|
| 64 |
+
|
| 65 |
+
if L < S:
|
| 66 |
+
return Q, K[:, :L, :, :]
|
| 67 |
+
|
| 68 |
+
if L > S:
|
| 69 |
+
return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1)
|
| 70 |
+
|
| 71 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 72 |
+
key_lengths):
|
| 73 |
+
# Apply the feature map to the queries and keys
|
| 74 |
+
self.feature_map.new_feature_map(queries.device)
|
| 75 |
+
Q = self.feature_map.forward_queries(queries)
|
| 76 |
+
K = self.feature_map.forward_keys(keys)
|
| 77 |
+
|
| 78 |
+
# Apply the key padding mask and make sure the attn_mask is a
|
| 79 |
+
# lower triangular causal mask
|
| 80 |
+
if not attn_mask.lower_triangular:
|
| 81 |
+
raise RuntimeError(("CausalLinearAttention only supports full "
|
| 82 |
+
"lower triangular masks"))
|
| 83 |
+
K = K * key_lengths.float_matrix[:, :, None, None]
|
| 84 |
+
|
| 85 |
+
# Ensure that Q and K have compatible sizes for the following
|
| 86 |
+
# computations, namely L == S
|
| 87 |
+
Q, K = self._make_sizes_compatible(Q, K)
|
| 88 |
+
|
| 89 |
+
# TODO: Shall we divide the Q and K with a relatively large number to
|
| 90 |
+
# avoid numerical instabilities in computing the denominator?
|
| 91 |
+
# We used to divide each with the max norm of all q and k but
|
| 92 |
+
# that seems relatively costly for a simple normalization.
|
| 93 |
+
|
| 94 |
+
# Compute the normalizers
|
| 95 |
+
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
|
| 96 |
+
|
| 97 |
+
# Compute the unnormalized result
|
| 98 |
+
V = causal_linear(
|
| 99 |
+
Q,
|
| 100 |
+
K,
|
| 101 |
+
values
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return V * Z[:, :, :, None]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Register the attention implementation so that it becomes available in our
|
| 108 |
+
# builders
|
| 109 |
+
AttentionRegistry.register(
|
| 110 |
+
"causal-linear", CausalLinearAttention,
|
| 111 |
+
[
|
| 112 |
+
("query_dimensions", Int),
|
| 113 |
+
("feature_map", Optional(Callable)),
|
| 114 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 115 |
+
]
|
| 116 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement clustered self attention."""
|
| 8 |
+
|
| 9 |
+
from math import sqrt
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.autograd
|
| 13 |
+
from torch.nn import Dropout, Module
|
| 14 |
+
from torch.nn.init import normal_
|
| 15 |
+
|
| 16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
| 17 |
+
Bool, EventDispatcherInstance
|
| 18 |
+
from ..events import EventDispatcher
|
| 19 |
+
from ..masking import FullMask
|
| 20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
| 21 |
+
from ..clustering.hamming import cluster
|
| 22 |
+
from ..hashing import compute_hashes
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _GroupQueries(torch.autograd.Function):
|
| 26 |
+
@staticmethod
|
| 27 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
| 28 |
+
factors = 1./counts.float()
|
| 29 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
| 30 |
+
ctx.save_for_backward(clusters, counts, factors)
|
| 31 |
+
|
| 32 |
+
return q_grouped
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def backward(ctx, grad_q_grouped):
|
| 36 |
+
clusters, counts, factors = ctx.saved_tensors
|
| 37 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
| 38 |
+
|
| 39 |
+
return grad_q, None, None, None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class _BroadcastValues(torch.autograd.Function):
|
| 43 |
+
@staticmethod
|
| 44 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
| 45 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
| 46 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
| 47 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
| 48 |
+
|
| 49 |
+
return V
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def backward(ctx, grad_v):
|
| 53 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
| 54 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
| 55 |
+
|
| 56 |
+
return grad_v_grouped, None, None, None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ClusteredAttention(Module):
|
| 60 |
+
"""Use LSH and clustering in the resulting Hamming space to group queries
|
| 61 |
+
that will have minimal L2 distance from each other.
|
| 62 |
+
|
| 63 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
| 64 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
| 65 |
+
Q_c.
|
| 66 |
+
|
| 67 |
+
We now use to the centroids Q_c to compute the attention using:
|
| 68 |
+
|
| 69 |
+
V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V).
|
| 70 |
+
|
| 71 |
+
Now the computed values V'_c are "broadcasted" back to the query members
|
| 72 |
+
of the corresponding cluster.
|
| 73 |
+
|
| 74 |
+
Arguments
|
| 75 |
+
---------
|
| 76 |
+
clusters: How many clusters to group the queries into
|
| 77 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
| 78 |
+
bits: How many bits to use for the hash (default: 32)
|
| 79 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
| 80 |
+
If false, hamming distance proportional to cosine distance
|
| 81 |
+
(default: True)
|
| 82 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 83 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 84 |
+
runtime)
|
| 85 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 86 |
+
(default: 0.1)
|
| 87 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 88 |
+
module for dispatching events (default: the default
|
| 89 |
+
global dispatcher)
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
| 92 |
+
hash_bias=True, softmax_temp=None, attention_dropout=0.1,
|
| 93 |
+
event_dispatcher=""):
|
| 94 |
+
super(ClusteredAttention, self).__init__()
|
| 95 |
+
self.clusters = clusters
|
| 96 |
+
self.iterations = iterations
|
| 97 |
+
self.bits = bits
|
| 98 |
+
self.hash_bias = hash_bias
|
| 99 |
+
self.softmax_temp = softmax_temp
|
| 100 |
+
self.dropout = Dropout(attention_dropout)
|
| 101 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 102 |
+
|
| 103 |
+
def _create_query_groups(self, Q, query_lengths):
|
| 104 |
+
N, H, L, E = Q.shape
|
| 105 |
+
|
| 106 |
+
# Compute the hashes for all the queries
|
| 107 |
+
planes = Q.new_empty((self.bits, E+1))
|
| 108 |
+
normal_(planes)
|
| 109 |
+
if not self.hash_bias:
|
| 110 |
+
planes[:, -1] = 0
|
| 111 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
| 112 |
+
|
| 113 |
+
# Cluster the hashes and return the cluster index per query
|
| 114 |
+
clusters, counts = cluster(
|
| 115 |
+
hashes,
|
| 116 |
+
query_lengths._lengths.int(),
|
| 117 |
+
clusters=self.clusters,
|
| 118 |
+
iterations=self.iterations,
|
| 119 |
+
bits=self.bits
|
| 120 |
+
)
|
| 121 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
| 122 |
+
return (sorted_clusters, counts), sorted_indx
|
| 123 |
+
|
| 124 |
+
def _group_queries(self, Q, groups, lengths):
|
| 125 |
+
"""Aggregate the Qs based on the index of cluster they belong to. Make
|
| 126 |
+
sure to allow for gradient propagation backwards from the grouped
|
| 127 |
+
queries to each query."""
|
| 128 |
+
q_grouped = _GroupQueries.apply(Q, *groups, lengths)
|
| 129 |
+
return q_grouped
|
| 130 |
+
|
| 131 |
+
def _broadcast_values(self, V, groups, lengths):
|
| 132 |
+
"""Broadcast the values back to the correct positions but make sure
|
| 133 |
+
that the gradient flows properly."""
|
| 134 |
+
V_new = _BroadcastValues.apply(V.contiguous(), *groups, lengths)
|
| 135 |
+
return V_new
|
| 136 |
+
|
| 137 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 138 |
+
key_lengths):
|
| 139 |
+
# Make sure that there is no attention mask
|
| 140 |
+
assert attn_mask.all_ones, ("Clustered attention cannot use an "
|
| 141 |
+
"arbitrary attention mask.")
|
| 142 |
+
|
| 143 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
| 144 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
| 145 |
+
values = values.permute(0,2,1,3).contiguous()
|
| 146 |
+
|
| 147 |
+
N, H, L, E = queries.shape
|
| 148 |
+
_, _, S, D = values.shape
|
| 149 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 150 |
+
|
| 151 |
+
# Cluster the queries into groups
|
| 152 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
| 153 |
+
# Re-organize queries so that first group belong to first cluster
|
| 154 |
+
# next to second cluster and so on. This improves kernel implementations.
|
| 155 |
+
# Note that this step is introduced after NeurIPS submission and
|
| 156 |
+
# now the complexity is O(N log(N)).
|
| 157 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
| 158 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 159 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
| 160 |
+
|
| 161 |
+
# Aggregate the re-arranged queries.
|
| 162 |
+
Q_grouped = self._group_queries(s_queries, groups, query_lengths._lengths.int())
|
| 163 |
+
# Compute the attention
|
| 164 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
| 165 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
| 166 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
| 167 |
+
V = torch.einsum("nhls,nhsd->nhld", A, values)
|
| 168 |
+
|
| 169 |
+
# Broadcast grouped attention
|
| 170 |
+
V_broadcast = self._broadcast_values(V, groups, query_lengths._lengths.int())
|
| 171 |
+
|
| 172 |
+
# Reverse the previous mapping
|
| 173 |
+
rev_indx = torch.argsort(sorted_indx, dim=-1)
|
| 174 |
+
q_rev_flat = (rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 175 |
+
V_new = V_broadcast.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
| 176 |
+
V_new = V_new.permute(0, 2, 1, 3).contiguous()
|
| 177 |
+
return V_new
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Register the attention implementation so that it becomes available in our
|
| 183 |
+
# builders
|
| 184 |
+
AttentionRegistry.register(
|
| 185 |
+
"clustered", ClusteredAttention,
|
| 186 |
+
[
|
| 187 |
+
("clusters", Int),
|
| 188 |
+
("iterations", Optional(Int, 10)),
|
| 189 |
+
("bits", Optional(Int, 63)),
|
| 190 |
+
("hash_bias", Optional(Bool, True)),
|
| 191 |
+
("softmax_temp", Optional(Float)),
|
| 192 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 193 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 194 |
+
]
|
| 195 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement a self attention that delegates to full attention or another
|
| 8 |
+
attention depending on the input sequence length."""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
|
| 13 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
| 14 |
+
EventDispatcherInstance
|
| 15 |
+
from ..events import EventDispatcher
|
| 16 |
+
from .full_attention import FullAttention
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ConditionalFullAttention(Module):
|
| 20 |
+
""""Delegate to full attention if the input sequence is short.
|
| 21 |
+
|
| 22 |
+
Arguments
|
| 23 |
+
---------
|
| 24 |
+
other_attention: Use the passed attention module if the sequence is
|
| 25 |
+
longer than 'length_limit'.
|
| 26 |
+
length_limit: An integer denoting the maximum sequence length to
|
| 27 |
+
consider.
|
| 28 |
+
softmax_temp: See fast_transformers.attention.full_attention.
|
| 29 |
+
attention_dropout: See fast_transformers.attention.full_attention.
|
| 30 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 31 |
+
module for dispatching events (default: the default
|
| 32 |
+
global dispatcher)
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, other_attention, length_limit=512, softmax_temp=None,
|
| 35 |
+
attention_dropout=0.1, event_dispatcher=""):
|
| 36 |
+
super(ConditionalFullAttention, self).__init__()
|
| 37 |
+
self.full_attention = FullAttention(softmax_temp, attention_dropout)
|
| 38 |
+
self.other_attention = other_attention
|
| 39 |
+
self.length_limit = length_limit
|
| 40 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 41 |
+
|
| 42 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 43 |
+
key_lengths):
|
| 44 |
+
# Extract some shapes to compare with the length limit
|
| 45 |
+
L = queries.shape[1]
|
| 46 |
+
S = values.shape[1]
|
| 47 |
+
|
| 48 |
+
if L > self.length_limit or S > self.length_limit:
|
| 49 |
+
return self.other_attention(queries, keys, values, attn_mask,
|
| 50 |
+
query_lengths, key_lengths)
|
| 51 |
+
else:
|
| 52 |
+
return self.full_attention(queries, keys, values, attn_mask,
|
| 53 |
+
query_lengths, key_lengths)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Register the attention implementation so that it becomes available in our
|
| 57 |
+
# builders
|
| 58 |
+
AttentionRegistry.register(
|
| 59 |
+
"conditional-full", ConditionalFullAttention,
|
| 60 |
+
[
|
| 61 |
+
("length_limit", Optional(Int, 512)),
|
| 62 |
+
("softmax_temp", Optional(Float)),
|
| 63 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 64 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 65 |
+
]
|
| 66 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement the oracle top-k attention. The top-k keys are exact ones.
|
| 8 |
+
MultiHeadAttention module. Note that this module is to be used in conjuction
|
| 9 |
+
with the AttentionLayer in order to work."""
|
| 10 |
+
|
| 11 |
+
from math import sqrt
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.nn import Dropout, Module
|
| 15 |
+
|
| 16 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
| 17 |
+
EventDispatcherInstance
|
| 18 |
+
from ..events import EventDispatcher
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ExactTopKAttention(Module):
|
| 22 |
+
"""Implement the oracle top-k softmax attention.
|
| 23 |
+
|
| 24 |
+
Arguments
|
| 25 |
+
---------
|
| 26 |
+
top-k: The top k keys to attend to (default: 32)
|
| 27 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 28 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 29 |
+
runtime)
|
| 30 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 31 |
+
(default: 0.1)
|
| 32 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 33 |
+
module for dispatching events (default: the default
|
| 34 |
+
global dispatcher)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, topk=32, softmax_temp=None, attention_dropout=0.1,
|
| 37 |
+
event_dispatcher=""):
|
| 38 |
+
super(ExactTopKAttention, self).__init__()
|
| 39 |
+
self.topk = topk
|
| 40 |
+
self.softmax_temp = softmax_temp
|
| 41 |
+
self.dropout = Dropout(attention_dropout)
|
| 42 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 43 |
+
|
| 44 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 45 |
+
key_lengths):
|
| 46 |
+
# Extract some shapes and compute the temperature
|
| 47 |
+
N, L, H, E = queries.shape
|
| 48 |
+
_, S, _, D = values.shape
|
| 49 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 50 |
+
|
| 51 |
+
# Compute the unnormalized attention and apply the masks
|
| 52 |
+
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
|
| 53 |
+
topk = min(self.topk, S)
|
| 54 |
+
|
| 55 |
+
if not attn_mask.all_ones:
|
| 56 |
+
QK = QK + attn_mask.additive_matrix
|
| 57 |
+
QK = QK + key_lengths.additive_matrix[:, None, None]
|
| 58 |
+
|
| 59 |
+
topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1)
|
| 60 |
+
mask = QK.new_ones(QK.shape) * float("-inf")
|
| 61 |
+
mask[
|
| 62 |
+
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
|
| 63 |
+
torch.arange(H, device=QK.device).view(1, H, 1, 1),
|
| 64 |
+
torch.arange(L, device=QK.device).view(1, 1, L, 1),
|
| 65 |
+
topk_idx,
|
| 66 |
+
] = 0.
|
| 67 |
+
|
| 68 |
+
QK = QK + mask
|
| 69 |
+
|
| 70 |
+
# Compute the attention and the weighted average
|
| 71 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
| 72 |
+
V = torch.einsum("nhls,nshd->nlhd", A, values)
|
| 73 |
+
|
| 74 |
+
# Make sure that what we return is contiguous
|
| 75 |
+
return V.contiguous()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Register the attention implementation so that it becomes available in our
|
| 79 |
+
# builders
|
| 80 |
+
AttentionRegistry.register(
|
| 81 |
+
"exact-topk", ExactTopKAttention,
|
| 82 |
+
[
|
| 83 |
+
("topk", Optional(Int, 32)),
|
| 84 |
+
("softmax_temp", Optional(Float)),
|
| 85 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 86 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 87 |
+
]
|
| 88 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement the full attention similar to the one implemented by PyTorch's
|
| 8 |
+
MultiHeadAttention module. Note that this module is to be used in conjuction
|
| 9 |
+
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
|
| 10 |
+
to work."""
|
| 11 |
+
|
| 12 |
+
from math import sqrt
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.nn import Dropout, Module
|
| 16 |
+
|
| 17 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, \
|
| 18 |
+
EventDispatcherInstance
|
| 19 |
+
from ..events import EventDispatcher, AttentionEvent
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FullAttention(Module):
|
| 23 |
+
"""Implement the scaled dot product attention with softmax.
|
| 24 |
+
|
| 25 |
+
Arguments
|
| 26 |
+
---------
|
| 27 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 28 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 29 |
+
runtime)
|
| 30 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 31 |
+
(default: 0.1)
|
| 32 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 33 |
+
module for dispatching events (default: the default
|
| 34 |
+
global dispatcher)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, softmax_temp=None, attention_dropout=0.1,
|
| 37 |
+
event_dispatcher=""):
|
| 38 |
+
super(FullAttention, self).__init__()
|
| 39 |
+
self.softmax_temp = softmax_temp
|
| 40 |
+
self.dropout = Dropout(attention_dropout)
|
| 41 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 42 |
+
|
| 43 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 44 |
+
key_lengths):
|
| 45 |
+
"""Implements the multihead softmax attention.
|
| 46 |
+
|
| 47 |
+
Arguments
|
| 48 |
+
---------
|
| 49 |
+
queries: (N, L, H, E) The tensor containing the queries
|
| 50 |
+
keys: (N, S, H, E) The tensor containing the keys
|
| 51 |
+
values: (N, S, H, D) The tensor containing the values
|
| 52 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
| 53 |
+
query can attend to
|
| 54 |
+
query_lengths: An implementation of BaseMask that encodes how
|
| 55 |
+
many queries each sequence in the batch consists of
|
| 56 |
+
key_lengths: An implementation of BaseMask that encodes how
|
| 57 |
+
many queries each sequence in the batch consists of
|
| 58 |
+
"""
|
| 59 |
+
# Extract some shapes and compute the temperature
|
| 60 |
+
N, L, H, E = queries.shape
|
| 61 |
+
_, S, _, D = values.shape
|
| 62 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 63 |
+
|
| 64 |
+
# Scale the queries instead of applying the softmax temperature to the
|
| 65 |
+
# dot products
|
| 66 |
+
queries = queries * softmax_temp
|
| 67 |
+
|
| 68 |
+
# Compute the unnormalized attention and apply the masks
|
| 69 |
+
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
|
| 70 |
+
if not attn_mask.all_ones:
|
| 71 |
+
QK = QK + attn_mask.additive_matrix
|
| 72 |
+
if not key_lengths.all_ones:
|
| 73 |
+
QK = QK + key_lengths.additive_matrix[:, None, None]
|
| 74 |
+
|
| 75 |
+
# Compute the attention and the weighted average
|
| 76 |
+
A = self.dropout(torch.softmax(QK, dim=-1))
|
| 77 |
+
V = torch.einsum("nhls,nshd->nlhd", A, values)
|
| 78 |
+
|
| 79 |
+
# Let the world know of the attention matrix
|
| 80 |
+
self.event_dispatcher.dispatch(AttentionEvent(self, A))
|
| 81 |
+
|
| 82 |
+
# Make sure that what we return is contiguous
|
| 83 |
+
return V.contiguous()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Register the attention implementation so that it becomes available in our
|
| 87 |
+
# builders
|
| 88 |
+
AttentionRegistry.register(
|
| 89 |
+
"full", FullAttention,
|
| 90 |
+
[
|
| 91 |
+
("softmax_temp", Optional(Float)),
|
| 92 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 93 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 94 |
+
]
|
| 95 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement improved clustered self attention."""
|
| 8 |
+
|
| 9 |
+
from math import sqrt
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.autograd
|
| 13 |
+
from torch.nn import Dropout, Module
|
| 14 |
+
from torch.nn.init import normal_
|
| 15 |
+
|
| 16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
| 17 |
+
Bool, EventDispatcherInstance
|
| 18 |
+
from ..events import EventDispatcher
|
| 19 |
+
from ..masking import FullMask
|
| 20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
| 21 |
+
from ..clustering.hamming import cluster
|
| 22 |
+
from ..hashing import compute_hashes
|
| 23 |
+
from ..sparse_product import sparse_dot_product, sparse_weighted_average
|
| 24 |
+
from ..sparse_product import clustered_sparse_dot_product, \
|
| 25 |
+
clustered_sparse_weighted_average
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _GroupQueries(torch.autograd.Function):
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
| 31 |
+
factors = 1./counts.float()
|
| 32 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
| 33 |
+
ctx.save_for_backward(clusters, counts, factors)
|
| 34 |
+
|
| 35 |
+
return q_grouped
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def backward(ctx, grad_q_grouped):
|
| 39 |
+
clusters, counts, factors = ctx.saved_tensors
|
| 40 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
| 41 |
+
|
| 42 |
+
return grad_q, None, None, None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _BroadcastValues(torch.autograd.Function):
|
| 46 |
+
@staticmethod
|
| 47 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
| 48 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
| 49 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
| 50 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
| 51 |
+
|
| 52 |
+
return V
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def backward(ctx, grad_v):
|
| 56 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
| 57 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
| 58 |
+
|
| 59 |
+
return grad_v_grouped, None, None, None, None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ImprovedClusteredAttention(Module):
|
| 63 |
+
"""
|
| 64 |
+
Immproved clustered attention approximation by recompution attention
|
| 65 |
+
for each query with the top-k keys for the corresponding cluster.
|
| 66 |
+
|
| 67 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
| 68 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
| 69 |
+
Q_c.
|
| 70 |
+
|
| 71 |
+
We now use to the centroids Q_c to identify the top-k keys with highest
|
| 72 |
+
dot products.
|
| 73 |
+
|
| 74 |
+
Subsequently, for each query we compute the sparse dot product with
|
| 75 |
+
the corresponding top-k keys to improve the attention approximation.
|
| 76 |
+
|
| 77 |
+
Arguments
|
| 78 |
+
---------
|
| 79 |
+
clusters: How many clusters to group the queries into
|
| 80 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
| 81 |
+
bits: How many bits to use for the hash (default: 32)
|
| 82 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
| 83 |
+
If false, hamming distance proportional to cosine distance
|
| 84 |
+
(default: True)
|
| 85 |
+
topk: Number of top-k keys to for improved approximation (default: 32)
|
| 86 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 87 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 88 |
+
runtime)
|
| 89 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 90 |
+
(default: 0.1)
|
| 91 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 92 |
+
module for dispatching events (default: the default
|
| 93 |
+
global dispatcher)
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
| 96 |
+
hash_bias=True, topk=32, softmax_temp=None,
|
| 97 |
+
attention_dropout=0.1, event_dispatcher=""):
|
| 98 |
+
super(ImprovedClusteredAttention, self).__init__()
|
| 99 |
+
self.clusters = clusters
|
| 100 |
+
self.iterations = iterations
|
| 101 |
+
self.bits = bits
|
| 102 |
+
self.hash_bias = hash_bias
|
| 103 |
+
self.topk = topk
|
| 104 |
+
self.softmax_temp = softmax_temp
|
| 105 |
+
self.dropout = Dropout(attention_dropout)
|
| 106 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 107 |
+
|
| 108 |
+
def _create_query_groups(self, Q, query_lengths):
|
| 109 |
+
N, H, L, E = Q.shape
|
| 110 |
+
|
| 111 |
+
# Compute the hashes for all the queries
|
| 112 |
+
planes = Q.new_empty((self.bits, E+1))
|
| 113 |
+
normal_(planes)
|
| 114 |
+
if not self.hash_bias:
|
| 115 |
+
planes[:, -1] = 0
|
| 116 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
| 117 |
+
|
| 118 |
+
# Cluster the hashes and return the cluster index per query
|
| 119 |
+
clusters, counts = cluster(
|
| 120 |
+
hashes,
|
| 121 |
+
query_lengths._lengths.int(),
|
| 122 |
+
clusters=self.clusters,
|
| 123 |
+
iterations=self.iterations,
|
| 124 |
+
bits=self.bits
|
| 125 |
+
)
|
| 126 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
| 127 |
+
return (sorted_clusters, counts), sorted_indx
|
| 128 |
+
|
| 129 |
+
def _topk_attention(self, Q, K, V,
|
| 130 |
+
clusters, counts,
|
| 131 |
+
topk, topk_values,
|
| 132 |
+
A_bottomk, softmax_temp,
|
| 133 |
+
query_lengths):
|
| 134 |
+
"""Return the attention with just the topk heads."""
|
| 135 |
+
# Extract some indices
|
| 136 |
+
N, H, L, E = Q.shape
|
| 137 |
+
_, _, S, _ = K.shape
|
| 138 |
+
_, _, C, k = topk.shape
|
| 139 |
+
|
| 140 |
+
# We need to pass the output tensor to initialize to 0
|
| 141 |
+
QK = clustered_sparse_dot_product(
|
| 142 |
+
Q, K, topk,
|
| 143 |
+
clusters, counts,
|
| 144 |
+
query_lengths._lengths.int()
|
| 145 |
+
)
|
| 146 |
+
# We need to mask the topk dot products if topk > input_length
|
| 147 |
+
QK = QK.masked_fill(
|
| 148 |
+
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k),
|
| 149 |
+
float("-inf")
|
| 150 |
+
)
|
| 151 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
| 152 |
+
assert A_bottomk.is_contiguous()
|
| 153 |
+
A_bottomk = clustered_broadcast(
|
| 154 |
+
A_bottomk.unsqueeze(3),
|
| 155 |
+
clusters,
|
| 156 |
+
counts,
|
| 157 |
+
torch.ones_like(counts, dtype=torch.float32)
|
| 158 |
+
)
|
| 159 |
+
A = A * (1.0 - A_bottomk)
|
| 160 |
+
A = self.dropout(A)
|
| 161 |
+
assert A.is_contiguous()
|
| 162 |
+
V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
|
| 163 |
+
|
| 164 |
+
return V_new
|
| 165 |
+
|
| 166 |
+
def _broadcast_values(self, V, clusters, counts, lengths):
|
| 167 |
+
"""Broadcast the values back to the correct positions but make sure
|
| 168 |
+
that the gradient flows properly."""
|
| 169 |
+
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
|
| 170 |
+
return V_new
|
| 171 |
+
|
| 172 |
+
def _bottomk_attention(self, QK, V, clusters, counts, query_lengths, topk, softmax_temp):
|
| 173 |
+
"""Return the attention with just the bottomk keys."""
|
| 174 |
+
N, H, C, S = QK.shape
|
| 175 |
+
|
| 176 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
| 177 |
+
mask = QK.new_ones(QK.shape)
|
| 178 |
+
mask[
|
| 179 |
+
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
|
| 180 |
+
torch.arange(H, device=QK.device).view(1, H, 1, 1),
|
| 181 |
+
torch.arange(C, device=QK.device).view(1, 1, C, 1),
|
| 182 |
+
topk,
|
| 183 |
+
] = 0
|
| 184 |
+
A = A * mask
|
| 185 |
+
A_bottomk = A.sum(-1)
|
| 186 |
+
A = self.dropout(A)
|
| 187 |
+
# Compute the values
|
| 188 |
+
V_new = torch.einsum("nhls,nhse->nhle", A, V)
|
| 189 |
+
# Broadcast the values back depending on the groups
|
| 190 |
+
V_new = self._broadcast_values(V_new, clusters, counts, query_lengths._lengths.int())
|
| 191 |
+
|
| 192 |
+
return V_new, A_bottomk
|
| 193 |
+
|
| 194 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 195 |
+
key_lengths):
|
| 196 |
+
# Make sure that there is no attention mask
|
| 197 |
+
assert attn_mask.all_ones, ("Improved-clustered attention cannot "
|
| 198 |
+
"use an arbitrary attention mask.")
|
| 199 |
+
|
| 200 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
| 201 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
| 202 |
+
values = values.permute(0,2,1,3).contiguous()
|
| 203 |
+
N, H, L, E = queries.shape
|
| 204 |
+
_, _, S, D = values.shape
|
| 205 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 206 |
+
|
| 207 |
+
# Cluster the queries into groups
|
| 208 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
| 209 |
+
clusters, counts = groups
|
| 210 |
+
|
| 211 |
+
# Re-organize queries so that first group belong to first cluster
|
| 212 |
+
# next to second cluster and so on. This improves kernel implementations.
|
| 213 |
+
# Note that this step is introduced after NeurIPS submission and
|
| 214 |
+
# now the complexity is O(N log(N)).
|
| 215 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
| 216 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 217 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
| 218 |
+
|
| 219 |
+
# Aggregate the re-arranged queries.
|
| 220 |
+
Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
|
| 221 |
+
# Compute the attention
|
| 222 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
| 223 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
| 224 |
+
topk_values, topk = torch.topk(QK, min(self.topk, S), sorted=False, dim=-1)
|
| 225 |
+
assert topk.is_contiguous()
|
| 226 |
+
|
| 227 |
+
# Now compute the attention with only the bottom keys
|
| 228 |
+
V_bottomk, A_bottomk = self._bottomk_attention(
|
| 229 |
+
QK, values,
|
| 230 |
+
clusters, counts,
|
| 231 |
+
query_lengths,
|
| 232 |
+
topk,
|
| 233 |
+
softmax_temp
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Now compute the attention with only the top keys
|
| 237 |
+
V_topk = self._topk_attention(
|
| 238 |
+
s_queries, keys, values,
|
| 239 |
+
clusters, counts,
|
| 240 |
+
topk, topk_values,
|
| 241 |
+
A_bottomk,
|
| 242 |
+
softmax_temp,
|
| 243 |
+
query_lengths
|
| 244 |
+
)
|
| 245 |
+
V_sorted_new = V_topk + V_bottomk
|
| 246 |
+
|
| 247 |
+
# Reverse the previous mapping
|
| 248 |
+
sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
|
| 249 |
+
q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 250 |
+
V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
| 251 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Register the attention implementation so that it becomes available in our
|
| 255 |
+
# builders
|
| 256 |
+
AttentionRegistry.register(
|
| 257 |
+
"improved-clustered", ImprovedClusteredAttention,
|
| 258 |
+
[
|
| 259 |
+
("clusters", Int),
|
| 260 |
+
("iterations", Optional(Int, 10)),
|
| 261 |
+
("bits", Optional(Int, 63)),
|
| 262 |
+
("hash_bias", Optional(Bool, True)),
|
| 263 |
+
("topk", Optional(Int, 32)),
|
| 264 |
+
("softmax_temp", Optional(Float)),
|
| 265 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 266 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 267 |
+
]
|
| 268 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement improved clustered causal self attention."""
|
| 8 |
+
|
| 9 |
+
from math import sqrt
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.autograd
|
| 13 |
+
from torch.nn import Dropout, Module
|
| 14 |
+
from torch.nn.init import normal_
|
| 15 |
+
|
| 16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
| 17 |
+
Bool, EventDispatcherInstance
|
| 18 |
+
from ..events import EventDispatcher
|
| 19 |
+
from ..masking import FullMask
|
| 20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
| 21 |
+
from ..clustering.hamming import cluster
|
| 22 |
+
from ..hashing import compute_hashes
|
| 23 |
+
from ..sparse_product import sparse_dot_product, sparse_weighted_average
|
| 24 |
+
from ..sparse_product import clustered_sparse_dot_product, \
|
| 25 |
+
clustered_sparse_weighted_average
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _GroupQueries(torch.autograd.Function):
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
| 31 |
+
factors = 1./counts.float()
|
| 32 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
| 33 |
+
ctx.save_for_backward(clusters, counts, factors)
|
| 34 |
+
|
| 35 |
+
return q_grouped
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def backward(ctx, grad_q_grouped):
|
| 39 |
+
clusters, counts, factors = ctx.saved_tensors
|
| 40 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
| 41 |
+
|
| 42 |
+
return grad_q, None, None, None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _BroadcastValues(torch.autograd.Function):
|
| 46 |
+
@staticmethod
|
| 47 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
| 48 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
| 49 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
| 50 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
| 51 |
+
|
| 52 |
+
return V
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def backward(ctx, grad_v):
|
| 56 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
| 57 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
| 58 |
+
|
| 59 |
+
return grad_v_grouped, None, None, None, None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ImprovedClusteredCausalAttention(Module):
|
| 63 |
+
"""
|
| 64 |
+
Immproved clustered causal attention approximation by recomputing attention
|
| 65 |
+
for each query with the top-k keys for the corresponding cluster.
|
| 66 |
+
|
| 67 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
| 68 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
| 69 |
+
Q_c.
|
| 70 |
+
|
| 71 |
+
We now use to the centroids Q_c to identify the top-k keys with highest
|
| 72 |
+
dot products.
|
| 73 |
+
|
| 74 |
+
Subsequently, for each query we compute the sparse dot product with
|
| 75 |
+
the corresponding top-k keys to improve the attention approximation.
|
| 76 |
+
|
| 77 |
+
Key difference with improved clustered attention is that we only use
|
| 78 |
+
top-k keys with causal mask, we do not compute attention on the
|
| 79 |
+
bottom-k keys.
|
| 80 |
+
|
| 81 |
+
Arguments
|
| 82 |
+
---------
|
| 83 |
+
clusters: How many clusters to group the queries into
|
| 84 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
| 85 |
+
bits: How many bits to use for the hash (default: 32)
|
| 86 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
| 87 |
+
If false, hamming distance proportional to cosine distance
|
| 88 |
+
(default: True)
|
| 89 |
+
topk: Number of top-k keys to for improved approximation (default: 32)
|
| 90 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 91 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 92 |
+
runtime)
|
| 93 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 94 |
+
(default: 0.1)
|
| 95 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 96 |
+
module for dispatching events (default: the default
|
| 97 |
+
global dispatcher)
|
| 98 |
+
"""
|
| 99 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
| 100 |
+
hash_bias=True, topk=32, softmax_temp=None,
|
| 101 |
+
attention_dropout=0.1, event_dispatcher=""):
|
| 102 |
+
super(ImprovedClusteredCausalAttention, self).__init__()
|
| 103 |
+
self.clusters = clusters
|
| 104 |
+
self.iterations = iterations
|
| 105 |
+
self.bits = bits
|
| 106 |
+
self.hash_bias = hash_bias
|
| 107 |
+
self.topk = topk
|
| 108 |
+
self.softmax_temp = softmax_temp
|
| 109 |
+
self.dropout = Dropout(attention_dropout)
|
| 110 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 111 |
+
|
| 112 |
+
def _create_query_groups(self, Q, query_lengths):
|
| 113 |
+
N, H, L, E = Q.shape
|
| 114 |
+
|
| 115 |
+
# Compute the hashes for all the queries
|
| 116 |
+
planes = Q.new_empty((self.bits, E+1))
|
| 117 |
+
normal_(planes)
|
| 118 |
+
if not self.hash_bias:
|
| 119 |
+
planes[:, -1] = 0
|
| 120 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
| 121 |
+
|
| 122 |
+
# Cluster the hashes and return the cluster index per query
|
| 123 |
+
clusters, counts = cluster(
|
| 124 |
+
hashes,
|
| 125 |
+
query_lengths.lengths.int(),
|
| 126 |
+
clusters=self.clusters,
|
| 127 |
+
iterations=self.iterations,
|
| 128 |
+
bits=self.bits
|
| 129 |
+
)
|
| 130 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
| 131 |
+
return (sorted_clusters, counts), sorted_indx
|
| 132 |
+
|
| 133 |
+
def _topk_attention(self, Q, K, V,
|
| 134 |
+
q_flat, q_rev_flat,
|
| 135 |
+
clusters, counts,
|
| 136 |
+
topk, topk_values,
|
| 137 |
+
softmax_temp,
|
| 138 |
+
query_lengths):
|
| 139 |
+
"""Return the attention with just the topk heads."""
|
| 140 |
+
# Extract some indices
|
| 141 |
+
N, H, L, E = Q.shape
|
| 142 |
+
_, _, S, _ = K.shape
|
| 143 |
+
_, _, C, k = topk.shape
|
| 144 |
+
|
| 145 |
+
# We need to pass the output tensor to initialize to 0
|
| 146 |
+
QK = clustered_sparse_dot_product(
|
| 147 |
+
Q, K, topk,
|
| 148 |
+
clusters, counts,
|
| 149 |
+
query_lengths.lengths.int()
|
| 150 |
+
)
|
| 151 |
+
# We need to mask out the future
|
| 152 |
+
assert topk.is_contiguous()
|
| 153 |
+
topk_broadcast = clustered_broadcast(
|
| 154 |
+
topk.float(),
|
| 155 |
+
clusters,
|
| 156 |
+
counts,
|
| 157 |
+
torch.ones_like(counts, dtype=torch.float32)
|
| 158 |
+
)
|
| 159 |
+
# Need to be careful here we changed the order of the keys the
|
| 160 |
+
# masking on future needs to be applied in the same way
|
| 161 |
+
seq_ids = torch.arange(L, device=QK.device).view(1, 1, L, 1).repeat(N, H, 1, 1)
|
| 162 |
+
# permute the ids in the same way as input so as to mask the right
|
| 163 |
+
# entries for each query
|
| 164 |
+
s_seq_ids = seq_ids.reshape(-1, 1).index_select(0, q_flat).view(N,H,L,1)
|
| 165 |
+
future_mask = topk_broadcast.long() > s_seq_ids
|
| 166 |
+
QK = QK.masked_fill(
|
| 167 |
+
future_mask,
|
| 168 |
+
float("-1e7")
|
| 169 |
+
)
|
| 170 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
| 171 |
+
# Mask again to ensure no probabilities leak due to float(-1e7)
|
| 172 |
+
# Leakage could be very high as we use a small top-k
|
| 173 |
+
A = A * (1. - future_mask.float())
|
| 174 |
+
A = self.dropout(A)
|
| 175 |
+
assert A.is_contiguous()
|
| 176 |
+
V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
|
| 177 |
+
|
| 178 |
+
return V_new
|
| 179 |
+
|
| 180 |
+
def _broadcast_values(self, V, clusters, counts, lengths):
|
| 181 |
+
"""Broadcast the values back to the correct positions but make sure
|
| 182 |
+
that the gradient flows properly."""
|
| 183 |
+
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
|
| 184 |
+
return V_new
|
| 185 |
+
|
| 186 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 187 |
+
key_lengths):
|
| 188 |
+
|
| 189 |
+
# Apply the key padding mask and make sure the attn_mask is a
|
| 190 |
+
# lower triangular causal mask
|
| 191 |
+
if not attn_mask.lower_triangular:
|
| 192 |
+
raise RuntimeError(("ImprovedClusteredCausalAttention only supports "
|
| 193 |
+
"lower triangular masks"))
|
| 194 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
| 195 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
| 196 |
+
values = values.permute(0,2,1,3).contiguous()
|
| 197 |
+
N, H, L, E = queries.shape
|
| 198 |
+
_, _, S, D = values.shape
|
| 199 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 200 |
+
|
| 201 |
+
# Cluster the queries into groups
|
| 202 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
| 203 |
+
clusters, counts = groups
|
| 204 |
+
|
| 205 |
+
# Re-organize queries so that first group belong to first cluster
|
| 206 |
+
# next to second cluster and so on. This improves kernel implementations.
|
| 207 |
+
# Note that this step is introduced after NeurIPS submission and
|
| 208 |
+
# now the complexity is O(N log(N)).
|
| 209 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
| 210 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 211 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
| 212 |
+
|
| 213 |
+
# Aggregate the re-arranged queries.
|
| 214 |
+
Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
|
| 215 |
+
# Compute the attention
|
| 216 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
| 217 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
| 218 |
+
# Set topk to minimum of key lengths if it is smaller than self.topk
|
| 219 |
+
cur_topk = min(self.topk, min(key_lengths.lengths).item())
|
| 220 |
+
topk_values, topk = torch.topk(QK, cur_topk, sorted=False, dim=-1)
|
| 221 |
+
assert topk.is_contiguous()
|
| 222 |
+
|
| 223 |
+
# Reverse mapping
|
| 224 |
+
sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
|
| 225 |
+
q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
| 226 |
+
|
| 227 |
+
# Compute the attention with only the top keys
|
| 228 |
+
V_topk = self._topk_attention(
|
| 229 |
+
s_queries, keys, values,
|
| 230 |
+
q_flat, q_rev_flat,
|
| 231 |
+
clusters, counts,
|
| 232 |
+
topk, topk_values,
|
| 233 |
+
softmax_temp,
|
| 234 |
+
query_lengths
|
| 235 |
+
)
|
| 236 |
+
V_sorted_new = V_topk
|
| 237 |
+
|
| 238 |
+
# Reverse the mapping to get correct values
|
| 239 |
+
V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
| 240 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Register the attention implementation so that it becomes available in our
|
| 244 |
+
# builders
|
| 245 |
+
AttentionRegistry.register(
|
| 246 |
+
"causal-improved-clustered", ImprovedClusteredCausalAttention,
|
| 247 |
+
[
|
| 248 |
+
("clusters", Int),
|
| 249 |
+
("iterations", Optional(Int, 10)),
|
| 250 |
+
("bits", Optional(Int, 63)),
|
| 251 |
+
("hash_bias", Optional(Bool, True)),
|
| 252 |
+
("topk", Optional(Int, 32)),
|
| 253 |
+
("softmax_temp", Optional(Float)),
|
| 254 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 255 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 256 |
+
]
|
| 257 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement unmasked linear attention."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn import Module
|
| 11 |
+
|
| 12 |
+
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
|
| 13 |
+
EventDispatcherInstance
|
| 14 |
+
from ..events import EventDispatcher
|
| 15 |
+
from ..feature_maps import elu_feature_map
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LinearAttention(Module):
|
| 19 |
+
"""Implement unmasked attention using dot product of feature maps in
|
| 20 |
+
O(N D^2) complexity.
|
| 21 |
+
|
| 22 |
+
Given the queries, keys and values as Q, K, V instead of computing
|
| 23 |
+
|
| 24 |
+
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
|
| 25 |
+
|
| 26 |
+
we make use of a feature map function Φ(.) and perform the following
|
| 27 |
+
computation
|
| 28 |
+
|
| 29 |
+
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
|
| 30 |
+
|
| 31 |
+
The above can be computed in O(N D^2) complexity where D is the
|
| 32 |
+
dimensionality of Q, K and V and N is the sequence length. Depending on the
|
| 33 |
+
feature map, however, the complexity of the attention might be limited.
|
| 34 |
+
|
| 35 |
+
Arguments
|
| 36 |
+
---------
|
| 37 |
+
feature_map: callable, a callable that applies the feature map to the
|
| 38 |
+
last dimension of a tensor (default: elu(x)+1)
|
| 39 |
+
eps: float, a small number to ensure the numerical stability of the
|
| 40 |
+
denominator (default: 1e-6)
|
| 41 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 42 |
+
module for dispatching events (default: the default
|
| 43 |
+
global dispatcher)
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
|
| 46 |
+
event_dispatcher=""):
|
| 47 |
+
super(LinearAttention, self).__init__()
|
| 48 |
+
self.feature_map = (
|
| 49 |
+
feature_map(query_dimensions) if feature_map else
|
| 50 |
+
elu_feature_map(query_dimensions)
|
| 51 |
+
)
|
| 52 |
+
self.eps = eps
|
| 53 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 54 |
+
|
| 55 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 56 |
+
key_lengths):
|
| 57 |
+
# Apply the feature map to the queries and keys
|
| 58 |
+
self.feature_map.new_feature_map(queries.device)
|
| 59 |
+
Q = self.feature_map.forward_queries(queries)
|
| 60 |
+
K = self.feature_map.forward_keys(keys)
|
| 61 |
+
|
| 62 |
+
# Apply the key padding mask and make sure that the attn_mask is
|
| 63 |
+
# all_ones
|
| 64 |
+
if not attn_mask.all_ones:
|
| 65 |
+
raise RuntimeError(("LinearAttention does not support arbitrary "
|
| 66 |
+
"attention masks"))
|
| 67 |
+
K = K * key_lengths.float_matrix[:, :, None, None]
|
| 68 |
+
|
| 69 |
+
# Compute the KV matrix, namely the dot product of keys and values so
|
| 70 |
+
# that we never explicitly compute the attention matrix and thus
|
| 71 |
+
# decrease the complexity
|
| 72 |
+
KV = torch.einsum("nshd,nshm->nhmd", K, values)
|
| 73 |
+
|
| 74 |
+
# Compute the normalizer
|
| 75 |
+
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
|
| 76 |
+
|
| 77 |
+
# Finally compute and return the new values
|
| 78 |
+
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
|
| 79 |
+
|
| 80 |
+
return V.contiguous()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Register the attention implementation so that it becomes available in our
|
| 84 |
+
# builders
|
| 85 |
+
AttentionRegistry.register(
|
| 86 |
+
"linear", LinearAttention,
|
| 87 |
+
[
|
| 88 |
+
("query_dimensions", Int),
|
| 89 |
+
("feature_map", Optional(Callable)),
|
| 90 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 91 |
+
]
|
| 92 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Implement local context attention."""
|
| 7 |
+
|
| 8 |
+
from math import sqrt
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.nn import Module, Dropout
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
|
| 14 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
| 15 |
+
EventDispatcherInstance
|
| 16 |
+
from ..events import EventDispatcher
|
| 17 |
+
from ..local_product import local_dot_product, local_weighted_average
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LocalAttention(Module):
|
| 21 |
+
"""Implement fast local attention where a query can only attend to
|
| 22 |
+
neighboring keys.
|
| 23 |
+
|
| 24 |
+
In this attention module the query Q_i can only attend to a key K_j if
|
| 25 |
+
|i-j| < local_context/2.
|
| 26 |
+
|
| 27 |
+
Arguments
|
| 28 |
+
---------
|
| 29 |
+
local_context: The neighborhood to consider for local attention.
|
| 30 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 31 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 32 |
+
runtime)
|
| 33 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 34 |
+
(default: 0.1)
|
| 35 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 36 |
+
module for dispatching events (default: the default
|
| 37 |
+
global dispatcher)
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1,
|
| 40 |
+
event_dispatcher=""):
|
| 41 |
+
super(LocalAttention, self).__init__()
|
| 42 |
+
self.local_context = local_context
|
| 43 |
+
self.softmax_temp = softmax_temp
|
| 44 |
+
self.dropout = Dropout(attention_dropout)
|
| 45 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 46 |
+
|
| 47 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 48 |
+
key_lengths):
|
| 49 |
+
"""Implements the local attention.
|
| 50 |
+
|
| 51 |
+
The attn_mask can be anything but the only values that will be
|
| 52 |
+
considered will be the ones in the neighborhood of each query.
|
| 53 |
+
|
| 54 |
+
Arguments
|
| 55 |
+
---------
|
| 56 |
+
queries: (N, L, H, E) The tensor containing the queries
|
| 57 |
+
keys: (N, S, H, E) The tensor containing the keys
|
| 58 |
+
values: (N, S, H, D) The tensor containing the values
|
| 59 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
| 60 |
+
query can attend to
|
| 61 |
+
query_lengths: An implementation of BaseMask that encodes how
|
| 62 |
+
many queries each sequence in the batch consists of
|
| 63 |
+
key_lengths: An implementation of BaseMask that encodes how
|
| 64 |
+
many queries each sequence in the batch consists of
|
| 65 |
+
"""
|
| 66 |
+
# Extract some shapes and compute the temperature
|
| 67 |
+
N, L, H, E = queries.shape
|
| 68 |
+
_, S, _, D = values.shape
|
| 69 |
+
context = self.local_context
|
| 70 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 71 |
+
|
| 72 |
+
# Permute the dimensions to NHLE instead of NLHE
|
| 73 |
+
queries = queries.permute(0, 2, 1, 3).contiguous()
|
| 74 |
+
keys = keys.permute(0, 2, 1, 3).contiguous()
|
| 75 |
+
values = values.permute(0, 2, 1, 3).contiguous()
|
| 76 |
+
|
| 77 |
+
QK = local_dot_product(
|
| 78 |
+
queries,
|
| 79 |
+
keys,
|
| 80 |
+
attn_mask.additive_matrix_finite,
|
| 81 |
+
key_lengths.lengths,
|
| 82 |
+
self.local_context
|
| 83 |
+
)
|
| 84 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
| 85 |
+
|
| 86 |
+
V_new = local_weighted_average(A, values)
|
| 87 |
+
|
| 88 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Register the attention implementation so that it becomes available in our
|
| 92 |
+
# builders
|
| 93 |
+
AttentionRegistry.register(
|
| 94 |
+
"local", LocalAttention,
|
| 95 |
+
[
|
| 96 |
+
("local_context", Int),
|
| 97 |
+
("softmax_temp", Optional(Float)),
|
| 98 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 99 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 100 |
+
]
|
| 101 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Implement the Reformer attention from the paper
|
| 8 |
+
"Reformer the efficient transformer"."""
|
| 9 |
+
|
| 10 |
+
from math import sqrt
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.nn import Dropout, Module
|
| 14 |
+
from torch.nn.init import normal_
|
| 15 |
+
|
| 16 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
| 17 |
+
Bool, EventDispatcherInstance
|
| 18 |
+
from ..events import EventDispatcher
|
| 19 |
+
from ..masking import FullMask
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ReformerAttention(Module):
|
| 23 |
+
"""Implement the attention module of the paper "Reformer the efficient
|
| 24 |
+
transformer"
|
| 25 |
+
|
| 26 |
+
Arguments
|
| 27 |
+
---------
|
| 28 |
+
chunk_size : Chunk size for each block (default: 32)
|
| 29 |
+
bits : Number of bits for hashing (default: 8)
|
| 30 |
+
rounds : Number of rounds of attention computation (default: 4)
|
| 31 |
+
masked : If true, the query does not attend to itsself (default: False)
|
| 32 |
+
softmax_temp: The temperature to use for the softmax attention.
|
| 33 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 34 |
+
runtime)
|
| 35 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 36 |
+
(default: 0.1)
|
| 37 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
| 38 |
+
module for dispatching events (default: the default
|
| 39 |
+
global dispatcher)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False,
|
| 43 |
+
softmax_temp=None, attention_dropout=0.1,
|
| 44 |
+
event_dispatcher=""):
|
| 45 |
+
super(ReformerAttention, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.chunk_size = chunk_size
|
| 48 |
+
self.bits = bits
|
| 49 |
+
self.rounds = rounds
|
| 50 |
+
self.masked = masked
|
| 51 |
+
self.softmax_temp = softmax_temp
|
| 52 |
+
self.dropout = Dropout(attention_dropout)
|
| 53 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
| 54 |
+
|
| 55 |
+
def _normalize(self, x):
|
| 56 |
+
norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x))
|
| 57 |
+
x_normed = x / norms.unsqueeze(-1)
|
| 58 |
+
return x_normed
|
| 59 |
+
|
| 60 |
+
def _look_back(self, x):
|
| 61 |
+
xshape = x.shape
|
| 62 |
+
|
| 63 |
+
return torch.cat([
|
| 64 |
+
x.new_zeros((xshape[0], 1) + xshape[2:]),
|
| 65 |
+
torch.repeat_interleave(x, 2, dim=1)[:,:-1]
|
| 66 |
+
], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:])
|
| 67 |
+
|
| 68 |
+
def _reformer_round(self, Q, K, V, mask, softmax_temp):
|
| 69 |
+
# Hash the queries
|
| 70 |
+
N, L, H, E = Q.shape
|
| 71 |
+
planes = Q.new_empty(self.bits, E)
|
| 72 |
+
normal_(planes)
|
| 73 |
+
projected = torch.einsum("nlhe,be->nlhb", K, planes)
|
| 74 |
+
hashes = torch.argmax(
|
| 75 |
+
torch.cat([projected, -projected], dim=-1),
|
| 76 |
+
dim=-1
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Sort the queries in order to group them
|
| 80 |
+
group = torch.argsort(hashes, dim=1)
|
| 81 |
+
|
| 82 |
+
invert_group = torch.empty_like(group)
|
| 83 |
+
batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1)
|
| 84 |
+
sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1)
|
| 85 |
+
head_indices = torch.arange(H, device=hashes.device).view(1, 1, H)
|
| 86 |
+
invert_group[batch_indices, group, head_indices] = sequence_indices
|
| 87 |
+
group = group.view(N, -1, self.chunk_size, H)
|
| 88 |
+
invert_group = invert_group.view(N, -1, self.chunk_size, H)
|
| 89 |
+
batch_indices = batch_indices.unsqueeze(1)
|
| 90 |
+
head_indices = head_indices.unsqueeze(0)
|
| 91 |
+
|
| 92 |
+
# Reorder Q, V and mask
|
| 93 |
+
Q_grouped = Q[batch_indices, group, head_indices]
|
| 94 |
+
K_grouped = K[batch_indices, group, head_indices]
|
| 95 |
+
V_grouped = V[batch_indices, group, head_indices]
|
| 96 |
+
mask_grouped = mask[
|
| 97 |
+
batch_indices.unsqueeze(1),
|
| 98 |
+
group.unsqueeze(3),
|
| 99 |
+
self._look_back(group).unsqueeze(2)
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf")
|
| 103 |
+
|
| 104 |
+
# When everything is masked just unmask everything because it doesn't
|
| 105 |
+
# matter what the output is at those positions
|
| 106 |
+
# This is to avoid inf/nans in the new values at masked positions
|
| 107 |
+
infmask = torch.isinf(mask_grouped)
|
| 108 |
+
infmask = torch.all(infmask, dim=3, keepdims=True)
|
| 109 |
+
mask_grouped = mask_grouped.masked_fill(infmask, 0.)
|
| 110 |
+
|
| 111 |
+
# Attention
|
| 112 |
+
K_grouped = self._look_back(K_grouped)
|
| 113 |
+
QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped)
|
| 114 |
+
QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3)
|
| 115 |
+
A = torch.softmax(softmax_temp * QQ, dim=-1)
|
| 116 |
+
A = self.dropout(A)
|
| 117 |
+
|
| 118 |
+
# Values
|
| 119 |
+
V_grouped = self._look_back(V_grouped)
|
| 120 |
+
V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped)
|
| 121 |
+
V_new = V_new.contiguous().view(N, -1, H, E)
|
| 122 |
+
V_new = V_new[batch_indices, invert_group, head_indices]
|
| 123 |
+
V_new = V_new.contiguous().view(N, L, H, E)
|
| 124 |
+
return V_new
|
| 125 |
+
|
| 126 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
| 127 |
+
key_lengths):
|
| 128 |
+
# Extract the dimensions of query, key, value
|
| 129 |
+
N, L, H, E = queries.shape
|
| 130 |
+
|
| 131 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
| 132 |
+
# Create the mask
|
| 133 |
+
mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L)
|
| 134 |
+
if self.masked:
|
| 135 |
+
mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9)
|
| 136 |
+
|
| 137 |
+
if not attn_mask.all_ones:
|
| 138 |
+
mask = mask + attn_mask.additive_matrix.unsqueeze(0)
|
| 139 |
+
# Get normalized Queries as Keys
|
| 140 |
+
K = self._normalize(queries)
|
| 141 |
+
# Zero the masked out keys
|
| 142 |
+
K = K * key_lengths.float_matrix.view(N, L, 1, 1)
|
| 143 |
+
|
| 144 |
+
V_new = 0
|
| 145 |
+
factor = 1/self.rounds
|
| 146 |
+
for i in range(self.rounds):
|
| 147 |
+
V_new = V_new + \
|
| 148 |
+
factor * self._reformer_round(queries, K, values, mask, softmax_temp)
|
| 149 |
+
|
| 150 |
+
return V_new
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Register the attention implementation so that it becomes available in our
|
| 154 |
+
# builders
|
| 155 |
+
AttentionRegistry.register(
|
| 156 |
+
"reformer", ReformerAttention,
|
| 157 |
+
[
|
| 158 |
+
("chunk_size", Optional(Int, 32)),
|
| 159 |
+
("bits", Optional(Int, 63)),
|
| 160 |
+
("rounds", Optional(Int, 4)),
|
| 161 |
+
("masked", Optional(Bool, False)),
|
| 162 |
+
("softmax_temp", Optional(Float)),
|
| 163 |
+
("attention_dropout", Optional(Float, 0.1)),
|
| 164 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
| 165 |
+
]
|
| 166 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Allow for the dynamic registration of new attention implementations.
|
| 7 |
+
|
| 8 |
+
This module provides a Registry implementation that other modules can use to
|
| 9 |
+
register attention implementations for the builders.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .registry import \
|
| 13 |
+
AttentionRegistry, \
|
| 14 |
+
RecurrentAttentionRegistry, \
|
| 15 |
+
RecurrentCrossAttentionRegistry
|
| 16 |
+
from .spec import Spec, Choice, Optional, Int, Float, Bool, Callable, \
|
| 17 |
+
EventDispatcherInstance
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (786 Bytes). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (2.27 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc
ADDED
|
Binary file (4.73 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Registry(object):
|
| 8 |
+
"""Hold the available attention implementations and their required
|
| 9 |
+
parameters."""
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self._classes = {}
|
| 12 |
+
self._class_params = {}
|
| 13 |
+
self._parameters = {}
|
| 14 |
+
|
| 15 |
+
def register(self, key, class_object, parameter_tuples):
|
| 16 |
+
# register the class if the key is new
|
| 17 |
+
if key in self._classes:
|
| 18 |
+
raise ValueError("{} is already registered".format(key))
|
| 19 |
+
self._classes[key] = class_object
|
| 20 |
+
|
| 21 |
+
# register the parameters
|
| 22 |
+
for parameter, spec in parameter_tuples:
|
| 23 |
+
if (
|
| 24 |
+
parameter in self._parameters and
|
| 25 |
+
self._parameters[parameter] != spec
|
| 26 |
+
):
|
| 27 |
+
raise ValueError(("{} is already registered with "
|
| 28 |
+
"spec {!r} instead of {!r}").format(
|
| 29 |
+
parameter,
|
| 30 |
+
self._parameters[parameter],
|
| 31 |
+
spec
|
| 32 |
+
))
|
| 33 |
+
self._parameters[parameter] = spec
|
| 34 |
+
|
| 35 |
+
# note which parameters are needed by this class
|
| 36 |
+
self._class_params[key] = [p for p, s in parameter_tuples]
|
| 37 |
+
|
| 38 |
+
def __contains__(self, key):
|
| 39 |
+
return key in self._classes
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, key):
|
| 42 |
+
return self._classes[key], self._class_params[key]
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def keys(self):
|
| 46 |
+
return list(self._classes.keys())
|
| 47 |
+
|
| 48 |
+
def contains_parameter(self, key):
|
| 49 |
+
return key in self._parameters
|
| 50 |
+
|
| 51 |
+
def validate_parameter(self, key, value):
|
| 52 |
+
try:
|
| 53 |
+
return self._parameters[key].get(value)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
raise ValueError(("Invalid value {!r} for "
|
| 56 |
+
"parameter {!r}").format(value, key)) from e
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
AttentionRegistry = Registry()
|
| 60 |
+
RecurrentAttentionRegistry = Registry()
|
| 61 |
+
RecurrentCrossAttentionRegistry = Registry()
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Spec instances allow to describe and check the type and value of
|
| 7 |
+
parameters."""
|
| 8 |
+
|
| 9 |
+
from ..events import EventDispatcher
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Spec(object):
|
| 13 |
+
"""Describe and validate a parameter type.
|
| 14 |
+
|
| 15 |
+
Arguments
|
| 16 |
+
---------
|
| 17 |
+
predicate: A callable that checks if the value is acceptable and
|
| 18 |
+
returns its canonical value or raises ValueError.
|
| 19 |
+
name: A name to create a human readable description of the Spec
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, predicate, name="CustomSpec"):
|
| 22 |
+
self._predicate = predicate
|
| 23 |
+
self._name = name
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
return self._name
|
| 27 |
+
|
| 28 |
+
def check(self, x):
|
| 29 |
+
try:
|
| 30 |
+
self._predicate(x)
|
| 31 |
+
return True
|
| 32 |
+
except ValueError:
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
def get(self, x):
|
| 36 |
+
return self._predicate(x)
|
| 37 |
+
|
| 38 |
+
def __eq__(self, y):
|
| 39 |
+
return self is y
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Choice(Spec):
|
| 43 |
+
"""A parameter type for a set of options.
|
| 44 |
+
|
| 45 |
+
Arguments
|
| 46 |
+
---------
|
| 47 |
+
choices: A set or list of possible values for this parameter
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, choices):
|
| 50 |
+
self._choices = choices
|
| 51 |
+
|
| 52 |
+
def get(self, x):
|
| 53 |
+
if x in self._choices:
|
| 54 |
+
return x
|
| 55 |
+
raise ValueError("{!r} is not in {!r}".format(x, self._choices))
|
| 56 |
+
|
| 57 |
+
def __repr__(self):
|
| 58 |
+
return "Choice({!r})".format(self._choices)
|
| 59 |
+
|
| 60 |
+
def __eq__(self, x):
|
| 61 |
+
if isinstance(x, Choice):
|
| 62 |
+
return self._choices == x._choices
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class _Callable(Spec):
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super(_Callable, self).__init__(None, "Callable")
|
| 69 |
+
|
| 70 |
+
def get(self, x):
|
| 71 |
+
if callable(x):
|
| 72 |
+
return x
|
| 73 |
+
raise ValueError("{!r} is not a callable".format(x))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class _EventDispatcherInstance(Spec):
|
| 77 |
+
def __init__(self):
|
| 78 |
+
super(_EventDispatcherInstance, self).__init__(
|
| 79 |
+
_EventDispatcherInstance._get_event_dispatcher,
|
| 80 |
+
"EventDispatcherInstance"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def _get_event_dispatcher(x):
|
| 85 |
+
if isinstance(x, str):
|
| 86 |
+
return x
|
| 87 |
+
if isinstance(x, EventDispatcher):
|
| 88 |
+
return x
|
| 89 |
+
raise ValueError("{!r} is not an event dispatcher".format(x))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Optional(Spec):
|
| 93 |
+
"""Represent an optional parameter that can either have a value or it can
|
| 94 |
+
be None.
|
| 95 |
+
|
| 96 |
+
Arguments
|
| 97 |
+
---------
|
| 98 |
+
spec: The spec for the value if it is not None
|
| 99 |
+
default: The returned value in case it is None
|
| 100 |
+
"""
|
| 101 |
+
def __init__(self, spec, default=None):
|
| 102 |
+
self._other_spec = spec
|
| 103 |
+
self._default = default
|
| 104 |
+
|
| 105 |
+
def __repr__(self):
|
| 106 |
+
return "Optional[{!r}, {!r}]".format(self._other_spec, self._default)
|
| 107 |
+
|
| 108 |
+
def get(self, x):
|
| 109 |
+
if x is None:
|
| 110 |
+
return self._default
|
| 111 |
+
return self._other_spec.get(x)
|
| 112 |
+
|
| 113 |
+
def __eq__(self, x):
|
| 114 |
+
if isinstance(x, Optional):
|
| 115 |
+
return (
|
| 116 |
+
self._other_spec == x._other_spec and
|
| 117 |
+
self._default == x._default
|
| 118 |
+
)
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
Int = Spec(int, "Int")
|
| 123 |
+
Float = Spec(float, "Float")
|
| 124 |
+
Bool = Spec(bool, "Bool")
|
| 125 |
+
Callable = _Callable()
|
| 126 |
+
EventDispatcherInstance = _EventDispatcherInstance()
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""This module implements builders that simplify building complex transformer
|
| 8 |
+
architectures with different attention mechanisms.
|
| 9 |
+
|
| 10 |
+
The main idea is to facilitate the construction of various attention layers and
|
| 11 |
+
transformer encoder layers and simplify their assembly into one transformer
|
| 12 |
+
module. It also allows for flexibility in the scripts as many builder
|
| 13 |
+
parameters can correspond 1-1 with command line arguments.
|
| 14 |
+
|
| 15 |
+
Example usage:
|
| 16 |
+
|
| 17 |
+
builder = TransformerEncoderBuilder()
|
| 18 |
+
builder.n_layers = 12
|
| 19 |
+
builder.n_heads = 8
|
| 20 |
+
builder.feed_forward_dimensions = 1024
|
| 21 |
+
builder.query_dimensions = 64
|
| 22 |
+
builder.value_dimensions = 64
|
| 23 |
+
builder.dropout = 0.1
|
| 24 |
+
builder.attention_dropout = 0.1
|
| 25 |
+
builder.attention_type = "linear"
|
| 26 |
+
transformer = builder.get()
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"AttentionBuilder",
|
| 31 |
+
"RecurrentAttentionBuilder",
|
| 32 |
+
"RecurrentCrossAttentionBuilder"
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# Import the attention implementations so that they register themselves with
|
| 36 |
+
# the builder. Attention implementations external to the library should be
|
| 37 |
+
# imported before using the builders.
|
| 38 |
+
#
|
| 39 |
+
# TODO: Should this behaviour change? Namely, should all attention
|
| 40 |
+
# implementations be imported in order to be useable? This also allows
|
| 41 |
+
# using the library even partially built, for instance.
|
| 42 |
+
from ..attention import \
|
| 43 |
+
FullAttention, \
|
| 44 |
+
LinearAttention
|
| 45 |
+
|
| 46 |
+
del FullAttention, \
|
| 47 |
+
LinearAttention
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
from .attention_builders import \
|
| 51 |
+
AttentionBuilder, \
|
| 52 |
+
RecurrentAttentionBuilder, \
|
| 53 |
+
RecurrentCrossAttentionBuilder
|
| 54 |
+
|
| 55 |
+
from .transformer_builders import \
|
| 56 |
+
TransformerEncoderBuilder, \
|
| 57 |
+
RecurrentEncoderBuilder, \
|
| 58 |
+
TransformerDecoderBuilder, \
|
| 59 |
+
RecurrentDecoderBuilder
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc
ADDED
|
Binary file (6.49 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
from .base import BaseBuilder
|
| 9 |
+
from ..attention_registry import \
|
| 10 |
+
AttentionRegistry, \
|
| 11 |
+
RecurrentAttentionRegistry, \
|
| 12 |
+
RecurrentCrossAttentionRegistry
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseAttentionBuilder(BaseBuilder):
|
| 16 |
+
def __init__(self, registry):
|
| 17 |
+
self._registry = registry
|
| 18 |
+
self._parameters = defaultdict(lambda: None)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def available_attentions(self):
|
| 22 |
+
"""Return a list with the available attention implementations."""
|
| 23 |
+
return self._registry.keys
|
| 24 |
+
|
| 25 |
+
def validate_attention_type(self, attention_type):
|
| 26 |
+
"""Parse the attention type according to the rules used by `get()` and
|
| 27 |
+
check if the requested attention is constructible."""
|
| 28 |
+
return all(
|
| 29 |
+
all(t in self._registry for t in a.split(","))
|
| 30 |
+
for a in attention_type.split(":")
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def __setattr__(self, key, value):
|
| 34 |
+
# Make sure we have normal behaviour for the class members _registry
|
| 35 |
+
# and _parameters
|
| 36 |
+
if key in ["_registry", "_parameters"]:
|
| 37 |
+
return object.__setattr__(self, key, value)
|
| 38 |
+
|
| 39 |
+
# Assign everything else in the parameters dictionary
|
| 40 |
+
if not self._registry.contains_parameter(key):
|
| 41 |
+
raise AttributeError(("{!r} is not a valid attention "
|
| 42 |
+
"parameter name").format(key))
|
| 43 |
+
self._parameters[key] = self._registry.validate_parameter(key, value)
|
| 44 |
+
|
| 45 |
+
def __getattr__(self, key):
|
| 46 |
+
if key in self._parameters:
|
| 47 |
+
return self._parameters[key]
|
| 48 |
+
else:
|
| 49 |
+
raise AttributeError()
|
| 50 |
+
|
| 51 |
+
def __repr__(self):
|
| 52 |
+
return (
|
| 53 |
+
"{}.from_kwargs(\n".format(self.__class__.__name__) +
|
| 54 |
+
"\n".join([" {}={!r},".format(k, v)
|
| 55 |
+
for k, v in self._parameters.items()])[:-1] +
|
| 56 |
+
"\n)"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def get(self, attention_type):
|
| 60 |
+
"""Construct the attention implementation object and return it.
|
| 61 |
+
|
| 62 |
+
The passed in attention_type argument defines the attention to be
|
| 63 |
+
created. It should be a string and in its simplest form it should
|
| 64 |
+
be one of the available choices from `available_attentions`.
|
| 65 |
+
|
| 66 |
+
However, to enable attention decoration, namely an attention
|
| 67 |
+
implementation augmenting the functionality of another implementation,
|
| 68 |
+
the attention type can be a colon separated list of compositions like
|
| 69 |
+
the following examples:
|
| 70 |
+
|
| 71 |
+
- 'att1' means instantiate att1
|
| 72 |
+
- 'att2:att1' means instantiate att1 and decorate it with att2
|
| 73 |
+
- 'att3:att1,att4' means instantiate att1 and att4 and decorate
|
| 74 |
+
them with att3
|
| 75 |
+
|
| 76 |
+
Arguments
|
| 77 |
+
---------
|
| 78 |
+
attention_type: A string that contains one or more keys from
|
| 79 |
+
`available_attentions` separated with a colon to
|
| 80 |
+
denote the decoration pattern.
|
| 81 |
+
"""
|
| 82 |
+
compositions = reversed(attention_type.split(":"))
|
| 83 |
+
attentions = []
|
| 84 |
+
for c in compositions:
|
| 85 |
+
attentions = [
|
| 86 |
+
self._construct_attention(t, attentions)
|
| 87 |
+
for t in c.split(",")
|
| 88 |
+
]
|
| 89 |
+
if len(attentions) > 1:
|
| 90 |
+
raise ValueError(("Invalid attention_type argument "
|
| 91 |
+
"{!r}").format(attention_type))
|
| 92 |
+
return attentions[0]
|
| 93 |
+
|
| 94 |
+
def _construct_attention(self, attention_type, decorated=[]):
|
| 95 |
+
"""Construct an attention implementation object.
|
| 96 |
+
|
| 97 |
+
Arguments
|
| 98 |
+
---------
|
| 99 |
+
attention_type: A string that contains a single key from the
|
| 100 |
+
`available_attentions`
|
| 101 |
+
decorated: A list of attention implementations to pass as arguments
|
| 102 |
+
to be decorated
|
| 103 |
+
"""
|
| 104 |
+
if attention_type not in self._registry:
|
| 105 |
+
raise ValueError(("Unknown attention type "
|
| 106 |
+
"{!r}").format(attention_type))
|
| 107 |
+
|
| 108 |
+
attention, parameters = self._registry[attention_type]
|
| 109 |
+
parameter_dictionary = {
|
| 110 |
+
p: self._registry.validate_parameter(p, self._parameters[p])
|
| 111 |
+
for p in parameters
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return attention(*decorated, **parameter_dictionary)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class AttentionBuilder(BaseAttentionBuilder):
|
| 118 |
+
"""Build attention implementations for batch sequence processing or
|
| 119 |
+
training."""
|
| 120 |
+
def __init__(self):
|
| 121 |
+
super(AttentionBuilder, self).__init__(AttentionRegistry)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class RecurrentAttentionBuilder(BaseAttentionBuilder):
|
| 125 |
+
"""Build attention implementations for autoregressive sequence
|
| 126 |
+
processing."""
|
| 127 |
+
def __init__(self):
|
| 128 |
+
super(RecurrentAttentionBuilder, self).__init__(
|
| 129 |
+
RecurrentAttentionRegistry
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class RecurrentCrossAttentionBuilder(BaseAttentionBuilder):
|
| 134 |
+
"""Build attention implementations for autoregressive cross attention
|
| 135 |
+
computation."""
|
| 136 |
+
def __init__(self):
|
| 137 |
+
super(RecurrentCrossAttentionBuilder, self).__init__(
|
| 138 |
+
RecurrentCrossAttentionRegistry
|
| 139 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
"""Provide a class for the others to inherit some useful functionality."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseBuilder(object):
|
| 11 |
+
@classmethod
|
| 12 |
+
def from_kwargs(cls, **kwargs):
|
| 13 |
+
"""Construct a builder and set all the keyword arguments as parameters.
|
| 14 |
+
|
| 15 |
+
The keyword argument strict is passed to
|
| 16 |
+
BaseBuilder.from_dictionary separately.
|
| 17 |
+
|
| 18 |
+
See BaseBuilder.from_dictionary().
|
| 19 |
+
"""
|
| 20 |
+
strict = kwargs.pop("strict", True)
|
| 21 |
+
return cls.from_dictionary(kwargs, strict=strict)
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_namespace(cls, args, strict=False):
|
| 25 |
+
"""Construct a builder from an argparse Namespace.
|
| 26 |
+
|
| 27 |
+
To be used for building transformers from command line arguments.
|
| 28 |
+
|
| 29 |
+
See BaseBuilder.from_dictionary().
|
| 30 |
+
"""
|
| 31 |
+
return cls.from_dictionary(vars(args), strict=strict)
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_dictionary(cls, dictionary, strict=True):
|
| 35 |
+
"""Construct a builder and set all the parameters in the dictionary.
|
| 36 |
+
|
| 37 |
+
Given a dictionary
|
| 38 |
+
|
| 39 |
+
d = {"foo": "bar"}
|
| 40 |
+
|
| 41 |
+
then
|
| 42 |
+
|
| 43 |
+
builder = TransformerEncoderBuilder.from_dictionary(d)
|
| 44 |
+
|
| 45 |
+
is equivalent to
|
| 46 |
+
|
| 47 |
+
builder = TransformerEncoderBuilder()
|
| 48 |
+
builder.foo = "bar"
|
| 49 |
+
|
| 50 |
+
Arguments
|
| 51 |
+
---------
|
| 52 |
+
dictionary: A dictionary of parameters to set to the builder.
|
| 53 |
+
strict: bool, If a key is not a parameter and strict is set to True
|
| 54 |
+
then a ValueError is raised, otherwise that dictionary key
|
| 55 |
+
is ignored (default: True)
|
| 56 |
+
"""
|
| 57 |
+
builder = cls()
|
| 58 |
+
for k, v in dictionary.items():
|
| 59 |
+
try:
|
| 60 |
+
setattr(builder, k, v)
|
| 61 |
+
except AttributeError:
|
| 62 |
+
if strict:
|
| 63 |
+
raise ValueError(("The builder has no "
|
| 64 |
+
"parameter {!r}").format(k))
|
| 65 |
+
else:
|
| 66 |
+
continue
|
| 67 |
+
return builder
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Build complex transformer architectures for inference or training easily."""
|
| 7 |
+
|
| 8 |
+
from torch.nn import LayerNorm
|
| 9 |
+
|
| 10 |
+
from ..attention import AttentionLayer
|
| 11 |
+
from ..transformers import TransformerEncoder, TransformerEncoderLayer, \
|
| 12 |
+
TransformerDecoder, TransformerDecoderLayer
|
| 13 |
+
from ..recurrent.attention import \
|
| 14 |
+
RecurrentAttentionLayer, \
|
| 15 |
+
RecurrentCrossAttentionLayer
|
| 16 |
+
from ..recurrent.transformers import \
|
| 17 |
+
RecurrentTransformerEncoder, RecurrentTransformerEncoderLayer, \
|
| 18 |
+
RecurrentTransformerDecoder, RecurrentTransformerDecoderLayer
|
| 19 |
+
from .base import BaseBuilder
|
| 20 |
+
from .attention_builders import AttentionBuilder, RecurrentAttentionBuilder, \
|
| 21 |
+
RecurrentCrossAttentionBuilder
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BaseTransformerBuilder(BaseBuilder):
|
| 25 |
+
"""Contains all the parameters for building a transformer other than the
|
| 26 |
+
attention part.
|
| 27 |
+
|
| 28 |
+
Classes extending the BaseTransformerBuilder should implement the `get()`
|
| 29 |
+
method that actually builds the transformer.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self):
|
| 32 |
+
# transformer parameters
|
| 33 |
+
self._n_layers = 4
|
| 34 |
+
self._n_heads = 4
|
| 35 |
+
self._d_query = 64
|
| 36 |
+
self._d_value = 64
|
| 37 |
+
self._d_ff = 1024
|
| 38 |
+
self._dropout = 0.1
|
| 39 |
+
self._activation = "relu"
|
| 40 |
+
self._final_norm = True
|
| 41 |
+
self._event_dispatcher = "" # the default global dispatcher
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def n_layers(self):
|
| 45 |
+
"""The number of transformer layers."""
|
| 46 |
+
return self._n_layers
|
| 47 |
+
|
| 48 |
+
@n_layers.setter
|
| 49 |
+
def n_layers(self, val):
|
| 50 |
+
self._n_layers = val
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def n_heads(self):
|
| 54 |
+
"""The number of heads in each transformer layer."""
|
| 55 |
+
return self._n_heads
|
| 56 |
+
|
| 57 |
+
@n_heads.setter
|
| 58 |
+
def n_heads(self, val):
|
| 59 |
+
self._n_heads = val
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def feed_forward_dimensions(self):
|
| 63 |
+
"""The dimensions of the fully connected layer in the transformer
|
| 64 |
+
layers."""
|
| 65 |
+
return self._d_ff
|
| 66 |
+
|
| 67 |
+
@feed_forward_dimensions.setter
|
| 68 |
+
def feed_forward_dimensions(self, val):
|
| 69 |
+
self._d_ff = val
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def query_dimensions(self):
|
| 73 |
+
"""The dimensions of the queries and keys in each attention layer."""
|
| 74 |
+
return self._d_query
|
| 75 |
+
|
| 76 |
+
@query_dimensions.setter
|
| 77 |
+
def query_dimensions(self, val):
|
| 78 |
+
self._d_query = val
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def value_dimensions(self):
|
| 82 |
+
"""The dimensions of the values in each attention layer."""
|
| 83 |
+
return self._d_value
|
| 84 |
+
|
| 85 |
+
@value_dimensions.setter
|
| 86 |
+
def value_dimensions(self, val):
|
| 87 |
+
self._d_value = val
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def dropout(self):
|
| 91 |
+
"""The dropout rate to be applied in the transformer encoder layer."""
|
| 92 |
+
return self._dropout
|
| 93 |
+
|
| 94 |
+
@dropout.setter
|
| 95 |
+
def dropout(self, val):
|
| 96 |
+
self._dropout = val
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def activation(self):
|
| 100 |
+
"""The activation function for the transformer layer.
|
| 101 |
+
|
| 102 |
+
One of {'relu', 'gelu'}.
|
| 103 |
+
"""
|
| 104 |
+
return self._activation
|
| 105 |
+
|
| 106 |
+
@activation.setter
|
| 107 |
+
def activation(self, val):
|
| 108 |
+
activations = ["relu", "gelu"]
|
| 109 |
+
if val not in activations:
|
| 110 |
+
raise ValueError(("{!r} is not one of the availabel activation "
|
| 111 |
+
"types {!r}").format(val, activations))
|
| 112 |
+
self._activation = val
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def final_normalization(self):
|
| 116 |
+
"""Whether to add LayerNorm as the final layer of the
|
| 117 |
+
TransformerEncoder."""
|
| 118 |
+
return self._final_norm
|
| 119 |
+
|
| 120 |
+
@final_normalization.setter
|
| 121 |
+
def final_normalization(self, val):
|
| 122 |
+
self._final_norm = bool(val)
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def event_dispatcher(self):
|
| 126 |
+
"""The transformer event dispatcher either as a string or as an
|
| 127 |
+
EventDispatcher object."""
|
| 128 |
+
return self._event_dispatcher
|
| 129 |
+
|
| 130 |
+
@event_dispatcher.setter
|
| 131 |
+
def event_dispatcher(self, event_dispatcher):
|
| 132 |
+
self._event_dispatcher = event_dispatcher
|
| 133 |
+
|
| 134 |
+
def get(self):
|
| 135 |
+
"""Build the transformer and return it."""
|
| 136 |
+
raise NotImplementedError()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class BaseTransformerEncoderBuilder(BaseTransformerBuilder):
|
| 140 |
+
"""Implement the logic of building a transformer encoder but leave the
|
| 141 |
+
specific layers open for changing by the inheriting classes. This allows us
|
| 142 |
+
to reuse the logic for creating both the TransformerEncoder and the
|
| 143 |
+
RecurrentTransformerEncoder.
|
| 144 |
+
|
| 145 |
+
Inheriting classes should implement the following:
|
| 146 |
+
|
| 147 |
+
- _get_attention_builder()
|
| 148 |
+
- _get_attention_layer_class()
|
| 149 |
+
- _get_encoder_class()
|
| 150 |
+
- _get_encoder_layer_class()
|
| 151 |
+
"""
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super(BaseTransformerEncoderBuilder, self).__init__()
|
| 154 |
+
self._attention_builder = self._get_attention_builder()
|
| 155 |
+
self._attention_type = "full"
|
| 156 |
+
|
| 157 |
+
def _get_attention_builder(self):
|
| 158 |
+
"""Return an instance of the appropriate attention builder."""
|
| 159 |
+
raise NotImplementedError()
|
| 160 |
+
|
| 161 |
+
def _get_attention_layer_class(self):
|
| 162 |
+
"""Return the class for the layer that projects queries keys and
|
| 163 |
+
values."""
|
| 164 |
+
raise NotImplementedError()
|
| 165 |
+
|
| 166 |
+
def _get_encoder_class(self):
|
| 167 |
+
"""Return the class for the transformer encoder."""
|
| 168 |
+
raise NotImplementedError()
|
| 169 |
+
|
| 170 |
+
def _get_encoder_layer_class(self):
|
| 171 |
+
"""Return the class for the transformer encoder layer."""
|
| 172 |
+
raise NotImplementedError()
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def attention(self):
|
| 176 |
+
"""The attention builder instance."""
|
| 177 |
+
return self._attention_builder
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def attention_type(self):
|
| 181 |
+
"""The attention implementation chosen."""
|
| 182 |
+
return self._attention_type
|
| 183 |
+
|
| 184 |
+
@attention_type.setter
|
| 185 |
+
def attention_type(self, val):
|
| 186 |
+
if not self._attention_builder.validate_attention_type(val):
|
| 187 |
+
raise ValueError(("{!r} is not an available attention "
|
| 188 |
+
"type").format(val))
|
| 189 |
+
self._attention_type = val
|
| 190 |
+
|
| 191 |
+
def __setattr__(self, key, val):
|
| 192 |
+
# "protected" attributes are settable (probably from withing the class)
|
| 193 |
+
if key[0] == "_":
|
| 194 |
+
return super().__setattr__(key, val)
|
| 195 |
+
|
| 196 |
+
# Existing attributes are settable but they might also be attention
|
| 197 |
+
# parameters so try that as well
|
| 198 |
+
fail_on_exception = True
|
| 199 |
+
if hasattr(self, key):
|
| 200 |
+
super().__setattr__(key, val)
|
| 201 |
+
fail_on_exception = False
|
| 202 |
+
|
| 203 |
+
# Non-existing "public" attributes may be attention parameters
|
| 204 |
+
try:
|
| 205 |
+
setattr(self._attention_builder, key, val)
|
| 206 |
+
except:
|
| 207 |
+
if fail_on_exception:
|
| 208 |
+
raise
|
| 209 |
+
|
| 210 |
+
def get(self):
|
| 211 |
+
"""Build the transformer and return it."""
|
| 212 |
+
# Set the event dispatcher to the attention builder
|
| 213 |
+
self.attention.event_dispatcher = self.event_dispatcher
|
| 214 |
+
|
| 215 |
+
# Extract into local variables the classes to be used
|
| 216 |
+
Encoder = self._get_encoder_class()
|
| 217 |
+
EncoderLayer = self._get_encoder_layer_class()
|
| 218 |
+
Attention = self._get_attention_layer_class()
|
| 219 |
+
|
| 220 |
+
model_dimensions = self.value_dimensions*self.n_heads
|
| 221 |
+
return Encoder(
|
| 222 |
+
[
|
| 223 |
+
EncoderLayer(
|
| 224 |
+
Attention(
|
| 225 |
+
self.attention.get(self.attention_type),
|
| 226 |
+
model_dimensions,
|
| 227 |
+
self.n_heads,
|
| 228 |
+
d_keys=self.query_dimensions,
|
| 229 |
+
d_values=self.value_dimensions,
|
| 230 |
+
event_dispatcher=self.event_dispatcher
|
| 231 |
+
),
|
| 232 |
+
model_dimensions,
|
| 233 |
+
self.feed_forward_dimensions,
|
| 234 |
+
self.dropout,
|
| 235 |
+
self.activation,
|
| 236 |
+
event_dispatcher=self.event_dispatcher
|
| 237 |
+
)
|
| 238 |
+
for _ in range(self.n_layers)
|
| 239 |
+
],
|
| 240 |
+
(LayerNorm(model_dimensions) if self.final_normalization else None),
|
| 241 |
+
event_dispatcher=self.event_dispatcher
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class TransformerEncoderBuilder(BaseTransformerEncoderBuilder):
|
| 246 |
+
"""Build a batch transformer encoder for training or processing of
|
| 247 |
+
sequences all elements at a time.
|
| 248 |
+
|
| 249 |
+
Example usage:
|
| 250 |
+
|
| 251 |
+
builder = TransformerEncoderBuilder()
|
| 252 |
+
builder.n_layers = 12
|
| 253 |
+
builder.n_heads = 8
|
| 254 |
+
builder.feed_forward_dimensions = 1024
|
| 255 |
+
builder.query_dimensions = 64
|
| 256 |
+
builder.value_dimensions = 64
|
| 257 |
+
builder.dropout = 0.1
|
| 258 |
+
builder.attention_dropout = 0.1
|
| 259 |
+
builder.attention_type = "linear"
|
| 260 |
+
transformer = builder.get()
|
| 261 |
+
"""
|
| 262 |
+
def _get_attention_builder(self):
|
| 263 |
+
"""Return an instance of the appropriate attention builder."""
|
| 264 |
+
return AttentionBuilder()
|
| 265 |
+
|
| 266 |
+
def _get_attention_layer_class(self):
|
| 267 |
+
"""Return the class for the layer that projects queries keys and
|
| 268 |
+
values."""
|
| 269 |
+
return AttentionLayer
|
| 270 |
+
|
| 271 |
+
def _get_encoder_class(self):
|
| 272 |
+
"""Return the class for the transformer encoder."""
|
| 273 |
+
return TransformerEncoder
|
| 274 |
+
|
| 275 |
+
def _get_encoder_layer_class(self):
|
| 276 |
+
"""Return the class for the transformer encoder layer."""
|
| 277 |
+
return TransformerEncoderLayer
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class RecurrentEncoderBuilder(BaseTransformerEncoderBuilder):
|
| 281 |
+
"""Build a transformer encoder for autoregressive processing of sequences.
|
| 282 |
+
|
| 283 |
+
Example usage:
|
| 284 |
+
|
| 285 |
+
builder = RecurrentEncoderBuilder()
|
| 286 |
+
builder.n_layers = 12
|
| 287 |
+
builder.n_heads = 8
|
| 288 |
+
builder.feed_forward_dimensions = 1024
|
| 289 |
+
builder.query_dimensions = 64
|
| 290 |
+
builder.value_dimensions = 64
|
| 291 |
+
builder.dropout = 0.1
|
| 292 |
+
builder.attention_dropout = 0.1
|
| 293 |
+
builder.attention_type = "linear"
|
| 294 |
+
transformer = builder.get()
|
| 295 |
+
"""
|
| 296 |
+
def _get_attention_builder(self):
|
| 297 |
+
"""Return an attention builder for recurrent attention."""
|
| 298 |
+
return RecurrentAttentionBuilder()
|
| 299 |
+
|
| 300 |
+
def _get_attention_layer_class(self):
|
| 301 |
+
"""Return the class for the recurrent layer that projects queries keys
|
| 302 |
+
and values."""
|
| 303 |
+
return RecurrentAttentionLayer
|
| 304 |
+
|
| 305 |
+
def _get_encoder_class(self):
|
| 306 |
+
"""Return the class for the recurrent transformer encoder."""
|
| 307 |
+
return RecurrentTransformerEncoder
|
| 308 |
+
|
| 309 |
+
def _get_encoder_layer_class(self):
|
| 310 |
+
"""Return the class for the recurrent transformer encoder layer."""
|
| 311 |
+
return RecurrentTransformerEncoderLayer
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class BaseTransformerDecoderBuilder(BaseTransformerBuilder):
|
| 315 |
+
"""Similar to BaseTransformerEncoderBuilder implement the logic of
|
| 316 |
+
building the transformer decoder without defining concrete layers.
|
| 317 |
+
|
| 318 |
+
Inheriting classes should implement the following:
|
| 319 |
+
|
| 320 |
+
- _get_self_attention_builder() and _get_cross_attention_builder()
|
| 321 |
+
- _get_self_attention_layer_class() and _get_cross_attention_layer_class()
|
| 322 |
+
- _get_decoder_class()
|
| 323 |
+
- _get_decoder_layer_class()
|
| 324 |
+
"""
|
| 325 |
+
def __init__(self):
|
| 326 |
+
super(BaseTransformerDecoderBuilder, self).__init__()
|
| 327 |
+
self._self_attention_builder = self._get_self_attention_builder()
|
| 328 |
+
self._cross_attention_builder = self._get_cross_attention_builder()
|
| 329 |
+
self._self_attention_type = "full"
|
| 330 |
+
self._cross_attention_type = "full"
|
| 331 |
+
|
| 332 |
+
def _get_self_attention_builder(self):
|
| 333 |
+
"""Return an instance of attention builder."""
|
| 334 |
+
raise NotImplementedError()
|
| 335 |
+
|
| 336 |
+
def _get_cross_attention_builder(self):
|
| 337 |
+
"""Return an instance of attention builder."""
|
| 338 |
+
raise NotImplementedError()
|
| 339 |
+
|
| 340 |
+
def _get_self_attention_layer_class(self):
|
| 341 |
+
"""Return a class to project the queries, keys and values to
|
| 342 |
+
multi-head versions."""
|
| 343 |
+
raise NotImplementedError()
|
| 344 |
+
|
| 345 |
+
def _get_cross_attention_layer_class(self):
|
| 346 |
+
"""Return a class to project the queries, keys and values to
|
| 347 |
+
multi-head versions."""
|
| 348 |
+
raise NotImplementedError()
|
| 349 |
+
|
| 350 |
+
def _get_decoder_class(self):
|
| 351 |
+
"""Return the class for the transformer decoder."""
|
| 352 |
+
raise NotImplementedError()
|
| 353 |
+
|
| 354 |
+
def _get_decoder_layer_class(self):
|
| 355 |
+
"""Return the class for the transformer decoder layer."""
|
| 356 |
+
raise NotImplementedError()
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def self_attention(self):
|
| 360 |
+
"""The attention builder instance that will be used for the self
|
| 361 |
+
attention modules."""
|
| 362 |
+
return self._self_attention_builder
|
| 363 |
+
|
| 364 |
+
@property
|
| 365 |
+
def self_attention_type(self):
|
| 366 |
+
"""The attention implementation used for self attention."""
|
| 367 |
+
return self._self_attention_type
|
| 368 |
+
|
| 369 |
+
@self_attention_type.setter
|
| 370 |
+
def self_attention_type(self, val):
|
| 371 |
+
if not self._self_attention_builder.validate_attention_type(val):
|
| 372 |
+
raise ValueError(("{!r} is not an available self attention "
|
| 373 |
+
"type").format(val))
|
| 374 |
+
self._self_attention_type = val
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def cross_attention(self):
|
| 378 |
+
"""The attention builder instance that will be used for the cross
|
| 379 |
+
attention modules."""
|
| 380 |
+
return self._cross_attention_builder
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def cross_attention_type(self):
|
| 384 |
+
"""The attention implementation used for cross attention."""
|
| 385 |
+
return self._cross_attention_type
|
| 386 |
+
|
| 387 |
+
@cross_attention_type.setter
|
| 388 |
+
def cross_attention_type(self, val):
|
| 389 |
+
if not self._cross_attention_builder.validate_attention_type(val):
|
| 390 |
+
raise ValueError(("{!r} is not an available cross attention "
|
| 391 |
+
"type").format(val))
|
| 392 |
+
self._cross_attention_type = val
|
| 393 |
+
|
| 394 |
+
def __setattr__(self, key, val):
|
| 395 |
+
# "protected" attributes are settable (probably from withing the class)
|
| 396 |
+
if key[0] == "_":
|
| 397 |
+
return super().__setattr__(key, val)
|
| 398 |
+
|
| 399 |
+
# Existing attributes are settable but they might also be attention
|
| 400 |
+
# parameters so try that as well
|
| 401 |
+
fail_on_exception = True
|
| 402 |
+
if hasattr(self, key):
|
| 403 |
+
super().__setattr__(key, val)
|
| 404 |
+
fail_on_exception = False
|
| 405 |
+
|
| 406 |
+
# Non-existing "public" attributes may be attention parameters
|
| 407 |
+
try:
|
| 408 |
+
setattr(self._self_attention_builder, key, val)
|
| 409 |
+
setattr(self._cross_attention_builder, key, val)
|
| 410 |
+
except:
|
| 411 |
+
if fail_on_exception:
|
| 412 |
+
raise
|
| 413 |
+
|
| 414 |
+
def get(self):
|
| 415 |
+
"""Build the transformer and return it."""
|
| 416 |
+
# Set the event dispatcher to attention builders
|
| 417 |
+
self.self_attention.event_dispatcher = self.event_dispatcher
|
| 418 |
+
self.cross_attention.event_dispatcher = self.event_dispatcher
|
| 419 |
+
|
| 420 |
+
# Extract into local variables the classes to be used
|
| 421 |
+
Decoder = self._get_decoder_class()
|
| 422 |
+
DecoderLayer = self._get_decoder_layer_class()
|
| 423 |
+
SelfAttention = self._get_self_attention_layer_class()
|
| 424 |
+
CrossAttention = self._get_cross_attention_layer_class()
|
| 425 |
+
|
| 426 |
+
model_dimensions = self.value_dimensions*self.n_heads
|
| 427 |
+
return Decoder(
|
| 428 |
+
[
|
| 429 |
+
DecoderLayer(
|
| 430 |
+
SelfAttention(
|
| 431 |
+
self.self_attention.get(self.self_attention_type),
|
| 432 |
+
model_dimensions,
|
| 433 |
+
self.n_heads,
|
| 434 |
+
d_keys=self.query_dimensions,
|
| 435 |
+
d_values=self.value_dimensions,
|
| 436 |
+
event_dispatcher=self.event_dispatcher
|
| 437 |
+
),
|
| 438 |
+
CrossAttention(
|
| 439 |
+
self.cross_attention.get(self.cross_attention_type),
|
| 440 |
+
model_dimensions,
|
| 441 |
+
self.n_heads,
|
| 442 |
+
d_keys=self.query_dimensions,
|
| 443 |
+
d_values=self.value_dimensions,
|
| 444 |
+
event_dispatcher=self.event_dispatcher
|
| 445 |
+
),
|
| 446 |
+
model_dimensions,
|
| 447 |
+
self.feed_forward_dimensions,
|
| 448 |
+
self.dropout,
|
| 449 |
+
self.activation,
|
| 450 |
+
event_dispatcher=self.event_dispatcher
|
| 451 |
+
)
|
| 452 |
+
for _ in range(self.n_layers)
|
| 453 |
+
],
|
| 454 |
+
(LayerNorm(model_dimensions) if self.final_normalization else None),
|
| 455 |
+
event_dispatcher=self.event_dispatcher
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class TransformerDecoderBuilder(BaseTransformerDecoderBuilder):
|
| 460 |
+
"""Build a transformer decoder for training or processing of sequences all
|
| 461 |
+
elements at a time.
|
| 462 |
+
|
| 463 |
+
Example usage:
|
| 464 |
+
|
| 465 |
+
builder = TransformerDecoderBuilder()
|
| 466 |
+
builder.n_layers = 12
|
| 467 |
+
builder.n_heads = 8
|
| 468 |
+
builder.feed_forward_dimensions = 1024
|
| 469 |
+
builder.query_dimensions = 64
|
| 470 |
+
builder.value_dimensions = 64
|
| 471 |
+
builder.dropout = 0.1
|
| 472 |
+
builder.attention_dropout = 0.1
|
| 473 |
+
builder.self_attention_type = "full"
|
| 474 |
+
builder.cross_attention_type = "full"
|
| 475 |
+
transformer = builder.get()
|
| 476 |
+
"""
|
| 477 |
+
def _get_self_attention_builder(self):
|
| 478 |
+
"""Return an attention builder for creating non-recurrent attention
|
| 479 |
+
variants."""
|
| 480 |
+
return AttentionBuilder()
|
| 481 |
+
|
| 482 |
+
def _get_cross_attention_builder(self):
|
| 483 |
+
"""Return an attention builder for creating non-recurrent attention
|
| 484 |
+
variants."""
|
| 485 |
+
return AttentionBuilder()
|
| 486 |
+
|
| 487 |
+
def _get_self_attention_layer_class(self):
|
| 488 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
| 489 |
+
and values."""
|
| 490 |
+
return AttentionLayer
|
| 491 |
+
|
| 492 |
+
def _get_cross_attention_layer_class(self):
|
| 493 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
| 494 |
+
and values."""
|
| 495 |
+
return AttentionLayer
|
| 496 |
+
|
| 497 |
+
def _get_decoder_class(self):
|
| 498 |
+
"""Return the transformer decoder class."""
|
| 499 |
+
return TransformerDecoder
|
| 500 |
+
|
| 501 |
+
def _get_decoder_layer_class(self):
|
| 502 |
+
"""Return the transformer decoder layer class."""
|
| 503 |
+
return TransformerDecoderLayer
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class RecurrentDecoderBuilder(BaseTransformerDecoderBuilder):
|
| 507 |
+
"""Build a transformer decoder for processing of sequences in
|
| 508 |
+
autoregressive fashion.
|
| 509 |
+
|
| 510 |
+
Example usage:
|
| 511 |
+
|
| 512 |
+
builder = RecurrentDecoderBuilder()
|
| 513 |
+
builder.n_layers = 12
|
| 514 |
+
builder.n_heads = 8
|
| 515 |
+
builder.feed_forward_dimensions = 1024
|
| 516 |
+
builder.query_dimensions = 64
|
| 517 |
+
builder.value_dimensions = 64
|
| 518 |
+
builder.dropout = 0.1
|
| 519 |
+
builder.attention_dropout = 0.1
|
| 520 |
+
builder.self_attention_type = "full"
|
| 521 |
+
builder.cross_attention_type = "full"
|
| 522 |
+
transformer = builder.get()
|
| 523 |
+
"""
|
| 524 |
+
def _get_self_attention_builder(self):
|
| 525 |
+
"""Return an attention builder for creating non-recurrent attention
|
| 526 |
+
variants."""
|
| 527 |
+
return RecurrentAttentionBuilder()
|
| 528 |
+
|
| 529 |
+
def _get_cross_attention_builder(self):
|
| 530 |
+
"""Return an attention builder for creating non-recurrent attention
|
| 531 |
+
variants."""
|
| 532 |
+
return RecurrentCrossAttentionBuilder()
|
| 533 |
+
|
| 534 |
+
def _get_self_attention_layer_class(self):
|
| 535 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
| 536 |
+
and values."""
|
| 537 |
+
return RecurrentAttentionLayer
|
| 538 |
+
|
| 539 |
+
def _get_cross_attention_layer_class(self):
|
| 540 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
| 541 |
+
and values."""
|
| 542 |
+
return RecurrentCrossAttentionLayer
|
| 543 |
+
|
| 544 |
+
def _get_decoder_class(self):
|
| 545 |
+
"""Return the transformer decoder class."""
|
| 546 |
+
return RecurrentTransformerDecoder
|
| 547 |
+
|
| 548 |
+
def _get_decoder_layer_class(self):
|
| 549 |
+
"""Return the transformer decoder layer class."""
|
| 550 |
+
return RecurrentTransformerDecoderLayer
|
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \
|
| 10 |
+
causal_dot_backward as causal_dot_backward_cpu
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .causal_product_cuda import \
|
| 14 |
+
causal_dot_product as causal_dot_product_cuda, \
|
| 15 |
+
causal_dot_backward as causal_dot_backward_cuda
|
| 16 |
+
except ImportError:
|
| 17 |
+
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CausalDotProduct(torch.autograd.Function):
|
| 21 |
+
"""Compute the weighted sum of values but attending only to previous
|
| 22 |
+
values."""
|
| 23 |
+
dot = {
|
| 24 |
+
"cpu": causal_dot_product_cpu,
|
| 25 |
+
"cuda": causal_dot_product_cuda
|
| 26 |
+
}
|
| 27 |
+
dot_backward = {
|
| 28 |
+
"cpu": causal_dot_backward_cpu,
|
| 29 |
+
"cuda": causal_dot_backward_cuda
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def forward(ctx, Q, K, V):
|
| 34 |
+
# Save the inputs for the gradient computation
|
| 35 |
+
ctx.save_for_backward(Q, K, V)
|
| 36 |
+
|
| 37 |
+
# Create the output tensor
|
| 38 |
+
device = Q.device
|
| 39 |
+
N, H, L, _ = Q.shape
|
| 40 |
+
_, _, _, M = V.shape
|
| 41 |
+
product = torch.zeros((N, H, L, M), device=device)
|
| 42 |
+
|
| 43 |
+
# Actually perform the dot product
|
| 44 |
+
CausalDotProduct.dot[device.type](
|
| 45 |
+
Q.data,
|
| 46 |
+
K.data,
|
| 47 |
+
V.data,
|
| 48 |
+
product
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return product
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def backward(ctx, grad_out):
|
| 55 |
+
# Extract the saved tensors
|
| 56 |
+
Q, K, V = ctx.saved_tensors
|
| 57 |
+
|
| 58 |
+
# Allocate memory for the gradients
|
| 59 |
+
grad_Q = torch.zeros_like(Q)
|
| 60 |
+
grad_K = torch.zeros_like(K)
|
| 61 |
+
grad_V = torch.zeros_like(V)
|
| 62 |
+
|
| 63 |
+
# Actually compute the gradients
|
| 64 |
+
CausalDotProduct.dot_backward[Q.device.type](
|
| 65 |
+
Q.data,
|
| 66 |
+
K.data,
|
| 67 |
+
V.data,
|
| 68 |
+
grad_out,
|
| 69 |
+
grad_Q,
|
| 70 |
+
grad_K,
|
| 71 |
+
grad_V
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return grad_Q, grad_K, grad_V
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Alias the autograd functions to python style snake case naming
|
| 78 |
+
causal_dot_product = CausalDotProduct.apply
|
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84f32370e707beebd8fee88f356fb62721096142265895a5a8e9872063c04595
|
| 3 |
+
size 140928
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py
ADDED
|
File without changes
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
| 4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .cluster_cpu import cluster as cluster_cpu
|
| 13 |
+
try:
|
| 14 |
+
from .cluster_cuda import cluster as cluster_gpu
|
| 15 |
+
except ImportError:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def cluster(
|
| 20 |
+
hashes,
|
| 21 |
+
lengths,
|
| 22 |
+
groups=None,
|
| 23 |
+
counts=None,
|
| 24 |
+
centroids=None,
|
| 25 |
+
distances=None,
|
| 26 |
+
bitcounts=None,
|
| 27 |
+
clusters=30,
|
| 28 |
+
iterations=10,
|
| 29 |
+
bits=32
|
| 30 |
+
):
|
| 31 |
+
"""Cluster hashes using a few iterations of K-Means with hamming distance.
|
| 32 |
+
|
| 33 |
+
All the tensors default initialized to None are optional buffers to avoid
|
| 34 |
+
memory allocations. distances and bitcounts are only used by the CUDA
|
| 35 |
+
version of this call. clusters will be ignored if centroids is provided.
|
| 36 |
+
|
| 37 |
+
Arguments
|
| 38 |
+
---------
|
| 39 |
+
hashes: A long tensor of shape (N, H, L) containing a hashcode for each
|
| 40 |
+
query.
|
| 41 |
+
lengths: An int tensor of shape (N,) containing the sequence length for
|
| 42 |
+
each sequence in hashes.
|
| 43 |
+
groups: An int tensor buffer of shape (N, H, L) contaning the cluster
|
| 44 |
+
in which the corresponding hash belongs to.
|
| 45 |
+
counts: An int tensor buffer of shape (N, H, K) containing the number
|
| 46 |
+
of elements in each cluster.
|
| 47 |
+
centroids: A long tensor buffer of shape (N, H, K) containing the
|
| 48 |
+
centroid for each cluster.
|
| 49 |
+
distances: An int tensor of shape (N, H, L) containing the distance to
|
| 50 |
+
the closest centroid for each hash.
|
| 51 |
+
bitcounts: An int tensor of shape (N, H, K, bits) containing the number
|
| 52 |
+
of elements that have 1 for a given bit.
|
| 53 |
+
clusters: The number of clusters to use for each sequence. It is
|
| 54 |
+
ignored if centroids is not None.
|
| 55 |
+
iterations: How many k-means iterations to perform.
|
| 56 |
+
bits: How many of the least-significant bits in hashes to consider.
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
groups and counts as defined above.
|
| 61 |
+
"""
|
| 62 |
+
device = hashes.device
|
| 63 |
+
N, H, L = hashes.shape
|
| 64 |
+
|
| 65 |
+
# Unfortunately cpu and gpu have different APIs so the entire call must be
|
| 66 |
+
# surrounded by an if-then-else
|
| 67 |
+
if device.type == "cpu":
|
| 68 |
+
if groups is None:
|
| 69 |
+
groups = torch.empty((N, H, L), dtype=torch.int32)
|
| 70 |
+
if centroids is None:
|
| 71 |
+
centroids = torch.empty((N, H, clusters), dtype=torch.int64)
|
| 72 |
+
centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
|
| 73 |
+
K = centroids.shape[2]
|
| 74 |
+
if counts is None:
|
| 75 |
+
counts = torch.empty((N, H, K), dtype=torch.int32)
|
| 76 |
+
|
| 77 |
+
cluster_cpu(
|
| 78 |
+
hashes, lengths,
|
| 79 |
+
centroids, groups, counts,
|
| 80 |
+
iterations, bits
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return groups, counts
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
if groups is None:
|
| 87 |
+
groups = torch.empty((N, H, L), dtype=torch.int32, device=device)
|
| 88 |
+
if centroids is None:
|
| 89 |
+
centroids = torch.empty((N, H, clusters), dtype=torch.int64,
|
| 90 |
+
device=device)
|
| 91 |
+
centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
|
| 92 |
+
K = centroids.numel() // N // H
|
| 93 |
+
#K = clusters
|
| 94 |
+
if counts is None:
|
| 95 |
+
counts = torch.empty((N, H, K), dtype=torch.int32, device=device)
|
| 96 |
+
if distances is None:
|
| 97 |
+
distances = torch.empty((N, H, L), dtype=torch.int32,
|
| 98 |
+
device=device)
|
| 99 |
+
if bitcounts is None:
|
| 100 |
+
bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32,
|
| 101 |
+
device=device)
|
| 102 |
+
groups = groups.view(N, H, L)
|
| 103 |
+
counts = counts.view(N, H, K)
|
| 104 |
+
centroids = centroids.view(N, H, K)
|
| 105 |
+
distances = distances.view(N, H, L)
|
| 106 |
+
bitcounts = bitcounts.view(N, H, K, -1)
|
| 107 |
+
|
| 108 |
+
cluster_gpu(
|
| 109 |
+
hashes, lengths,
|
| 110 |
+
centroids, distances, bitcounts, groups, counts,
|
| 111 |
+
iterations, bits
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return groups, counts
|
| 115 |
+
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2bd8f761d6e1efdeea33665cad8702b5c07d1a0db728d19cf332c4383510d45
|
| 3 |
+
size 139824
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""This module implements a basic event system that allows the transformer
|
| 7 |
+
internal components to make available any tensor with minimal overhead."""
|
| 8 |
+
|
| 9 |
+
from .event import Event, AttentionEvent, QKVEvent
|
| 10 |
+
from .event_dispatcher import EventDispatcher
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (556 Bytes). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc
ADDED
|
Binary file (5.82 kB). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/event.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Event(object):
|
| 8 |
+
"""The Event is the base class for all events that are dispatched from any
|
| 9 |
+
transformer module.
|
| 10 |
+
|
| 11 |
+
This class defines only the basic attributes of an event without any
|
| 12 |
+
payload.
|
| 13 |
+
|
| 14 |
+
Arguments
|
| 15 |
+
---------
|
| 16 |
+
source: torch.nn.Module instance that dispatched this event
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, source):
|
| 19 |
+
self.source = source
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AttentionEvent(Event):
|
| 23 |
+
"""An event containing an attention matrix.
|
| 24 |
+
|
| 25 |
+
Arguments
|
| 26 |
+
---------
|
| 27 |
+
source: torch.nn.Module instance that dispatched this event
|
| 28 |
+
attention_matrix: torch.tensor of the multihead attention matrix
|
| 29 |
+
computed in the corresponding attention layer
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, source, attention_matrix):
|
| 32 |
+
super(AttentionEvent, self).__init__(source)
|
| 33 |
+
self.attention_matrix = attention_matrix
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class QKVEvent(Event):
|
| 37 |
+
"""An event containing the queries, keys and values projected in their
|
| 38 |
+
multiple heads.
|
| 39 |
+
|
| 40 |
+
Arguments
|
| 41 |
+
---------
|
| 42 |
+
source: torch.nn.Module instance that dispatched this event
|
| 43 |
+
queries: torch.tensor containing the queries in shape NLHE
|
| 44 |
+
keys: torch.tensor containing the keys in shape NSHE
|
| 45 |
+
values: torch.tensor containing the values in shape NSHD
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, source, queries, keys, values):
|
| 48 |
+
super(QKVEvent, self).__init__(source)
|
| 49 |
+
self.queries = queries
|
| 50 |
+
self.keys = keys
|
| 51 |
+
self.values = values
|
smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
from .event import Event
|
| 9 |
+
from .filters import event_class
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EventDispatcher(object):
|
| 13 |
+
"""An EventDispatcher is a simple way to implement an observer pattern for
|
| 14 |
+
loose coupling of components. In our case it is used so that the internals
|
| 15 |
+
of large neural networks can communicate with the outside world in an
|
| 16 |
+
agnostic and efficient way.
|
| 17 |
+
|
| 18 |
+
Example usage
|
| 19 |
+
-------------
|
| 20 |
+
|
| 21 |
+
from fast_transformers.events import EventDispatcher, AttentionEvent
|
| 22 |
+
from fast_transformers.events.filters import \
|
| 23 |
+
layer_name_contains
|
| 24 |
+
|
| 25 |
+
def attention_event_handler(event):
|
| 26 |
+
print(event.attention_matrix)
|
| 27 |
+
|
| 28 |
+
ed = EventDispatcher()
|
| 29 |
+
ed.listen(AttentionEvent, attention_event_handler)
|
| 30 |
+
ed.listen(
|
| 31 |
+
AttentionEvent & layer_name_contains("layers.12"),
|
| 32 |
+
attention_event_handler
|
| 33 |
+
)
|
| 34 |
+
"""
|
| 35 |
+
_dispatchers = {}
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self._listeners = OrderedDict()
|
| 39 |
+
|
| 40 |
+
def listen(self, event_filter, event_handler):
|
| 41 |
+
"""Add an event handler for the events that pass the event filter.
|
| 42 |
+
|
| 43 |
+
Arguments
|
| 44 |
+
---------
|
| 45 |
+
event_filter: callable or Event class to define for which events
|
| 46 |
+
this handler will be called
|
| 47 |
+
event_handler: callable that accepts an instance of Event
|
| 48 |
+
"""
|
| 49 |
+
if isinstance(event_filter, type) and issubclass(event_filter, Event):
|
| 50 |
+
event_filter = event_class(event_filter)
|
| 51 |
+
|
| 52 |
+
self._listeners[event_handler] = event_filter
|
| 53 |
+
|
| 54 |
+
def remove(self, event_handler):
|
| 55 |
+
"""Remove the event_handler from the listeners so that no more events
|
| 56 |
+
are dispatched to this handler."""
|
| 57 |
+
self._listeners.pop(event_handler, None)
|
| 58 |
+
|
| 59 |
+
def clear(self):
|
| 60 |
+
"""Remove all listeners from the event dispatcher."""
|
| 61 |
+
self._listeners.clear()
|
| 62 |
+
|
| 63 |
+
def dispatch(self, event):
|
| 64 |
+
"""Dispatch an event to the listeners.
|
| 65 |
+
|
| 66 |
+
Arguments
|
| 67 |
+
---------
|
| 68 |
+
event: Event instance
|
| 69 |
+
"""
|
| 70 |
+
for event_handler, event_filter in self._listeners.items():
|
| 71 |
+
if event_filter(event):
|
| 72 |
+
event_handler(event)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def get(cls, key=""):
|
| 76 |
+
"""Factory method for creating global event dispatchers for loosely
|
| 77 |
+
coupling parts of a larger codebase.
|
| 78 |
+
|
| 79 |
+
Since global objects are a complete antipattern, we suggest that this
|
| 80 |
+
is only used to set a default value for an event dispatcher passed as
|
| 81 |
+
an argument.
|
| 82 |
+
|
| 83 |
+
Argument
|
| 84 |
+
--------
|
| 85 |
+
key: A key to uniquely identify a dispatcher or an instance of a
|
| 86 |
+
dispatcher to be returned as is
|
| 87 |
+
"""
|
| 88 |
+
if isinstance(key, cls):
|
| 89 |
+
return key
|
| 90 |
+
if key not in cls._dispatchers:
|
| 91 |
+
cls._dispatchers[key] = cls()
|
| 92 |
+
return cls._dispatchers[key]
|
smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Define composable functions to filter events."""
|
| 7 |
+
|
| 8 |
+
import weakref
|
| 9 |
+
|
| 10 |
+
from .event import Event
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EventFilter(object):
|
| 14 |
+
"""EventFilter instances are predicates (ie functions that return True or
|
| 15 |
+
False) to be used with an event dispatcher for filtering event
|
| 16 |
+
instances.
|
| 17 |
+
|
| 18 |
+
The main benefit from using raw functions is that an EventFilter composes
|
| 19 |
+
very easily using operators such as &, |, ~.
|
| 20 |
+
|
| 21 |
+
Example
|
| 22 |
+
--------
|
| 23 |
+
|
| 24 |
+
event_filter = AttentionEvent | layer_name_contains("layers.1")
|
| 25 |
+
event_filter = from_layer(transformer.layers[2].attention)
|
| 26 |
+
event_filter = (
|
| 27 |
+
AttentionEvent &
|
| 28 |
+
lambda ev: torch.isnan(ev.attention_matrix).any()
|
| 29 |
+
)
|
| 30 |
+
"""
|
| 31 |
+
def __call__(self, event):
|
| 32 |
+
raise NotImplementedError()
|
| 33 |
+
|
| 34 |
+
def _to_event_filter(self, other):
|
| 35 |
+
if isinstance(other, EventFilter):
|
| 36 |
+
return other
|
| 37 |
+
if isinstance(other, type) and issubclass(other, Event):
|
| 38 |
+
return event_class(other)
|
| 39 |
+
if callable(other):
|
| 40 |
+
return CallableEventFilter(other)
|
| 41 |
+
|
| 42 |
+
return NotImplemented
|
| 43 |
+
|
| 44 |
+
def __and__(self, other):
|
| 45 |
+
other = self._to_event_filter(other)
|
| 46 |
+
if other is NotImplemented:
|
| 47 |
+
return other
|
| 48 |
+
return CallableEventFilter(lambda ev: self(ev) and other(ev))
|
| 49 |
+
|
| 50 |
+
def __rand__(self, other):
|
| 51 |
+
other = self._to_event_filter(other)
|
| 52 |
+
if other is NotImplemented:
|
| 53 |
+
return other
|
| 54 |
+
return CallableEventFilter(lambda ev: other(ev) and self(ev))
|
| 55 |
+
|
| 56 |
+
def __or__(self, other):
|
| 57 |
+
other = self._to_event_filter(other)
|
| 58 |
+
if other is NotImplemented:
|
| 59 |
+
return other
|
| 60 |
+
return CallableEventFilter(lambda ev: self(ev) or other(ev))
|
| 61 |
+
|
| 62 |
+
def __ror__(self, other):
|
| 63 |
+
other = self._to_event_filter(other)
|
| 64 |
+
if other is NotImplemented:
|
| 65 |
+
return other
|
| 66 |
+
return CallableEventFilter(lambda ev: other(ev) or self(ev))
|
| 67 |
+
|
| 68 |
+
def __invert__(self):
|
| 69 |
+
return CallableEventFilter(lambda ev: not self(ev))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CallableEventFilter(EventFilter):
|
| 73 |
+
"""Wrap a function with an EventFilter object."""
|
| 74 |
+
def __init__(self, event_filter):
|
| 75 |
+
self._event_filter = event_filter
|
| 76 |
+
|
| 77 |
+
def __call__(self, event):
|
| 78 |
+
return self._event_filter(event)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LayerNameEventFilter(EventFilter):
|
| 82 |
+
"""A LayerNameEventFilter allows to filter events based on a human readable
|
| 83 |
+
name of the layer that emitted them.
|
| 84 |
+
|
| 85 |
+
Note that LayerNameEventFilter keeps a weak reference to all modules which
|
| 86 |
+
means that it cannot be used to prevent modules from being garbage
|
| 87 |
+
collected.
|
| 88 |
+
|
| 89 |
+
Arguments
|
| 90 |
+
---------
|
| 91 |
+
root: torch.nn.Module instance that represents the root container
|
| 92 |
+
name_filter: callable, that returns true if the name
|
| 93 |
+
"""
|
| 94 |
+
def __init__(self, root, name_filter):
|
| 95 |
+
self._names = {
|
| 96 |
+
weakref.ref(m): n
|
| 97 |
+
for n, m in root.named_modules()
|
| 98 |
+
}
|
| 99 |
+
self._name_filter = name_filter
|
| 100 |
+
|
| 101 |
+
def __call__(self, event):
|
| 102 |
+
name = self._names.get(weakref.ref(event.source), None)
|
| 103 |
+
if name is None:
|
| 104 |
+
return False
|
| 105 |
+
return self._name_filter(name)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def event_class(klass):
|
| 109 |
+
"""Select events that are instances of `klass`.
|
| 110 |
+
|
| 111 |
+
Arguments
|
| 112 |
+
---------
|
| 113 |
+
klass: A class to check the event instance against
|
| 114 |
+
|
| 115 |
+
Returns
|
| 116 |
+
-------
|
| 117 |
+
An instance of EventFilter
|
| 118 |
+
"""
|
| 119 |
+
return CallableEventFilter(lambda ev: isinstance(ev, klass))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def from_layer(layer):
|
| 123 |
+
"""Select events that are dispatched from the `layer`.
|
| 124 |
+
|
| 125 |
+
Arguments
|
| 126 |
+
---------
|
| 127 |
+
layer: An instance of torch.nn.Module to check against the event source
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
An instance of EventFilter
|
| 132 |
+
"""
|
| 133 |
+
return CallableEventFilter(lambda ev: ev.source is layer)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def layer_name_contains(root, name):
|
| 137 |
+
"""Select events that contain `name` in their human readable name.
|
| 138 |
+
|
| 139 |
+
We use root.named_modules() to get human readable names for the layers.
|
| 140 |
+
"""
|
| 141 |
+
return LayerNameEventFilter(root, lambda n: name in n)
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""Implementations of feature maps to be used with linear attention and causal
|
| 7 |
+
linear attention."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from .base import elu_feature_map, ActivationFunctionFeatureMap
|
| 11 |
+
from .fourier_features import RandomFourierFeatures, Favor, \
|
| 12 |
+
SmoothedRandomFourierFeatures, GeneralizedRandomFeatures
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (614 Bytes). View file
|
|
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|