Source code for deletor.models.attn

# Original Copyright 2020 Google
# Additional modifications Copyright 2020 Reid Swanson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Python Modules
import logging

from typing import Any, Dict

# 3rd Party Modules
import tensorflow as tf
import tensorflow.keras.layers as tf_layers

# Project Modules
# In the transformer tutorial n_model is the number of dimensions used for
# the word embeddings. Since we are currently only working with dense inputs
# and no embeddings (yet), n_model is the number of input features
# (i.e., n_features).
from deletor.constants import MIN_FLOAT_32


log = logging.getLogger(__name__)


[docs]class ModelParameter(object): N_LAYERS = 'n_layers' N_MODEL = 'n_model' N_FEATURES = 'n_features' GROUP_SIZE = 'group_size' N_HEADS = 'n_heads' N_FF_UNITS = 'n_feed_forward_units' DROPOUT_RATE = 'dropout_rate' USE_LAYER_NORM = 'use_layer_norm' USE_AVERAGE = 'use_average' SHARE_WEIGHTS = 'share_weights'
[docs]def make_point_wise_feed_forward_network(n_model, dff): """ :param n_model: :param dff: :return: """ return tf.keras.Sequential([ tf_layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) tf_layers.Dense(n_model) # (batch_size, seq_len, n_model) ])
[docs]class MultiHeadAttention(tf_layers.Layer): def __init__(self, n_model, num_heads): """ :param n_model: :param num_heads: """ super().__init__() self.num_heads = num_heads self.n_model = n_model if n_model % num_heads != 0: raise ValueError("The number of heads must be an exact multiple of n_model") self.depth = n_model // num_heads self.wq = tf_layers.Dense(n_model) self.wk = tf_layers.Dense(n_model) self.wv = tf_layers.Dense(n_model) self.dense = tf_layers.Dense(n_model)
[docs] def split_heads(self, x, batch_size): """ Split the last dimension into (num_heads, depth). Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) """ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) x = tf.transpose(x, perm=[0, 2, 1, 3]) return x
[docs] @classmethod def make_attention_mask(cls, x, y, multi_head: bool = True, pad_value: float = MIN_FLOAT_32): """ :param x: :param y: :param multi_head: :param pad_value: :return: """ ys = y[1] if isinstance(y, (tuple, list)) else y y_rank = tf.rank(ys) if y_rank == 3: batch_size = tf.shape(ys)[0] n_samples = tf.shape(ys)[1] group_size = tf.shape(ys)[2] ys = tf.reshape(tf.identity(ys), (batch_size * n_samples, group_size)) mask = tf.cast(tf.math.equal(ys, pad_value), tf.float32) # This has an extra dimension because we split the input into multiple # heads later. # So it should have the shape: (batch_size, n_heads, seqlen, n_features) if multi_head is True: x['attention_mask'] = mask[:, tf.newaxis, tf.newaxis, :] else: x['attention_mask'] = mask[:, tf.newaxis, :] return x, y
[docs] @classmethod def scaled_dot_product_attention( cls, q: tf.Tensor, k: tf.Tensor, v: tf.Tensor, mask: tf.Tensor = None ): """ Calculate the attention weights. q, k, v must have matching leading dimensions. k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. The mask has different shapes depending on its type(padding or look ahead) but it must be broadcastable for addition. :param q: query shape == (..., seq_len_q, depth) :param k: key shape == (..., seq_len_k, depth) :param v: value shape == (..., seq_len_v, depth_v) :param mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k). Defaults to None. :return: output, attention_weights """ q.shape.assert_same_rank(k.shape) k.shape.assert_same_rank(v.shape) # (..., seq_len_q, seq_len_k) matmul_qk = tf.matmul(q, k, transpose_b=True) # scale matmul_qk dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) # add the mask to the scaled tensor. if mask is not None: # noinspection PyTypeChecker scaled_attention_logits += (mask * -1e9) # softmax is normalized on the last axis (seq_len_k) so that the scores # add up to 1. # (..., seq_len_q, seq_len_k) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) return output, attention_weights
[docs] @classmethod def reshape_input(cls, q: tf.Tensor, k: tf.Tensor, v: tf.Tensor): batch_size = tf.shape(q)[0] n_samples = tf.shape(q)[1] group_size = tf.shape(q)[2] n_features = tf.shape(q)[3] q = tf.reshape(q, (batch_size * n_samples, group_size, n_features)) k = tf.reshape(k, (batch_size * n_samples, group_size, n_features)) v = tf.reshape(v, (batch_size * n_samples, group_size, n_features)) return q, k, v
[docs] @classmethod def reshape_output(cls, output: tf.Tensor, q: tf.Tensor): batch_size = tf.shape(q)[0] n_samples = tf.shape(q)[1] group_size = tf.shape(q)[2] n_features = tf.shape(q)[3] return tf.reshape(output, (batch_size, n_samples, group_size, n_features))
# noinspection PyMethodOverriding
[docs] def call(self, v, k, q, mask, **kwargs): """ :param v: :param k: :param q: :param mask: :param kwargs: :return: """ batch_size = tf.shape(q)[0] q = self.wq(q) # (batch_size, seq_len, n_model) k = self.wk(k) # (batch_size, seq_len, n_model) v = self.wv(v) # (batch_size, seq_len, n_model) seqlen = tf.shape(q)[1] if self.num_heads > 1: q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask) if self.num_heads > 1: # (batch_size, seq_len_q, num_heads, depth) scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, n_model) concat_attention = tf.reshape(scaled_attention, (batch_size, seqlen, self.n_model)) # (batch_size, seq_len_q, n_model) output = self.dense(concat_attention) return output, attention_weights
[docs]class EncoderLayer(tf_layers.Layer): def __init__(self, n_model: int, num_heads: int, d_ff: int, dropout_rate: float = 0.1): """ :param n_model: :param num_heads: :param d_ff: :param dropout_rate: """ super().__init__() self.multi_head_attention = MultiHeadAttention(n_model, num_heads) self.feed_forward = make_point_wise_feed_forward_network(n_model, d_ff) self.layer_norm_1 = tf_layers.LayerNormalization(epsilon=1e-6) self.layer_norm_2 = tf_layers.LayerNormalization(epsilon=1e-6) self.dropout_1 = tf_layers.Dropout(dropout_rate) self.dropout_2 = tf_layers.Dropout(dropout_rate) # noinspection PyMethodOverriding
[docs] def call(self, x, training, mask, **kwargs): """ :param x: :param training: :param mask: :param kwargs: :return: """ # (batch_size, input_seq_len, n_model) attn_output, _ = self.multi_head_attention(x, x, x, mask) attn_output = self.dropout_1(attn_output, training=training) # (batch_size, input_seq_len, n_model) output_1 = self.layer_norm_1(x + attn_output) # (batch_size, input_seq_len, n_model) ff_output = self.feed_forward(output_1) ff_output = self.dropout_2(ff_output, training=training) # (batch_size, input_seq_len, n_model) output_2 = self.layer_norm_2(output_1 + ff_output) return output_2
[docs]class Encoder(tf_layers.Layer): def __init__( self, num_layers: int, n_model: int, num_heads: int, d_ff: int, dropout_rate: float = 0.1 ): """ :param num_layers: :param n_model: :param num_heads: :param d_ff: :param dropout_rate: """ super().__init__() self.n_model = n_model self.num_layers = num_layers self.encoder_layers = [ EncoderLayer(n_model, num_heads, d_ff, dropout_rate) for _ in range(num_layers) ] self.dropout = tf_layers.Dropout(dropout_rate) # noinspection PyMethodOverriding
[docs] def call(self, x, mask, training, **kwargs): """ :param x: :param mask: :param training: :param kwargs: :return: """ # Note we don't need any positional information x *= tf.math.sqrt(tf.cast(self.n_model, tf.float32)) x = self.dropout(x, training=training) for i in range(self.num_layers): x = self.encoder_layers[i](x, training, mask) # (batch_size, input_seq_len, n_model) return x
[docs]class GroupwiseMultiHeadAttentionNetwork(tf.keras.Model): """ This model implements a groupwise scoring model using normalized dot product attention mechanism based on Attention Is All You Need. """ default_use_average = True default_share_weights = False def __init__(self, params: Dict[str, Any], **kwargs): """ :param params: See :class:`.ModelParameter` for the valid parameters. :param kwargs: """ super().__init__(**kwargs) use_average = params[ModelParameter.USE_AVERAGE] n_layers = params[ModelParameter.N_LAYERS] n_model = params[ModelParameter.N_FEATURES] group_size = params[ModelParameter.GROUP_SIZE] n_heads = params[ModelParameter.N_HEADS] d_ff = params[ModelParameter.N_FF_UNITS] dropout_rate = params[ModelParameter.DROPOUT_RATE] self.model_params = params self.group_size = tf.constant(group_size, tf.int32) self.use_average = use_average self.encoder = Encoder(n_layers, n_model, n_heads, d_ff, dropout_rate) self.scoring_layer = tf_layers.Dense(group_size) # noinspection PyMethodOverriding,DuplicatedCode
[docs] def call(self, x, training: bool = True, **kwargs): """ :param x: :param training: :param kwargs: :return: """ xd = x['sequence_dense'] xs = x['sample_dense'] encoder_mask = x['attention_mask'] scatter_idx = x['scatter_idx'] batch_size = tf.shape(xd)[0] n_documents = tf.shape(xd)[1] n_samples = tf.shape(xs)[1] group_size = tf.shape(xs)[2] n_features = tf.shape(xd)[2] input_lyr = tf.reshape(xs, (batch_size * n_samples, group_size, n_features)) attn_output = self.encoder(input_lyr, encoder_mask, training) attn_output = tf.reshape(attn_output, (batch_size * n_samples, group_size * n_features)) sample_scores = self.scoring_layer(attn_output) sample_scores = tf.reshape(sample_scores, (batch_size, n_samples, group_size)) scores = tf.scatter_nd(scatter_idx, sample_scores, (batch_size, n_documents)) if self.use_average is True: document_counts = x['document_counts'] scores = tf.math.divide_no_nan(scores, document_counts) return scores
# region Simple Attention Network
[docs]class AttentionLayer(tf_layers.Layer): def __init__( self, n_model: int, n_ff_units: int, use_layer_norm: bool = False, dropout_rate: float = 0.1 ): """ :param n_model: :param n_ff_units: :param use_layer_norm: :param dropout_rate: """ super().__init__() self.use_dropout = dropout_rate > 0 self.use_layer_norm = use_layer_norm self.multi_head_attention = MultiHeadAttention(n_model, 1) self.feed_forward = make_point_wise_feed_forward_network(n_model, n_ff_units) if self.use_layer_norm is True: self.layer_norm_1 = tf_layers.LayerNormalization(epsilon=1e-6) self.layer_norm_2 = tf_layers.LayerNormalization(epsilon=1e-6) if self.use_dropout is True: self.dropout_1 = tf_layers.Dropout(dropout_rate) self.dropout_2 = tf_layers.Dropout(dropout_rate) # noinspection PyMethodOverriding
[docs] def call(self, x, training, mask, **kwargs): """ :param x: :param training: :param mask: :param kwargs: :return: """ # (batch_size, input_seq_len, n_model) attn_output, _ = self.multi_head_attention(x, x, x, mask) # print(f"attn_output {tf.shape(attn_output)}\n{attn_output}") # print(f"attn_output {tf.shape(attn_output)}") # print(f"x {tf.shape(x)}") if self.use_dropout is True: attn_output = self.dropout_1(attn_output, training=training) if self.use_layer_norm is True: # (batch_size, input_seq_len, n_model) output_1 = self.layer_norm_1(x + attn_output) else: output_1 = attn_output # (batch_size, input_seq_len, n_model) ff_output = self.feed_forward(output_1) if self.use_dropout is True: ff_output = self.dropout_2(ff_output, training=training) if self.use_layer_norm is True: # (batch_size, input_seq_len, n_model) output_2 = self.layer_norm_2(output_1 + ff_output) else: output_2 = ff_output return output_2
[docs]class SelfAttention(tf_layers.Layer): def __init__( self, n_layers: int, n_model: int, n_ff_units: int, use_layer_norm: bool = False, dropout_rate: float = 0.1 ): """ :param n_layers: :param n_model: :param n_ff_units: :param use_layer_norm: :param dropout_rate: """ super().__init__() self.n_model = n_model self.n_layers = n_layers self.attention_layers = [ AttentionLayer(n_model, n_ff_units, use_layer_norm, dropout_rate) for _ in range(n_layers) ] self.dropout = tf_layers.Dropout(dropout_rate) # noinspection PyMethodOverriding
[docs] def call(self, x, mask, training, **kwargs): """ :param x: :param mask: :param training: :param kwargs: :return: """ # Note we don't need any positional information x *= tf.math.sqrt(tf.cast(self.n_model, tf.float32)) x = self.dropout(x, training=training) for i in range(self.n_layers): x = self.attention_layers[i](x, training, mask) # (batch_size, input_seq_len, n_model) return x
[docs]class GroupwiseAttentionNetwork(tf.keras.Model): def __init__(self, params: Dict[str, Any], **kwargs): """ :param params: :param kwargs: """ super().__init__(**kwargs) use_average = params[ModelParameter.USE_AVERAGE] share_weights = params[ModelParameter.SHARE_WEIGHTS] use_layer_norm = params[ModelParameter.USE_LAYER_NORM] n_layers = params[ModelParameter.N_LAYERS] n_model = params[ModelParameter.N_MODEL] group_size = params[ModelParameter.GROUP_SIZE] n_ff_units = params[ModelParameter.N_FF_UNITS] dropout_rate = params[ModelParameter.DROPOUT_RATE] self.model_params = params self.group_size = tf.constant(group_size, tf.int32) self.use_average = use_average self.share_weights = share_weights if self.share_weights is True: self.shared_input_layer = self.make_shared_layer(n_model) attention_args = [n_layers, n_model, n_ff_units, use_layer_norm, dropout_rate] self.self_attention = SelfAttention(*attention_args) self.scoring_layer = tf_layers.Dense(group_size)
[docs] @classmethod def make_shared_layer(cls, n_model: int): """ :param n_model: :return: """ return tf.keras.Sequential([ tf_layers.Dense(n_model), tf_layers.BatchNormalization(), tf_layers.PReLU() ])
[docs] def call(self, x: Dict[str, tf.Tensor], training: bool = True, **kwargs): """ :param x: :param training: :param kwargs: :return: """ xd = x['sequence_dense'] # (batch_size, n_documents, n_features) xs = x['sample_dense'] # (batch_size, n_samples, group_size, n_features) encoder_mask = x['attention_mask'] # (batch_size, group_size, ?)? scatter_idx = x['scatter_idx'] batch_size = tf.shape(xd)[0] n_documents = tf.shape(xd)[1] n_samples = tf.shape(xs)[1] group_size = tf.shape(xs)[2] n_features = tf.shape(xs)[3] if self.share_weights is True: # If we are using a shared input layer, then we will run each individual item # through a dense layer. # First, reshape so that each instance is in its own row n_rows = batch_size * n_samples * group_size n_cols = n_features in_lyr = tf.reshape(xs, [n_rows, n_cols]) # Pass them through a dense layer (with PReLU activation and a BatchNorm) in_lyr = self.shared_input_layer(in_lyr) # Now reshape it so that each sample is a batch and the features of # all items in a group are concatenated. n_model = tf.shape(in_lyr)[1] input_lyr = tf.reshape(in_lyr, (batch_size * n_samples, group_size, n_model)) else: # Reshape the input so that each sample is treated like a batch element input_lyr = tf.reshape(xs, (batch_size * n_samples, group_size, n_features)) n_model = n_features # Apply the self attention mechanism, which has shape: # (batch_size, group_size, n_model/n_features) attn_output = self.self_attention(input_lyr, encoder_mask, training) # Reshape it so it has the correct dimensions for the scoring layer. attn_output = tf.reshape(attn_output, (batch_size * n_samples, group_size * n_model)) # It's possible that some queries have fewer documents than the group size. # If all documents in the batch have fewer documents than the group size # then there will be a mismatch in the expected input shape for the # final layer (which expects exactly self.group_size * n_features) as input. # To fix this issue we padd the attention output to have the correct # size. It may be possible to apply this padding in the input pipeline # for better efficiency, but it is surprisingly complicated, especially # because most of the code was written before this was an issue (I hope). padlen = (self.group_size * n_model) - (group_size * n_model) def pad_attn(): paddings = tf.zeros([2, 2], tf.int32) paddings = tf.tensor_scatter_nd_add(paddings, [[1, 1]], [padlen]) return tf.pad(attn_output, paddings) attn_output = tf.cond(tf.greater(padlen, 0), pad_attn, lambda: attn_output) # Apply the final layer sample_scores = self.scoring_layer(attn_output) sample_scores = sample_scores[..., :group_size] # Reshape the output so it is compatible with the scatter index sample_scores = tf.reshape(sample_scores, (batch_size, n_samples, group_size)) # Update (sum) the scores for each document scores = tf.scatter_nd(scatter_idx, sample_scores, (batch_size, n_documents)) # If we are using the average, divide each score by the number of times # the given document was sampled. if self.use_average is True: document_counts = x['document_counts'] scores = tf.math.divide_no_nan(scores, document_counts) return scores
# endregion Simple Attention Network