Source code for test.ranking.test_utils

"""
Documentation
"""
# Python Modules
import logging
import os
import unittest

# 3rd Party Modules
import numpy as np
import tensorflow as tf

# Project Modules
import deletor.tfutils as tfutils

tfutils.disable_gpu()

import deletor.ranking.utils as rutils

from deletor.constants import MIN_FLOAT_32

os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

logging.getLogger('tensorflow').setLevel(logging.WARNING)

np.set_printoptions(precision=4, suppress=True, edgeitems=20, linewidth=10000)


[docs]class TestUtils(unittest.TestCase):
[docs] def test_compute_ranks(self): logits = tf.ragged.constant([[3.5, 2.2, 0.5, 1.0], [1.7, 2.3, 2.1, 1.1, 0.1, 1.0], [3.8, 0.8], [1.5, -1.0, 1.0], [0.1, 1.0, 0.3]]).to_tensor(MIN_FLOAT_32) is_valid = tf.math.not_equal(logits, MIN_FLOAT_32) rutils.compute_ranks(logits, is_valid)