"""
Documentation
"""
# Python Modules
import argparse
import logging
from typing import Dict, Any
# 3rd Party Modules
import tensorflow as tf
# Project Modules
import deletor.metrics as lwmetrics
from deletor.losses import ApproximateNormalizedDiscountedCumulativeGain, RankingSoftmax, \
RankingCrossEntropy, MeanSquaredError, ApproximateBiDiNormalizedDiscountedCumulativeGain
log = logging.getLogger(__name__)
[docs]def log_level_type(s: str):
s = s.lower()
if s == 'debug':
return logging.DEBUG
if s == 'info':
return logging.INFO
if s == 'warning':
return logging.WARNING
if s == 'error':
return logging.ERROR
raise argparse.ArgumentTypeError(f"{s} is not a valid log level")
[docs]def make_optimizer(args: argparse.Namespace):
optimizer_type = args.optimizer
learning_rate = args.learning_rate
if optimizer_type == 'adagrad':
optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate)
elif optimizer_type == 'adam':
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
elif optimizer_type == 'sgd':
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
elif optimizer_type == 'nesterov':
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, nesterov=True)
elif optimizer_type == 'rmsprop':
optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
return optimizer
[docs]def make_loss(args: argparse.Namespace, reduce: bool = True, **kwargs):
loss_type = args.loss
if loss_type == 'ndcg':
return ApproximateNormalizedDiscountedCumulativeGain(reduce=reduce)
if loss_type == 'bidi_ndcg':
beta = kwargs.get('beta', 1.0)
return ApproximateBiDiNormalizedDiscountedCumulativeGain(reduce=reduce, beta=beta)
if loss_type == 'softmax':
return RankingSoftmax(reduce=reduce)
if loss_type == 'cross_entropy':
return RankingCrossEntropy(reduce=reduce)
if loss_type == 'mse':
return MeanSquaredError(reduce=reduce)
raise ValueError(f"Unknown loss type: {loss_type}")
# Tensorflow warns about passing lists into @tf.function annotated methods
# but I don't see any performance penalty passing in a list/dict of metric
# values. Maybe it's just lists of scalars that cause problems.
[docs]@tf.function
def train(model: tf.keras.Model, train_data: tf.data.Dataset, train_meta: Dict[str, Any]):
"""
Train a model with the given training data.
:param model: The model to train.
:param train_data: The training data.
:param train_meta: A dictionary containing variables to store the results in.
"""
sample_pre_batch = train_meta['sample_pre_batch']
optimizer = model.optimizer
loss_fn = model.loss
steps = 0.
train_start = tf.timestamp()
for x, y in train_data:
if isinstance(y, (tuple, list)):
y, y_sample = y
y_true = y_sample if sample_pre_batch else y
else:
y_true = y
with tf.GradientTape() as tape:
train_scores = model(x, training=True)
train_loss = loss_fn(y_true, train_scores)
train_meta['train_loss'].update_state([train_loss])
weights = model.trainable_weights
grads = tape.gradient(train_loss, weights)
optimizer.apply_gradients(zip(grads, weights))
train_meta['step'].assign_add(1)
steps += 1.
train_time = tf.cast(tf.timestamp() - train_start, tf.float32)
train_meta['train_time'].assign(train_time)
train_meta['secs_step'].assign(train_time / steps)
[docs]@tf.function
def evaluate(net: tf.keras.Model, eval_data: tf.data.Dataset, train_meta: Dict[str, Any]):
start = tf.timestamp()
metric_fns = dict()
for k, metric in train_meta['metrics'].items():
metric_fns[k] = lwmetrics.NormalizedDiscountedCumulativeGain(k=k)
metric.reset_states()
for x, y in eval_data:
if isinstance(y, (tuple, list)):
y, y_sample = y
y_pred = net(x, training=False)
for k, metric in train_meta['metrics'].items():
metric.update_state(metric_fns[k](y, y_pred))
train_meta['valid_time'].assign(tf.cast(tf.timestamp() - start, tf.float32))