Source code for examples.utils

"""
Documentation
"""
# Python Modules
import argparse
import logging

from typing import Dict, Any

# 3rd Party Modules
import tensorflow as tf

# Project Modules
import deletor.metrics as lwmetrics
from deletor.losses import ApproximateNormalizedDiscountedCumulativeGain, RankingSoftmax, \
    RankingCrossEntropy, MeanSquaredError, ApproximateBiDiNormalizedDiscountedCumulativeGain

log = logging.getLogger(__name__)


[docs]def log_level_type(s: str): s = s.lower() if s == 'debug': return logging.DEBUG if s == 'info': return logging.INFO if s == 'warning': return logging.WARNING if s == 'error': return logging.ERROR raise argparse.ArgumentTypeError(f"{s} is not a valid log level")
[docs]def make_optimizer(args: argparse.Namespace): optimizer_type = args.optimizer learning_rate = args.learning_rate if optimizer_type == 'adagrad': optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate) elif optimizer_type == 'adam': optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) elif optimizer_type == 'sgd': optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) elif optimizer_type == 'nesterov': optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, nesterov=True) elif optimizer_type == 'rmsprop': optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate) else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") return optimizer
[docs]def make_loss(args: argparse.Namespace, reduce: bool = True, **kwargs): loss_type = args.loss if loss_type == 'ndcg': return ApproximateNormalizedDiscountedCumulativeGain(reduce=reduce) if loss_type == 'bidi_ndcg': beta = kwargs.get('beta', 1.0) return ApproximateBiDiNormalizedDiscountedCumulativeGain(reduce=reduce, beta=beta) if loss_type == 'softmax': return RankingSoftmax(reduce=reduce) if loss_type == 'cross_entropy': return RankingCrossEntropy(reduce=reduce) if loss_type == 'mse': return MeanSquaredError(reduce=reduce) raise ValueError(f"Unknown loss type: {loss_type}")
# Tensorflow warns about passing lists into @tf.function annotated methods # but I don't see any performance penalty passing in a list/dict of metric # values. Maybe it's just lists of scalars that cause problems.
[docs]@tf.function def train(model: tf.keras.Model, train_data: tf.data.Dataset, train_meta: Dict[str, Any]): """ Train a model with the given training data. :param model: The model to train. :param train_data: The training data. :param train_meta: A dictionary containing variables to store the results in. """ sample_pre_batch = train_meta['sample_pre_batch'] optimizer = model.optimizer loss_fn = model.loss steps = 0. train_start = tf.timestamp() for x, y in train_data: if isinstance(y, (tuple, list)): y, y_sample = y y_true = y_sample if sample_pre_batch else y else: y_true = y with tf.GradientTape() as tape: train_scores = model(x, training=True) train_loss = loss_fn(y_true, train_scores) train_meta['train_loss'].update_state([train_loss]) weights = model.trainable_weights grads = tape.gradient(train_loss, weights) optimizer.apply_gradients(zip(grads, weights)) train_meta['step'].assign_add(1) steps += 1. train_time = tf.cast(tf.timestamp() - train_start, tf.float32) train_meta['train_time'].assign(train_time) train_meta['secs_step'].assign(train_time / steps)
[docs]@tf.function def evaluate(net: tf.keras.Model, eval_data: tf.data.Dataset, train_meta: Dict[str, Any]): start = tf.timestamp() metric_fns = dict() for k, metric in train_meta['metrics'].items(): metric_fns[k] = lwmetrics.NormalizedDiscountedCumulativeGain(k=k) metric.reset_states() for x, y in eval_data: if isinstance(y, (tuple, list)): y, y_sample = y y_pred = net(x, training=False) for k, metric in train_meta['metrics'].items(): metric.update_state(metric_fns[k](y, y_pred)) train_meta['valid_time'].assign(tf.cast(tf.timestamp() - start, tf.float32))