"""
Documentation
"""
# Python Modules
import logging
import os
import unittest
# 3rd Party Modules
import numpy as np
import tensorflow as tf
# Project Modules
import deletor.tfutils as tfutils
tfutils.disable_gpu()
import deletor.ranking.utils as rutils
from deletor.constants import MIN_FLOAT_32
os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
logging.getLogger('tensorflow').setLevel(logging.WARNING)
np.set_printoptions(precision=4, suppress=True, edgeitems=20, linewidth=10000)
[docs]class TestUtils(unittest.TestCase):
[docs] def test_compute_ranks(self):
logits = tf.ragged.constant([[3.5, 2.2, 0.5, 1.0],
[1.7, 2.3, 2.1, 1.1, 0.1, 1.0],
[3.8, 0.8],
[1.5, -1.0, 1.0],
[0.1, 1.0, 0.3]]).to_tensor(MIN_FLOAT_32)
is_valid = tf.math.not_equal(logits, MIN_FLOAT_32)
rutils.compute_ranks(logits, is_valid)