#
#
# Copyright 2020 Reid Swanson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Documentation
"""
# Python Modules
import argparse
import os
import pathlib
import tempfile
import unittest
# 3rd Party Modules
import numpy as np
import tensorflow as tf
# Project Modules
import examples.build_tfrecords as build
import examples.pipeline as pipeline
import deletor.models.attn as attn
import deletor.tfutils as tfutils
from deletor.constants import MIN_FLOAT_32 as PAD
from deletor.random.sample import IndependentMultiOutputSampler
from test.random.test_sample import TestSampleAfterBatching
tfutils.disable_gpu()
tf.config.experimental_run_functions_eagerly(True)
np.set_printoptions(precision=6, suppress=True, edgeitems=200, linewidth=1000000)
[docs]class TestAttentionModel(unittest.TestCase):
n_features = 6
x = TestSampleAfterBatching.x
y = TestSampleAfterBatching.y
r = TestSampleAfterBatching.r
a = TestSampleAfterBatching.a
# noinspection DuplicatedCode
[docs] @classmethod
def setUpClass(cls) -> None:
scriptpath = os.path.realpath(__file__)
directory = pathlib.Path(scriptpath).parents[1]
# The raw data is stored here
svmpath = os.path.join(directory, 'test_data.svm')
# Write a tfrecords file based on the raw data
with tempfile.NamedTemporaryFile() as tfrfile:
# Create a Namespace with the necessary arguments to write the
# tfrecords file to disk.
args = {
'input_file': svmpath,
'output_file': tfrfile.name,
'compression_type': None,
'compression_level': None
}
# Write the tfrecords file
build.write_data(argparse.Namespace(**args))
dataset = pipeline.load_dataset(tfrfile.name, n_features=cls.n_features).cache()
# Load the dataset into the cache
for _ in dataset:
pass
cls.dataset = dataset
[docs] def test_scaled_dot_product_attention(self):
# From tensorflow.org tutorial on the transformer model
temp_q = tf.constant([[0, 0, 10],
[0, 10, 0],
[10, 10, 0]], dtype=tf.float32) # (3, 3)
temp_k = tf.constant([[10, 0, 0],
[0, 10, 0],
[0, 0, 10],
[0, 0, 10]], dtype=tf.float32) # (4, 3)
temp_v = tf.constant([[1, 0],
[10, 0],
[100, 5],
[1000, 6]], dtype=tf.float32) # (4, 2)
attention_fn = attn.MultiHeadAttention.scaled_dot_product_attention
act_output, act_attn_weights = attention_fn(temp_q, temp_k, temp_v)
exp_output = [[550.0, 5.5],
[10.00, 0.0],
[5.500, 0.0]]
exp_attn_weights = [[0.0, 0.0, 0.5, 0.5],
[0.0, 1.0, 0.0, 0.0],
[0.5, 0.5, 0.0, 0.0]]
np.testing.assert_array_almost_equal(act_output, exp_output)
np.testing.assert_array_almost_equal(act_attn_weights, exp_attn_weights)
[docs] def test_scaled_dot_product_attention_with_weights(self):
x0 = tf.sqrt(tf.reshape(tf.range(6, dtype=tf.float32), [1, 3, 2]))
x1 = tf.reshape(tf.range(6, dtype=tf.float32), [1, 3, 2])
x = {
'sequence_dense': tf.concat([x0, x1], axis=0)
}
y = tf.constant([[3., 1., PAD],
[1., 2., 3.]])
attention_fn = attn.MultiHeadAttention.scaled_dot_product_attention
mask_fn = attn.MultiHeadAttention.make_attention_mask
x, y = mask_fn(x, y, multi_head=False)
mask = x['attention_mask']
q, k, v = x['sequence_dense'], x['sequence_dense'], x['sequence_dense']
exp_output = [[[0.886140, 1.458700],
[1.286604, 1.665995],
[1.356517, 1.702185]],
[[3.445059, 4.445059],
[3.998300, 4.998300],
[3.999994, 4.999994]]]
exp_attn_weights = [[[0.373405, 0.626595, 0.000000],
[0.090233, 0.909767, 0.000000],
[0.040798, 0.959202, 0.000000]],
[[0.045388, 0.186694, 0.767918],
[0.000001, 0.000849, 0.999151],
[0.000000, 0.000003, 0.999997]]]
act_output, act_attn_weights = attention_fn(q, k, v, mask=mask)
np.testing.assert_array_almost_equal(act_output, exp_output)
np.testing.assert_array_almost_equal(act_attn_weights, exp_attn_weights)
[docs] def test_call(self):
group_size = 3
dataset = self.dataset
sampler = IndependentMultiOutputSampler(group_size, multiple=1)
mask_fn = attn.MultiHeadAttention.make_attention_mask
shapes = (
{
'context_one_hot': (),
'context_multi_hot': (),
'context_dense': (),
'sequence_one_hot': (),
'sequence_multi_hot': (),
'sequence_dense': tf.TensorShape([None, self.n_features])
},
tf.TensorShape([None])
)
values = (
{
'context_one_hot': 0.,
'context_multi_hot': 0.,
'context_dense': 0.,
'sequence_one_hot': 0.,
'sequence_multi_hot': 0.,
'sequence_dense': 0.
},
PAD
)
dataset = dataset.padded_batch(2, shapes, values)
dataset = dataset.map(sampler)
dataset = dataset.map(lambda a, b: mask_fn(a, b))
x, y = tf.data.experimental.get_single_element(dataset.take(1))
model_params = {
attn.ModelParameter.USE_AVERAGE: True,
attn.ModelParameter.N_LAYERS: 2,
attn.ModelParameter.N_FEATURES: self.n_features,
attn.ModelParameter.GROUP_SIZE: group_size,
attn.ModelParameter.N_HEADS: 3,
attn.ModelParameter.N_FF_UNITS: 5,
attn.ModelParameter.DROPOUT_RATE: 0.3
}
model = attn.GroupwiseMultiHeadAttentionNetwork(model_params)
scores = model.call(x)
print(f"x:\n{x}")
print(f"y:\n{y}")
print(f"scores:\n{scores}")