custom_prng: better error messages for key validation

This commit is contained in:
Jake VanderPlas 2022-03-04 10:49:29 -08:00
parent f87e6c3a74
commit 8b1d710202

View File

@ -75,14 +75,21 @@ class PRNGImpl(NamedTuple):
# -- PRNG key arrays --
def _is_prng_key_data(impl, keys: jnp.ndarray) -> bool:
def _check_prng_key_data(impl, key_data: jnp.ndarray):
ndim = len(impl.key_shape)
try:
return (keys.ndim >= 1 and
keys.shape[-ndim:] == impl.key_shape and
(keys.dtype == np.uint32 or keys.dtype == float0))
except AttributeError:
return False
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
f"to have ndim, shape, and dtype attributes. Got {key_data}")
if key_data.ndim < 1:
raise TypeError("JAX encountered invalid PRNG key data: expected "
f"key_data.ndim >= 1; got ndim={key_data.ndim}")
if key_data.shape[-ndim:] != impl.key_shape:
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to "
f"end with {impl.key_shape}; got shape={key_data.shape} for impl={impl}")
if key_data.dtype not in [np.uint32, float0]:
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
f"got dtype={key_data.dtype}")
@tree_util.register_pytree_node_class
class PRNGKeyArray:
@ -105,10 +112,8 @@ class PRNGKeyArray:
def __init__(self, impl, key_data: jnp.ndarray):
# key_data might be a placeholder python `object` or `bool`
# instead of a jnp.ndarray due to tree_unflatten
if (type(key_data) not in [object, bool] and
not _is_prng_key_data(impl, key_data)):
raise TypeError(
f'Invalid PRNG key data {key_data} for PRNG implementation {impl}')
if type(key_data) not in [object, bool]:
_check_prng_key_data(impl, key_data)
self.impl = impl
self._keys = key_data