# Original Copyright 2020 Google
# Additional modifications Copyright 2020 Reid Swanson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Python Modules
import logging
from typing import Any, Dict
# 3rd Party Modules
import tensorflow as tf
import tensorflow.keras.layers as tf_layers
# Project Modules
# In the transformer tutorial n_model is the number of dimensions used for
# the word embeddings. Since we are currently only working with dense inputs
# and no embeddings (yet), n_model is the number of input features
# (i.e., n_features).
from deletor.constants import MIN_FLOAT_32
log = logging.getLogger(__name__)
[docs]class ModelParameter(object):
N_LAYERS = 'n_layers'
N_MODEL = 'n_model'
N_FEATURES = 'n_features'
GROUP_SIZE = 'group_size'
N_HEADS = 'n_heads'
N_FF_UNITS = 'n_feed_forward_units'
DROPOUT_RATE = 'dropout_rate'
USE_LAYER_NORM = 'use_layer_norm'
USE_AVERAGE = 'use_average'
SHARE_WEIGHTS = 'share_weights'
[docs]def make_point_wise_feed_forward_network(n_model, dff):
"""
:param n_model:
:param dff:
:return:
"""
return tf.keras.Sequential([
tf_layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf_layers.Dense(n_model) # (batch_size, seq_len, n_model)
])
[docs]class MultiHeadAttention(tf_layers.Layer):
def __init__(self, n_model, num_heads):
"""
:param n_model:
:param num_heads:
"""
super().__init__()
self.num_heads = num_heads
self.n_model = n_model
if n_model % num_heads != 0:
raise ValueError("The number of heads must be an exact multiple of n_model")
self.depth = n_model // num_heads
self.wq = tf_layers.Dense(n_model)
self.wk = tf_layers.Dense(n_model)
self.wv = tf_layers.Dense(n_model)
self.dense = tf_layers.Dense(n_model)
[docs] def split_heads(self, x, batch_size):
"""
Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
x = tf.transpose(x, perm=[0, 2, 1, 3])
return x
[docs] @classmethod
def make_attention_mask(cls, x, y, multi_head: bool = True, pad_value: float = MIN_FLOAT_32):
"""
:param x:
:param y:
:param multi_head:
:param pad_value:
:return:
"""
ys = y[1] if isinstance(y, (tuple, list)) else y
y_rank = tf.rank(ys)
if y_rank == 3:
batch_size = tf.shape(ys)[0]
n_samples = tf.shape(ys)[1]
group_size = tf.shape(ys)[2]
ys = tf.reshape(tf.identity(ys), (batch_size * n_samples, group_size))
mask = tf.cast(tf.math.equal(ys, pad_value), tf.float32)
# This has an extra dimension because we split the input into multiple
# heads later.
# So it should have the shape: (batch_size, n_heads, seqlen, n_features)
if multi_head is True:
x['attention_mask'] = mask[:, tf.newaxis, tf.newaxis, :]
else:
x['attention_mask'] = mask[:, tf.newaxis, :]
return x, y
[docs] @classmethod
def scaled_dot_product_attention(
cls,
q: tf.Tensor,
k: tf.Tensor,
v: tf.Tensor, mask:
tf.Tensor = None
):
"""
Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
:param q: query shape == (..., seq_len_q, depth)
:param k: key shape == (..., seq_len_k, depth)
:param v: value shape == (..., seq_len_v, depth_v)
:param mask: Float tensor with shape broadcastable to
(..., seq_len_q, seq_len_k). Defaults to None.
:return: output, attention_weights
"""
q.shape.assert_same_rank(k.shape)
k.shape.assert_same_rank(v.shape)
# (..., seq_len_q, seq_len_k)
matmul_qk = tf.matmul(q, k, transpose_b=True)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
# noinspection PyTypeChecker
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
# (..., seq_len_q, seq_len_k)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
[docs] @classmethod
def reshape_output(cls, output: tf.Tensor, q: tf.Tensor):
batch_size = tf.shape(q)[0]
n_samples = tf.shape(q)[1]
group_size = tf.shape(q)[2]
n_features = tf.shape(q)[3]
return tf.reshape(output, (batch_size, n_samples, group_size, n_features))
# noinspection PyMethodOverriding
[docs] def call(self, v, k, q, mask, **kwargs):
"""
:param v:
:param k:
:param q:
:param mask:
:param kwargs:
:return:
"""
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, n_model)
k = self.wk(k) # (batch_size, seq_len, n_model)
v = self.wv(v) # (batch_size, seq_len, n_model)
seqlen = tf.shape(q)[1]
if self.num_heads > 1:
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
if self.num_heads > 1:
# (batch_size, seq_len_q, num_heads, depth)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, n_model)
concat_attention = tf.reshape(scaled_attention, (batch_size, seqlen, self.n_model))
# (batch_size, seq_len_q, n_model)
output = self.dense(concat_attention)
return output, attention_weights
[docs]class EncoderLayer(tf_layers.Layer):
def __init__(self, n_model: int, num_heads: int, d_ff: int, dropout_rate: float = 0.1):
"""
:param n_model:
:param num_heads:
:param d_ff:
:param dropout_rate:
"""
super().__init__()
self.multi_head_attention = MultiHeadAttention(n_model, num_heads)
self.feed_forward = make_point_wise_feed_forward_network(n_model, d_ff)
self.layer_norm_1 = tf_layers.LayerNormalization(epsilon=1e-6)
self.layer_norm_2 = tf_layers.LayerNormalization(epsilon=1e-6)
self.dropout_1 = tf_layers.Dropout(dropout_rate)
self.dropout_2 = tf_layers.Dropout(dropout_rate)
# noinspection PyMethodOverriding
[docs] def call(self, x, training, mask, **kwargs):
"""
:param x:
:param training:
:param mask:
:param kwargs:
:return:
"""
# (batch_size, input_seq_len, n_model)
attn_output, _ = self.multi_head_attention(x, x, x, mask)
attn_output = self.dropout_1(attn_output, training=training)
# (batch_size, input_seq_len, n_model)
output_1 = self.layer_norm_1(x + attn_output)
# (batch_size, input_seq_len, n_model)
ff_output = self.feed_forward(output_1)
ff_output = self.dropout_2(ff_output, training=training)
# (batch_size, input_seq_len, n_model)
output_2 = self.layer_norm_2(output_1 + ff_output)
return output_2
[docs]class Encoder(tf_layers.Layer):
def __init__(
self,
num_layers: int,
n_model: int,
num_heads: int,
d_ff: int,
dropout_rate: float = 0.1
):
"""
:param num_layers:
:param n_model:
:param num_heads:
:param d_ff:
:param dropout_rate:
"""
super().__init__()
self.n_model = n_model
self.num_layers = num_layers
self.encoder_layers = [
EncoderLayer(n_model, num_heads, d_ff, dropout_rate)
for _ in range(num_layers)
]
self.dropout = tf_layers.Dropout(dropout_rate)
# noinspection PyMethodOverriding
[docs] def call(self, x, mask, training, **kwargs):
"""
:param x:
:param mask:
:param training:
:param kwargs:
:return:
"""
# Note we don't need any positional information
x *= tf.math.sqrt(tf.cast(self.n_model, tf.float32))
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.encoder_layers[i](x, training, mask)
# (batch_size, input_seq_len, n_model)
return x
[docs]class GroupwiseMultiHeadAttentionNetwork(tf.keras.Model):
"""
This model implements a groupwise scoring model using normalized dot
product attention mechanism based on Attention Is All You Need.
"""
default_use_average = True
default_share_weights = False
def __init__(self, params: Dict[str, Any], **kwargs):
"""
:param params: See :class:`.ModelParameter` for the valid parameters.
:param kwargs:
"""
super().__init__(**kwargs)
use_average = params[ModelParameter.USE_AVERAGE]
n_layers = params[ModelParameter.N_LAYERS]
n_model = params[ModelParameter.N_FEATURES]
group_size = params[ModelParameter.GROUP_SIZE]
n_heads = params[ModelParameter.N_HEADS]
d_ff = params[ModelParameter.N_FF_UNITS]
dropout_rate = params[ModelParameter.DROPOUT_RATE]
self.model_params = params
self.group_size = tf.constant(group_size, tf.int32)
self.use_average = use_average
self.encoder = Encoder(n_layers, n_model, n_heads, d_ff, dropout_rate)
self.scoring_layer = tf_layers.Dense(group_size)
# noinspection PyMethodOverriding,DuplicatedCode
[docs] def call(self, x, training: bool = True, **kwargs):
"""
:param x:
:param training:
:param kwargs:
:return:
"""
xd = x['sequence_dense']
xs = x['sample_dense']
encoder_mask = x['attention_mask']
scatter_idx = x['scatter_idx']
batch_size = tf.shape(xd)[0]
n_documents = tf.shape(xd)[1]
n_samples = tf.shape(xs)[1]
group_size = tf.shape(xs)[2]
n_features = tf.shape(xd)[2]
input_lyr = tf.reshape(xs, (batch_size * n_samples, group_size, n_features))
attn_output = self.encoder(input_lyr, encoder_mask, training)
attn_output = tf.reshape(attn_output, (batch_size * n_samples, group_size * n_features))
sample_scores = self.scoring_layer(attn_output)
sample_scores = tf.reshape(sample_scores, (batch_size, n_samples, group_size))
scores = tf.scatter_nd(scatter_idx, sample_scores, (batch_size, n_documents))
if self.use_average is True:
document_counts = x['document_counts']
scores = tf.math.divide_no_nan(scores, document_counts)
return scores
# region Simple Attention Network
[docs]class AttentionLayer(tf_layers.Layer):
def __init__(
self, n_model: int,
n_ff_units: int,
use_layer_norm: bool = False,
dropout_rate: float = 0.1
):
"""
:param n_model:
:param n_ff_units:
:param use_layer_norm:
:param dropout_rate:
"""
super().__init__()
self.use_dropout = dropout_rate > 0
self.use_layer_norm = use_layer_norm
self.multi_head_attention = MultiHeadAttention(n_model, 1)
self.feed_forward = make_point_wise_feed_forward_network(n_model, n_ff_units)
if self.use_layer_norm is True:
self.layer_norm_1 = tf_layers.LayerNormalization(epsilon=1e-6)
self.layer_norm_2 = tf_layers.LayerNormalization(epsilon=1e-6)
if self.use_dropout is True:
self.dropout_1 = tf_layers.Dropout(dropout_rate)
self.dropout_2 = tf_layers.Dropout(dropout_rate)
# noinspection PyMethodOverriding
[docs] def call(self, x, training, mask, **kwargs):
"""
:param x:
:param training:
:param mask:
:param kwargs:
:return:
"""
# (batch_size, input_seq_len, n_model)
attn_output, _ = self.multi_head_attention(x, x, x, mask)
# print(f"attn_output {tf.shape(attn_output)}\n{attn_output}")
# print(f"attn_output {tf.shape(attn_output)}")
# print(f"x {tf.shape(x)}")
if self.use_dropout is True:
attn_output = self.dropout_1(attn_output, training=training)
if self.use_layer_norm is True:
# (batch_size, input_seq_len, n_model)
output_1 = self.layer_norm_1(x + attn_output)
else:
output_1 = attn_output
# (batch_size, input_seq_len, n_model)
ff_output = self.feed_forward(output_1)
if self.use_dropout is True:
ff_output = self.dropout_2(ff_output, training=training)
if self.use_layer_norm is True:
# (batch_size, input_seq_len, n_model)
output_2 = self.layer_norm_2(output_1 + ff_output)
else:
output_2 = ff_output
return output_2
[docs]class SelfAttention(tf_layers.Layer):
def __init__(
self,
n_layers: int,
n_model: int,
n_ff_units: int,
use_layer_norm: bool = False,
dropout_rate: float = 0.1
):
"""
:param n_layers:
:param n_model:
:param n_ff_units:
:param use_layer_norm:
:param dropout_rate:
"""
super().__init__()
self.n_model = n_model
self.n_layers = n_layers
self.attention_layers = [
AttentionLayer(n_model, n_ff_units, use_layer_norm, dropout_rate)
for _ in range(n_layers)
]
self.dropout = tf_layers.Dropout(dropout_rate)
# noinspection PyMethodOverriding
[docs] def call(self, x, mask, training, **kwargs):
"""
:param x:
:param mask:
:param training:
:param kwargs:
:return:
"""
# Note we don't need any positional information
x *= tf.math.sqrt(tf.cast(self.n_model, tf.float32))
x = self.dropout(x, training=training)
for i in range(self.n_layers):
x = self.attention_layers[i](x, training, mask)
# (batch_size, input_seq_len, n_model)
return x
[docs]class GroupwiseAttentionNetwork(tf.keras.Model):
def __init__(self, params: Dict[str, Any], **kwargs):
"""
:param params:
:param kwargs:
"""
super().__init__(**kwargs)
use_average = params[ModelParameter.USE_AVERAGE]
share_weights = params[ModelParameter.SHARE_WEIGHTS]
use_layer_norm = params[ModelParameter.USE_LAYER_NORM]
n_layers = params[ModelParameter.N_LAYERS]
n_model = params[ModelParameter.N_MODEL]
group_size = params[ModelParameter.GROUP_SIZE]
n_ff_units = params[ModelParameter.N_FF_UNITS]
dropout_rate = params[ModelParameter.DROPOUT_RATE]
self.model_params = params
self.group_size = tf.constant(group_size, tf.int32)
self.use_average = use_average
self.share_weights = share_weights
if self.share_weights is True:
self.shared_input_layer = self.make_shared_layer(n_model)
attention_args = [n_layers, n_model, n_ff_units, use_layer_norm, dropout_rate]
self.self_attention = SelfAttention(*attention_args)
self.scoring_layer = tf_layers.Dense(group_size)
[docs] @classmethod
def make_shared_layer(cls, n_model: int):
"""
:param n_model:
:return:
"""
return tf.keras.Sequential([
tf_layers.Dense(n_model),
tf_layers.BatchNormalization(),
tf_layers.PReLU()
])
[docs] def call(self, x: Dict[str, tf.Tensor], training: bool = True, **kwargs):
"""
:param x:
:param training:
:param kwargs:
:return:
"""
xd = x['sequence_dense'] # (batch_size, n_documents, n_features)
xs = x['sample_dense'] # (batch_size, n_samples, group_size, n_features)
encoder_mask = x['attention_mask'] # (batch_size, group_size, ?)?
scatter_idx = x['scatter_idx']
batch_size = tf.shape(xd)[0]
n_documents = tf.shape(xd)[1]
n_samples = tf.shape(xs)[1]
group_size = tf.shape(xs)[2]
n_features = tf.shape(xs)[3]
if self.share_weights is True:
# If we are using a shared input layer, then we will run each individual item
# through a dense layer.
# First, reshape so that each instance is in its own row
n_rows = batch_size * n_samples * group_size
n_cols = n_features
in_lyr = tf.reshape(xs, [n_rows, n_cols])
# Pass them through a dense layer (with PReLU activation and a BatchNorm)
in_lyr = self.shared_input_layer(in_lyr)
# Now reshape it so that each sample is a batch and the features of
# all items in a group are concatenated.
n_model = tf.shape(in_lyr)[1]
input_lyr = tf.reshape(in_lyr, (batch_size * n_samples, group_size, n_model))
else:
# Reshape the input so that each sample is treated like a batch element
input_lyr = tf.reshape(xs, (batch_size * n_samples, group_size, n_features))
n_model = n_features
# Apply the self attention mechanism, which has shape:
# (batch_size, group_size, n_model/n_features)
attn_output = self.self_attention(input_lyr, encoder_mask, training)
# Reshape it so it has the correct dimensions for the scoring layer.
attn_output = tf.reshape(attn_output, (batch_size * n_samples, group_size * n_model))
# It's possible that some queries have fewer documents than the group size.
# If all documents in the batch have fewer documents than the group size
# then there will be a mismatch in the expected input shape for the
# final layer (which expects exactly self.group_size * n_features) as input.
# To fix this issue we padd the attention output to have the correct
# size. It may be possible to apply this padding in the input pipeline
# for better efficiency, but it is surprisingly complicated, especially
# because most of the code was written before this was an issue (I hope).
padlen = (self.group_size * n_model) - (group_size * n_model)
def pad_attn():
paddings = tf.zeros([2, 2], tf.int32)
paddings = tf.tensor_scatter_nd_add(paddings, [[1, 1]], [padlen])
return tf.pad(attn_output, paddings)
attn_output = tf.cond(tf.greater(padlen, 0), pad_attn, lambda: attn_output)
# Apply the final layer
sample_scores = self.scoring_layer(attn_output)
sample_scores = sample_scores[..., :group_size]
# Reshape the output so it is compatible with the scatter index
sample_scores = tf.reshape(sample_scores, (batch_size, n_samples, group_size))
# Update (sum) the scores for each document
scores = tf.scatter_nd(scatter_idx, sample_scores, (batch_size, n_documents))
# If we are using the average, divide each score by the number of times
# the given document was sampled.
if self.use_average is True:
document_counts = x['document_counts']
scores = tf.math.divide_no_nan(scores, document_counts)
return scores
# endregion Simple Attention Network