"""
Documentation
"""
# Python Modules
import logging
# 3rd Party Modules
import tensorflow as tf
# Project Modules
log = logging.getLogger(__name__)
[docs]def disable_gpu():
tf.config.set_visible_devices([], 'GPU')
[docs]def grow_memory():
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
log.error(f"Unable to grow memory: {str(e)}", e)
[docs]def set_device(device: int):
devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(devices[device], 'GPU')
[docs]def to_nd_indices(indices):
"""
:param indices: A `Tensor` of shape [batch_size, size] with integer values.
:return:
"""
indices.get_shape().assert_has_rank(2)
batch_ids = tf.ones_like(indices) * tf.expand_dims(tf.range(tf.shape(indices)[0]), 1)
return tf.stack([batch_ids, indices], axis=-1)