random: clean up fold_in implementation

Addresses an old TODO
This commit is contained in:
Jake VanderPlas 2023-11-16 10:51:59 -08:00
parent b7814352a6
commit 94414ce07e

View File

@ -255,18 +255,6 @@ def unsafe_rbg_key(seed: int | ArrayLike) -> KeyArray:
return _return_prng_keys(True, key)
def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("fold_in accepts a single key, but was given a key array of"
f"shape {key.shape} != (). Use jax.vmap for batching.")
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
return prng.random_fold_in(key, jnp.uint32(data))
def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
@ -279,7 +267,14 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
statistically safe for producing a stream of new pseudo-random values.
"""
key, wrapped = _check_prng_key(key)
return _return_prng_keys(wrapped, _fold_in(key, data))
if np.ndim(key):
raise TypeError("fold_in accepts a single key, but was given a key array of"
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
key_out = prng.random_fold_in(key, jnp.uint32(data))
return _return_prng_keys(wrapped, key_out)
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: