"""
Documentation
"""
# Python Modules
from typing import List, Optional, Tuple, Union
# 3rd Party Modules
# Project Modules
DROPOUT_TYPE = Optional[Union[List[float], Tuple[float], float]]
[docs]def normalize_dropout(dropout_rate: DROPOUT_TYPE, n_units: List[int]):
"""
Returns the dropout rate as a list the same size as `len(n_units)`.
This does some error checking, but is not very robust.
:param dropout_rate:
:param n_units:
:return:
"""
# This should handle None, 0 and an empty list
if not dropout_rate:
return [0.] * len(n_units)
if isinstance(dropout_rate, (list, tuple)):
if len(dropout_rate) == 1:
return dropout_rate * len(n_units)
if len(dropout_rate) != len(n_units):
raise ValueError(
f"The dropout rate is not compatible with the number of layers: "
f"{len(dropout_rate)} != {len(n_units)}"
)
return dropout_rate
return [dropout_rate] * len(n_units)