mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of batched keys to PRNG functions
PiperOrigin-RevId: 636196573
This commit is contained in:
parent
949e98c08d
commit
568987af23
@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
deprecated and will soon be removed. Use `rtol` instead.
|
||||
* The deprecated `jax.config` submodule has been removed. To configure JAX
|
||||
use `import jax` and then reference the config object via `jax.config`.
|
||||
* {mod}`jax.random` APIs no longer accept batched keys, where previously
|
||||
some did unintentionally. Going forward, we recommend explicit use of
|
||||
{func}`jax.vmap` in such cases.
|
||||
|
||||
## jaxlib 0.4.29
|
||||
|
||||
|
@ -70,11 +70,8 @@ def _isnan(x: ArrayLike) -> Array:
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True.
|
||||
# FutureWarning Added 2024-01-17
|
||||
def _check_prng_key(name: str, key: KeyArrayLike, *,
|
||||
allow_batched: bool = False,
|
||||
error_on_batched: bool = False) -> tuple[KeyArray, bool]:
|
||||
allow_batched: bool = False) -> tuple[KeyArray, bool]:
|
||||
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
|
||||
wrapped_key = key
|
||||
wrapped = False
|
||||
@ -102,13 +99,8 @@ def _check_prng_key(name: str, key: KeyArrayLike, *,
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
if (not allow_batched) and wrapped_key.ndim:
|
||||
msg = (f"{name} accepts a single key, but was given a key array of "
|
||||
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
|
||||
if error_on_batched:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
warnings.warn(msg + " In a future JAX version, this will be an error.",
|
||||
FutureWarning, stacklevel=3)
|
||||
raise ValueError(f"{name} accepts a single key, but was given a key array of"
|
||||
f" shape {np.shape(key)} != (). Use jax.vmap for batching.")
|
||||
|
||||
return wrapped_key, wrapped
|
||||
|
||||
@ -252,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
A new PRNG key that is a deterministic function of the inputs and is
|
||||
statistically safe for producing a stream of new pseudo-random values.
|
||||
"""
|
||||
key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True)
|
||||
key, wrapped = _check_prng_key("fold_in", key)
|
||||
if np.ndim(data):
|
||||
raise TypeError("fold_in accepts a scalar, but was given an array of"
|
||||
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
|
||||
@ -282,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
Returns:
|
||||
An array-like object of `num` new PRNG keys.
|
||||
"""
|
||||
typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True)
|
||||
typed_key, wrapped = _check_prng_key("split", key)
|
||||
return _return_prng_keys(wrapped, _split(typed_key, num))
|
||||
|
||||
|
||||
|
@ -1243,29 +1243,27 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
|
||||
def test_batched_key_warnings(self):
|
||||
def test_batched_key_errors(self):
|
||||
keys = lambda: jax.random.split(self.make_key(0))
|
||||
msg = "{} accepts a single key, but was given a key array of shape.*"
|
||||
|
||||
# Check a handful of functions that are expected to warn.
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
|
||||
# Check a handful of functions that are expected to error.
|
||||
with self.assertRaisesRegex(ValueError, msg.format('bits')):
|
||||
jax.random.bits(keys(), shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
|
||||
with self.assertRaisesRegex(ValueError, msg.format('chisquare')):
|
||||
jax.random.chisquare(keys(), 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
|
||||
with self.assertRaisesRegex(ValueError, msg.format('dirichlet')):
|
||||
jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
|
||||
with self.assertRaisesRegex(ValueError, msg.format('gamma')):
|
||||
jax.random.gamma(keys(), 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
|
||||
with self.assertRaisesRegex(ValueError, msg.format('loggamma')):
|
||||
jax.random.loggamma(keys(), 1.0, shape=(2,))
|
||||
|
||||
# Other functions should error; test a few cases.
|
||||
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
|
||||
jax.random.fold_in(keys(), 0)
|
||||
with self.assertRaisesRegex(ValueError, msg.format('split')):
|
||||
jax.random.split(keys())
|
||||
|
||||
# Some shouldn't error or warn
|
||||
# Shouldn't error or warn:
|
||||
with self.assertNoWarnings():
|
||||
jax.random.key_data(keys())
|
||||
jax.random.key_impl(keys())
|
||||
|
Loading…
x
Reference in New Issue
Block a user