mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
custom_prng: better error messages for key validation
This commit is contained in:
parent
f87e6c3a74
commit
8b1d710202
@ -75,14 +75,21 @@ class PRNGImpl(NamedTuple):
|
||||
|
||||
# -- PRNG key arrays --
|
||||
|
||||
def _is_prng_key_data(impl, keys: jnp.ndarray) -> bool:
|
||||
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
ndim = len(impl.key_shape)
|
||||
try:
|
||||
return (keys.ndim >= 1 and
|
||||
keys.shape[-ndim:] == impl.key_shape and
|
||||
(keys.dtype == np.uint32 or keys.dtype == float0))
|
||||
except AttributeError:
|
||||
return False
|
||||
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
|
||||
f"to have ndim, shape, and dtype attributes. Got {key_data}")
|
||||
if key_data.ndim < 1:
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected "
|
||||
f"key_data.ndim >= 1; got ndim={key_data.ndim}")
|
||||
if key_data.shape[-ndim:] != impl.key_shape:
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to "
|
||||
f"end with {impl.key_shape}; got shape={key_data.shape} for impl={impl}")
|
||||
if key_data.dtype not in [np.uint32, float0]:
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
|
||||
f"got dtype={key_data.dtype}")
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class PRNGKeyArray:
|
||||
@ -105,10 +112,8 @@ class PRNGKeyArray:
|
||||
def __init__(self, impl, key_data: jnp.ndarray):
|
||||
# key_data might be a placeholder python `object` or `bool`
|
||||
# instead of a jnp.ndarray due to tree_unflatten
|
||||
if (type(key_data) not in [object, bool] and
|
||||
not _is_prng_key_data(impl, key_data)):
|
||||
raise TypeError(
|
||||
f'Invalid PRNG key data {key_data} for PRNG implementation {impl}')
|
||||
if type(key_data) not in [object, bool]:
|
||||
_check_prng_key_data(impl, key_data)
|
||||
self.impl = impl
|
||||
self._keys = key_data
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user