mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
random: clean up fold_in implementation
Addresses an old TODO
This commit is contained in:
parent
b7814352a6
commit
94414ce07e
@ -255,18 +255,6 @@ def unsafe_rbg_key(seed: int | ArrayLike) -> KeyArray:
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
|
||||
def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
|
||||
# Alternative to fold_in() to use within random samplers.
|
||||
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
if key.ndim:
|
||||
raise TypeError("fold_in accepts a single key, but was given a key array of"
|
||||
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
||||
if np.ndim(data):
|
||||
raise TypeError("fold_in accepts a scalar, but was given an array of"
|
||||
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
|
||||
return prng.random_fold_in(key, jnp.uint32(data))
|
||||
|
||||
def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
|
||||
@ -279,7 +267,14 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
statistically safe for producing a stream of new pseudo-random values.
|
||||
"""
|
||||
key, wrapped = _check_prng_key(key)
|
||||
return _return_prng_keys(wrapped, _fold_in(key, data))
|
||||
if np.ndim(key):
|
||||
raise TypeError("fold_in accepts a single key, but was given a key array of"
|
||||
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
|
||||
if np.ndim(data):
|
||||
raise TypeError("fold_in accepts a scalar, but was given an array of"
|
||||
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
|
||||
key_out = prng.random_fold_in(key, jnp.uint32(data))
|
||||
return _return_prng_keys(wrapped, key_out)
|
||||
|
||||
|
||||
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
|
Loading…
x
Reference in New Issue
Block a user