diff --git a/CHANGELOG.md b/CHANGELOG.md index c7f1e0eab..46351308c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list deprecated and will soon be removed. Use `rtol` instead. * The deprecated `jax.config` submodule has been removed. To configure JAX use `import jax` and then reference the config object via `jax.config`. + * {mod}`jax.random` APIs no longer accept batched keys, where previously + some did unintentionally. Going forward, we recommend explicit use of + {func}`jax.vmap` in such cases. ## jaxlib 0.4.29 diff --git a/jax/_src/random.py b/jax/_src/random.py index fb5d0bf45..f0ea398e7 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -70,11 +70,8 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True. -# FutureWarning Added 2024-01-17 def _check_prng_key(name: str, key: KeyArrayLike, *, - allow_batched: bool = False, - error_on_batched: bool = False) -> tuple[KeyArray, bool]: + allow_batched: bool = False) -> tuple[KeyArray, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): wrapped_key = key wrapped = False @@ -102,13 +99,8 @@ def _check_prng_key(name: str, key: KeyArrayLike, *, raise TypeError(f'unexpected PRNG key type {type(key)}') if (not allow_batched) and wrapped_key.ndim: - msg = (f"{name} accepts a single key, but was given a key array of " - f"shape {np.shape(key)} != (). Use jax.vmap for batching.") - if error_on_batched: - raise ValueError(msg) - else: - warnings.warn(msg + " In a future JAX version, this will be an error.", - FutureWarning, stacklevel=3) + raise ValueError(f"{name} accepts a single key, but was given a key array of" + f" shape {np.shape(key)} != (). Use jax.vmap for batching.") return wrapped_key, wrapped @@ -252,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: A new PRNG key that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values. """ - key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True) + key, wrapped = _check_prng_key("fold_in", key) if np.ndim(data): raise TypeError("fold_in accepts a scalar, but was given an array of" f"shape {np.shape(data)} != (). Use jax.vmap for batching.") @@ -282,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: Returns: An array-like object of `num` new PRNG keys. """ - typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True) + typed_key, wrapped = _check_prng_key("split", key) return _return_prng_keys(wrapped, _split(typed_key, num)) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 39c4a0ef9..f69687ddc 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1243,29 +1243,27 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False) self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False) - def test_batched_key_warnings(self): + def test_batched_key_errors(self): keys = lambda: jax.random.split(self.make_key(0)) msg = "{} accepts a single key, but was given a key array of shape.*" - # Check a handful of functions that are expected to warn. - with self.assertWarnsRegex(FutureWarning, msg.format('bits')): + # Check a handful of functions that are expected to error. + with self.assertRaisesRegex(ValueError, msg.format('bits')): jax.random.bits(keys(), shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')): + with self.assertRaisesRegex(ValueError, msg.format('chisquare')): jax.random.chisquare(keys(), 1.0, shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')): + with self.assertRaisesRegex(ValueError, msg.format('dirichlet')): jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('gamma')): + with self.assertRaisesRegex(ValueError, msg.format('gamma')): jax.random.gamma(keys(), 1.0, shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')): + with self.assertRaisesRegex(ValueError, msg.format('loggamma')): jax.random.loggamma(keys(), 1.0, shape=(2,)) - - # Other functions should error; test a few cases. with self.assertRaisesRegex(ValueError, msg.format('fold_in')): jax.random.fold_in(keys(), 0) with self.assertRaisesRegex(ValueError, msg.format('split')): jax.random.split(keys()) - # Some shouldn't error or warn + # Shouldn't error or warn: with self.assertNoWarnings(): jax.random.key_data(keys()) jax.random.key_impl(keys())