"""
Documentation
"""
# Python Modules
import argparse
import os
import tempfile
import unittest
# 3rd Party Modules
import numpy as np
import tensorflow as tf
# Project Modules
import examples.build_tfrecords as build
import examples.pipeline as pipeline
import deletor.tfutils as tfutils
from deletor.random.sample import IndependentMultiOutputSampler
tfutils.disable_gpu()
tf.config.experimental_run_functions_eagerly(True)
np.set_printoptions(precision=6, suppress=True, edgeitems=200, linewidth=1000000)
[docs]class TestPipeline(unittest.TestCase):
# noinspection DuplicatedCode
[docs] @classmethod
def setUpClass(cls) -> None:
scriptpath = os.path.realpath(__file__)
directory = os.path.dirname(scriptpath)
# The raw data is stored here
svmpath = os.path.join(directory, 'test_data.svm')
# Write a tfrecords file based on the raw data
with tempfile.NamedTemporaryFile() as tfrfile:
# Create a Namespace with the necessary arguments to write the
# tfrecords file to disk.
args = {
'input_file': svmpath,
'output_file': tfrfile.name,
'compression_type': None,
'compression_level': None
}
# Write the tfrecords file
build.write_data(argparse.Namespace(**args))
dataset = pipeline.load_dataset(tfrfile.name, n_features=6).cache()
# Load the dataset into the cache
for _ in dataset:
pass
cls.dataset = dataset
[docs] def test_unbatch(self):
sampler = IndependentMultiOutputSampler(3, multiple=1, sample_pre_batch=True)
dataset = self.dataset.map(sampler)
dataset = dataset.map(pipeline.expand_dims_for_unbatch)
dataset = dataset.unbatch()
dataset = dataset.map(pipeline.squeeze_for_unbatch)
x, y = tf.data.experimental.get_single_element(dataset.take(1))
print(f"x:\n{x}")
print(f"y.shape: {tf.shape(y)}")
print(f"y:\n{y}")