Source code for deletor.tfutils

"""
Documentation
"""
# Python Modules
import logging

# 3rd Party Modules
import tensorflow as tf

# Project Modules

log = logging.getLogger(__name__)


[docs]def disable_gpu(): tf.config.set_visible_devices([], 'GPU')
[docs]def grow_memory(): gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: log.error(f"Unable to grow memory: {str(e)}", e)
[docs]def set_device(device: int): devices = tf.config.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(devices[device], 'GPU')
[docs]def to_nd_indices(indices): """ :param indices: A `Tensor` of shape [batch_size, size] with integer values. :return: """ indices.get_shape().assert_has_rank(2) batch_ids = tf.ones_like(indices) * tf.expand_dims(tf.range(tf.shape(indices)[0]), 1) return tf.stack([batch_ids, indices], axis=-1)