Finalize deprecation of batched keys to PRNG functions

PiperOrigin-RevId: 636196573
This commit is contained in:
Jake VanderPlas 2024-05-22 09:36:43 -07:00 committed by jax authors
parent 949e98c08d
commit 568987af23
3 changed files with 16 additions and 23 deletions

View File

@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list
deprecated and will soon be removed. Use `rtol` instead.
* The deprecated `jax.config` submodule has been removed. To configure JAX
use `import jax` and then reference the config object via `jax.config`.
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
## jaxlib 0.4.29

View File

@ -70,11 +70,8 @@ def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True.
# FutureWarning Added 2024-01-17
def _check_prng_key(name: str, key: KeyArrayLike, *,
allow_batched: bool = False,
error_on_batched: bool = False) -> tuple[KeyArray, bool]:
allow_batched: bool = False) -> tuple[KeyArray, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
wrapped_key = key
wrapped = False
@ -102,13 +99,8 @@ def _check_prng_key(name: str, key: KeyArrayLike, *,
raise TypeError(f'unexpected PRNG key type {type(key)}')
if (not allow_batched) and wrapped_key.ndim:
msg = (f"{name} accepts a single key, but was given a key array of "
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
if error_on_batched:
raise ValueError(msg)
else:
warnings.warn(msg + " In a future JAX version, this will be an error.",
FutureWarning, stacklevel=3)
raise ValueError(f"{name} accepts a single key, but was given a key array of"
f" shape {np.shape(key)} != (). Use jax.vmap for batching.")
return wrapped_key, wrapped
@ -252,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
A new PRNG key that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True)
key, wrapped = _check_prng_key("fold_in", key)
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
@ -282,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
Returns:
An array-like object of `num` new PRNG keys.
"""
typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True)
typed_key, wrapped = _check_prng_key("split", key)
return _return_prng_keys(wrapped, _split(typed_key, num))

View File

@ -1243,29 +1243,27 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
def test_batched_key_warnings(self):
def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
# Check a handful of functions that are expected to warn.
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
# Check a handful of functions that are expected to error.
with self.assertRaisesRegex(ValueError, msg.format('bits')):
jax.random.bits(keys(), shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
with self.assertRaisesRegex(ValueError, msg.format('chisquare')):
jax.random.chisquare(keys(), 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
with self.assertRaisesRegex(ValueError, msg.format('dirichlet')):
jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
with self.assertRaisesRegex(ValueError, msg.format('gamma')):
jax.random.gamma(keys(), 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
with self.assertRaisesRegex(ValueError, msg.format('loggamma')):
jax.random.loggamma(keys(), 1.0, shape=(2,))
# Other functions should error; test a few cases.
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
jax.random.fold_in(keys(), 0)
with self.assertRaisesRegex(ValueError, msg.format('split')):
jax.random.split(keys())
# Some shouldn't error or warn
# Shouldn't error or warn:
with self.assertNoWarnings():
jax.random.key_data(keys())
jax.random.key_impl(keys())