# Copyright 2020 Reid Swanson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Python Modules
import argparse
import logging
import os
import shutil
# from typing import Dict, Any, Optional
# 3rd Party Modules
import numpy as np
import tensorflow as tf
import deletor.tfutils as tfutils
from examples.utils import train, evaluate, make_optimizer, make_loss
from deletor.random.sample import IndependentMultiOutputSampler
tfutils.grow_memory()
# Project Modules
from examples.pipeline import load_dataset, is_valid_query, truncate_document_list, \
make_padded_shapes, make_padding_values, expand_dims_for_unbatch, squeeze_for_unbatch, \
N_FEATURES
from deletor.models.gsf import GroupwiseScoringNetwork, ModelParameter, GroupwiseScoringNetwork2
from deletor.metrics import NormalizedDiscountedCumulativeGain
np.set_printoptions(precision=6, suppress=True, edgeitems=10, linewidth=10000)
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)-15s [%(name)s]:%(lineno)d %(levelname)s %(message)s'
)
log = logging.getLogger('gsf/mltr30k')
AUTOTUNE = tf.data.experimental.AUTOTUNE
[docs]def prepare_data(args: argparse.Namespace):
list_size = args.list_size
group_size = args.group_size
multiples = args.multiples
sample_pre_batch = args.sample_pre_batch
train_bsz = args.training_batch_size
eval_bsz = args.evaluation_batch_size
drop_remainder = args.drop_remainder
train_data = load_dataset(args.train_file, args.scaler)
valid_data = load_dataset(args.valid_file, args.scaler)
test_data = load_dataset(args.test_file, args.scaler)
train_data = train_data.filter(is_valid_query)
valid_data = valid_data.filter(is_valid_query)
test_data = test_data.filter(is_valid_query)
if list_size:
train_data = train_data.map(lambda x, y: truncate_document_list(x, y, list_size))
valid_data = valid_data.map(lambda x, y: truncate_document_list(x, y, list_size))
test_data = test_data.map(lambda x, y: truncate_document_list(x, y, list_size))
train_data = train_data.cache()
valid_data = valid_data.cache()
test_data = test_data.cache()
train_sampler = IndependentMultiOutputSampler(
group_size,
multiple=multiples,
sample_pre_batch=sample_pre_batch
)
eval_sampler = (
train_sampler if not sample_pre_batch else
IndependentMultiOutputSampler(group_size, multiple=multiples)
)
# The code for bucketing by sequence length should be in branch v0.2
# It doesn't really seem to improve the accuracy and any efficiency gains
# appear to be small. The primary reason for not using it is that it
# seems to break when using an unbounded list_size (i.e., None). It's
# possible this could be fixed, but given the lack of accuracy/efficiency
# it's probably not worth it.
padded_shapes = make_padded_shapes(list_size)
padding_values = make_padding_values()
valid_data = valid_data.padded_batch(eval_bsz, padded_shapes, padding_values, drop_remainder)
test_data = test_data.padded_batch(eval_bsz, padded_shapes, padding_values, drop_remainder)
if sample_pre_batch:
train_data = train_data.map(train_sampler)
valid_data = valid_data.map(eval_sampler)
test_data = test_data.map(eval_sampler)
train_data = train_data.map(expand_dims_for_unbatch)
train_data = train_data.unbatch()
train_data = train_data.map(squeeze_for_unbatch)
train_data = train_data.shuffle(10000, args.random_seed, reshuffle_each_iteration=True)
padded_shapes = make_padded_shapes(group_size)
padding_values = make_padding_values()
train_data = train_data.padded_batch(
train_bsz,
padded_shapes,
padding_values,
drop_remainder
)
else:
train_data = train_data.shuffle(1000, args.random_seed, reshuffle_each_iteration=True)
train_data = train_data.padded_batch(
train_bsz,
padded_shapes,
padding_values,
drop_remainder
)
train_data = train_data.map(train_sampler)
valid_data = valid_data.map(eval_sampler)
test_data = test_data.map(eval_sampler)
train_data = train_data.prefetch(AUTOTUNE)
valid_data = valid_data.prefetch(AUTOTUNE)
test_data = test_data.prefetch(AUTOTUNE)
return train_data, valid_data, test_data
[docs]def setup_model(args: argparse.Namespace):
model_params = {
ModelParameter.N_FEATURES: N_FEATURES,
ModelParameter.N_UNITS: args.n_units,
ModelParameter.GROUP_SIZE: args.group_size,
ModelParameter.USE_AVERAGE: args.use_average,
ModelParameter.SHARE_WEIGHTS: args.share_weights,
ModelParameter.DROPOUT_RATE: args.dropout_rate
}
model_cls = GroupwiseScoringNetwork2 if args.sample_pre_batch else GroupwiseScoringNetwork
model = model_cls(model_params)
optimizer = make_optimizer(args)
loss = make_loss(args)
metrics = [
NormalizedDiscountedCumulativeGain(k=1),
NormalizedDiscountedCumulativeGain(k=5),
NormalizedDiscountedCumulativeGain(k=10),
]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
return model
# noinspection PyTypeChecker
[docs]def main(args: argparse.Namespace):
# Best parameters so far:
# max_epochs: 1000
# multiples: 3
# training_batch_size: 64
# evaluation_batch_size: 256
# optimizer: adam
# learning_rate: 0.00025
# n_units: 64 128 64 32
# use_average: True
# share_weights: True
# epoch: 1000 step: 287000 elapsed time: 47496.81s val time: 10.22 train/loss: -47.7925
# val/ndcg@01: 0.4646 val/ndcg@05: 0.4541 val/ndcg@10: 0.4716
# test/ndcg@01: 0.4646 test/ndcg@05: 0.4541 test/ndcg@10: 0.4716
max_epochs = args.max_epochs
datasets = prepare_data(args)
train_data, valid_data, test_data = datasets
if os.path.exists(args.checkpoint_dir):
log.info(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
shutil.rmtree(args.checkpoint_dir, ignore_errors=True)
tf.config.experimental_run_functions_eagerly(args.run_eagerly)
model = setup_model(args)
model(tf.data.experimental.get_single_element(train_data.take(1))[0], training=True)
model.summary(print_fn=log.info)
# Training in keras fails with:
# "ValueError: Cannot take the length of shape with unknown rank."
# Note this should work now, but if training with keras, then
# the `reduce` parameter for the loss function should be set to `False`.
# model.fit(train_data, validation_data=valid_data, epochs=10)
# Custom loop
train_meta = {
'sample_pre_batch': args.sample_pre_batch,
'max_epochs': tf.constant(max_epochs),
'step': tf.Variable(0),
'elapsed_time': tf.Variable(0., tf.float32),
'train_time': tf.Variable(0., tf.float32),
'valid_time': tf.Variable(0., tf.float32),
'secs_step': tf.Variable(0., tf.float32),
'train_loss': tf.keras.metrics.Mean(),
'metrics': {k: tf.keras.metrics.Mean() for k in (1, 5, 10)},
'best_result': tf.Variable(0., tf.float32)
}
ckpt = tf.train.Checkpoint(epoch=tf.Variable(0), optimizer=model.optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, args.checkpoint_dir, max_to_keep=1)
start_time = tf.timestamp()
for epoch in range(train_meta['max_epochs']):
train_meta['train_loss'].reset_states()
train(model, train_data, train_meta)
evaluate(model, valid_data, train_meta)
flag_best_result = ''
if train_meta['metrics'][5].result() > train_meta['best_result']:
manager.save()
train_meta['best_result'].assign(train_meta['metrics'][5].result())
flag_best_result = ' *'
train_meta['elapsed_time'].assign(tf.cast(tf.timestamp() - start_time, tf.float32))
ckpt.epoch.assign(epoch)
log.info(
f"epoch: {epoch+1:5d} "
f"step: {train_meta['step'].numpy():8d} "
f"elapsed time: {train_meta['elapsed_time'].numpy():8.2f}s "
f"train time: {train_meta['train_time'].numpy():6.2f}s "
f"secs/step: {train_meta['secs_step'].numpy():6.3f} "
f"val time: {train_meta['valid_time'].numpy():6.2f} "
f"train/loss: {train_meta['train_loss'].result():10.4f} "
f"val/ndcg@01: {train_meta['metrics'][1].result():10.4f} "
f"val/ndcg@05: {train_meta['metrics'][5].result():10.4f} "
f"val/ndcg@10: {train_meta['metrics'][10].result():10.4f}"
f"{flag_best_result}"
)
# Evaluate on the test data using the best model during training
log.info(f"Loading checkpoint from: {args.checkpoint_dir}")
ckpt = tf.train.Checkpoint(epoch=tf.Variable(0), model=model, optimizer=model.optimizer)
ckpt.restore(tf.train.latest_checkpoint(args.checkpoint_dir))
eval_meta = {
'sample_pre_batch': args.sample_pre_batch,
'max_epochs': tf.constant(max_epochs),
'step': tf.Variable(0),
'elapsed_time': tf.Variable(0., tf.float32),
'train_time': tf.Variable(0., tf.float32),
'valid_time': tf.Variable(0., tf.float32),
'secs_step': tf.Variable(0., tf.float32),
'train_loss': tf.keras.metrics.Mean(),
'metrics': {k: tf.keras.metrics.Mean() for k in (1, 5, 10)},
'best_result': tf.Variable(0., tf.float32)
}
evaluate(model, test_data, eval_meta)
log.info(
f"test/ndcg@01: {eval_meta['metrics'][1].result():10.4f} "
f"test/ndcg@05: {eval_meta['metrics'][5].result():10.4f} "
f"test/ndcg@10: {eval_meta['metrics'][10].result():10.4f}"
)
# noinspection DuplicatedCode
[docs]def make_command_line_options():
cli = argparse.ArgumentParser(fromfile_prefix_chars='@')
cli.add_argument(
'--train-file',
required=True,
type=str,
help="The training tfrecords file."
)
cli.add_argument(
'--valid-file',
required=True,
type=str,
help="The validation tfrecords file."
)
cli.add_argument(
'--test-file',
required=True,
type=str,
help="The test tfrecords file."
)
cli.add_argument(
'--checkpoint-dir',
required=True,
type=str,
help="The directory where model checkpoints will be saved."
)
cli.add_argument(
'--scaler',
required=False,
type=str,
nargs=2,
help=(
"This argument requires two parameters. The first is the path to "
"a scaler file created with the build dataset script. The second "
"is the name of the scaler to use. Choose one of: "
"minmax, standard, robust, power."
)
)
cli.add_argument(
'--run-eagerly',
required=False,
action='store_true'
)
cli.add_argument(
'--max-epochs',
required=False,
type=int,
default=500,
help="The maximum number of epochs before the training terminates no matter what."
)
cli.add_argument(
'--optimizer',
required=False,
type=str,
default='adagrad',
choices=['adagrad', 'adam', 'sgd', 'nesterov', 'rmsprop']
)
cli.add_argument(
'--learning-rate',
required=False,
type=float,
default=0.001
)
cli.add_argument(
'--loss',
required=False,
type=str,
choices=['ndcg', 'bidi_ndcg', 'softmax', 'cross_entropy', 'mse'],
default='ndcg'
)
cli.add_argument(
'--list-size',
required=False,
type=int,
default=None,
help="The maximum number of documents per query or no maximum if not set."
)
cli.add_argument(
'--group-size',
required=False,
type=int,
default=16,
help="The group size to use."
)
cli.add_argument(
'--sample-pre-batch',
required=False,
action='store_true',
default=False,
help=(
"If this flag is set then the alternate form of training will be "
"performed where documents are sampled before training."
)
)
cli.add_argument(
'--multiples',
required=False,
type=int,
default=1,
help="The sampling multiplier."
)
cli.add_argument(
'--training-batch-size',
required=False,
type=int,
default=128
)
cli.add_argument(
'--evaluation-batch-size',
required=False,
type=int,
default=128
)
cli.add_argument(
'--use-average',
required=False,
action='store_true',
default=False,
help=(
"According to the paper, when a document is sampled more than once its scores are "
"summed. When this option is set the scores are averaged over the number of times "
"each document is seen instead."
)
)
cli.add_argument(
'--share-weights',
required=False,
action='store_true',
default=False,
help="Apply each document through a shared dense layer before concatenating them."
)
cli.add_argument(
'--n-units',
required=False,
type=int,
nargs='+',
default=[64, 32, 16]
)
cli.add_argument(
'--dropout-rate',
required=False,
type=float,
default=0.0
)
cli.add_argument(
'--drop-remainder',
action='store_true',
default=False,
help="This is necessary when using the keras training/eval loops."
)
cli.add_argument(
'--random-seed',
required=False,
type=int,
help="The random seed to use for sampling query results."
)
cli.set_defaults(func=main)
return cli
if __name__ == '__main__':
clo = make_command_line_options()
cli_args = clo.parse_args()
cli_args.func(cli_args)