jax.random: deprecate passing of batched keys to APIs

This commit is contained in:
Jake VanderPlas 2024-01-17 12:53:24 -08:00
parent 3e8067060e
commit 03ce8ca0ca
6 changed files with 96 additions and 52 deletions

View File

@ -81,6 +81,9 @@ Remember to align the itemized text with the first line of an item within a list
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
## jaxlib 0.4.24

View File

@ -69,12 +69,18 @@ def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]:
# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True.
# FutureWarning Added 2024-01-17
def _check_prng_key(name: str, key: KeyArrayLike, *,
allow_batched: bool = False,
error_on_batched: bool = False) -> tuple[KeyArray, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
return key, False
wrapped_key = key
wrapped = False
elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
wrapped = True
if config.legacy_prng_key.value == 'error':
raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. '
@ -91,10 +97,20 @@ def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]:
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return wrapped_key, True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
if (not allow_batched) and wrapped_key.ndim:
msg = (f"{name} accepts a single key, but was given a key array of "
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
if error_on_batched:
raise ValueError(msg)
else:
warnings.warn(msg + " In a future JAX version, this will be an error.",
FutureWarning, stacklevel=3)
return wrapped_key, wrapped
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
@ -245,10 +261,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
A new PRNG key that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
key, wrapped = _check_prng_key(key)
if np.ndim(key):
raise TypeError("fold_in accepts a single key, but was given a key array of"
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True)
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
@ -262,7 +275,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
# to always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("split accepts a single key, but was given a key array of"
raise TypeError("split accepts a single key, but was given a key array of "
f"shape {key.shape} != (). Use jax.vmap for batching.")
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
@ -278,7 +291,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
Returns:
An array-like object of `num` new PRNG keys.
"""
typed_key, wrapped = _check_prng_key(key)
typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True)
return _return_prng_keys(wrapped, _split(typed_key, num))
@ -288,7 +301,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl:
return keys_dtype._impl
def key_impl(keys: KeyArrayLike) -> Hashable:
typed_keys, _ = _check_prng_key(keys)
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
return PRNGSpec(_key_impl(typed_keys))
@ -298,7 +311,7 @@ def _key_data(keys: KeyArray) -> Array:
def key_data(keys: KeyArrayLike) -> Array:
"""Recover the bits of key data underlying a PRNG key array."""
keys, _ = _check_prng_key(keys)
keys, _ = _check_prng_key("key_data", keys, allow_batched=True)
return _key_data(keys)
@ -350,7 +363,7 @@ def bits(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("bits", key)
if dtype is None:
dtype = dtypes.canonicalize_dtype(jnp.uint)
else:
@ -383,7 +396,7 @@ def uniform(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("uniform", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
@ -452,7 +465,7 @@ def randint(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("randint", key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
@ -535,7 +548,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
"Use jax.random.permutation with independent=True.")
warnings.warn(msg, FutureWarning)
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("shuffle", key)
return _shuffle(key, x, axis) # type: ignore
@ -556,7 +569,7 @@ def permutation(key: KeyArrayLike,
Returns:
A shuffled version of x or array range
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("permutation", key)
check_arraylike("permutation", x)
axis = canonicalize_axis(axis, np.ndim(x) or 1)
if not np.ndim(x):
@ -630,7 +643,7 @@ def choice(key: KeyArrayLike,
Returns:
An array of shape `shape` containing samples from `a`.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("choice", key)
if not isinstance(shape, Sequence):
raise TypeError("shape argument of jax.random.choice must be a sequence, "
f"got {shape}")
@ -697,7 +710,7 @@ def normal(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("normal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
@ -764,7 +777,7 @@ def multivariate_normal(key: KeyArrayLike,
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("multivariate_normal", key)
dtypes.check_user_dtype_supported(dtype)
mean, cov = promote_dtypes_inexact(mean, cov)
if method not in {'svd', 'eigh', 'cholesky'}:
@ -843,7 +856,7 @@ def truncated_normal(key: KeyArrayLike,
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("truncated_normal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
@ -901,7 +914,7 @@ def bernoulli(key: KeyArrayLike,
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("bernoulli", key)
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
if shape is not None:
shape = core.as_named_shape(shape)
@ -952,7 +965,7 @@ def beta(key: KeyArrayLike,
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("beta", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `beta` must be a float "
@ -1005,7 +1018,7 @@ def cauchy(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("cauchy", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `cauchy` must be a float "
@ -1057,7 +1070,7 @@ def dirichlet(key: KeyArrayLike,
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
``alpha.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("dirichlet", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `dirichlet` must be a float "
@ -1116,7 +1129,7 @@ def exponential(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("exponential", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `exponential` must be a float "
@ -1297,7 +1310,7 @@ def gamma(key: KeyArrayLike,
loggamma : sample gamma values in log-space, which can provide improved
accuracy for small values of ``a``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("gamma", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
@ -1339,7 +1352,7 @@ def loggamma(key: KeyArrayLike,
See Also:
gamma : standard gamma sampler.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("loggamma", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
@ -1475,7 +1488,7 @@ def poisson(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape is not None, or else by ``lam.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("poisson", key)
dtypes.check_user_dtype_supported(dtype)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check
@ -1515,7 +1528,7 @@ def gumbel(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("gumbel", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gumbel` must be a float "
@ -1550,7 +1563,7 @@ def categorical(key: KeyArrayLike,
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("categorical", key)
check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits)
@ -1593,7 +1606,7 @@ def laplace(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("laplace", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `laplace` must be a float "
@ -1630,7 +1643,7 @@ def logistic(key: KeyArrayLike,
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("logistic", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `logistic` must be a float "
@ -1673,7 +1686,7 @@ def pareto(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``b.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("pareto", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `pareto` must be a float "
@ -1722,7 +1735,7 @@ def t(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("t", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `t` must be a float "
@ -1775,7 +1788,7 @@ def chisquare(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("chisquare", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `chisquare` must be a float "
@ -1833,7 +1846,7 @@ def f(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("f", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `f` must be a float "
@ -1885,7 +1898,7 @@ def rademacher(key: KeyArrayLike,
a 50% change of being 1 or -1.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("rademacher", key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
@ -1921,7 +1934,7 @@ def maxwell(key: KeyArrayLike,
"""
# Generate samples using:
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("maxwell", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `maxwell` must be a float "
@ -1964,7 +1977,7 @@ def double_sided_maxwell(key: KeyArrayLike,
A jnp.array of samples.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("double_sided_maxwell", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
@ -2016,7 +2029,7 @@ def weibull_min(key: KeyArrayLike,
A jnp.array of samples.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("weibull_min", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `weibull_min` must be a float "
@ -2055,7 +2068,7 @@ def orthogonal(
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("orthogonal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
@ -2090,7 +2103,7 @@ def generalized_normal(
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("generalized_normal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("generalized_normal", shape)
keys = split(key)
@ -2120,7 +2133,7 @@ def ball(
Returns:
A random array of shape `(*shape, d)` and specified dtype.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("ball", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("ball", shape)
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
@ -2158,7 +2171,7 @@ def rayleigh(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``scale.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("rayleigh", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `rayleigh` must be a float "
@ -2212,7 +2225,7 @@ def wald(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``mean.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("wald", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `wald` must be a float "
@ -2268,7 +2281,7 @@ def geometric(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``p.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("geometric", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.integer):
raise ValueError("dtype argument to `geometric` must be an int "
@ -2330,7 +2343,7 @@ def triangular(key: KeyArrayLike,
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("triangular", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `triangular` must be a float "
@ -2384,7 +2397,7 @@ def lognormal(key: KeyArrayLike,
Returns:
A random array with the specified dtype and with shape given by ``shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("lognormal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, "
@ -2597,7 +2610,7 @@ def binomial(
A random array with the specified dtype and with shape given by
``np.broadcast(n, p).shape``.
"""
key, _ = _check_prng_key(key)
key, _ = _check_prng_key("binomial", key)
check_arraylike("binomial", n, p)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):

View File

@ -2108,7 +2108,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# non-partitionable), and unsafe_rbg.
[
PolyHarness("random_gamma", f"{flags_name}",
lambda key, a: jax.random.gamma(key, a),
lambda key, a: jax.vmap(jax.random.gamma)(key, a),
arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
override_jax_config_flags=override_jax_config_flags), # type: ignore

View File

@ -616,7 +616,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
def f():
key = jax.random.key(0)
key2 = key[None]
return jax.random.bits(key) + jax.random.bits(key2)
return jax.random.bits(key) + jax.vmap(jax.random.bits)(key2)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f)

View File

@ -1247,6 +1247,34 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
def test_batched_key_warnings(self):
keys = jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
# Check a handful of functions that are expected to warn.
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
jax.random.bits(keys, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
jax.random.chisquare(keys, 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
jax.random.dirichlet(keys, jnp.arange(2.0), shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
jax.random.gamma(keys, 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
jax.random.loggamma(keys, 1.0, shape=(2,))
# Other functions should error; test a few cases.
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
jax.random.fold_in(keys, 0)
with self.assertRaisesRegex(ValueError, msg.format('split')):
jax.random.split(keys)
# Some shouldn't error or warn
with self.assertNoWarnings():
jax.random.key_data(keys)
jax.random.key_impl(keys)
threefry_seed = prng_internal.threefry_seed
threefry_split = prng_internal.threefry_split
threefry_random_bits = prng_internal.threefry_random_bits

View File

@ -1917,7 +1917,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# non-partitionable), and unsafe_rbg.
[
PolyHarness("random_gamma", f"{flags_name}",
lambda key, a: jax.random.gamma(
lambda key, a: jax.vmap(jax.random.gamma)(
jax.random.wrap_key_data(key), a),
arg_descriptors=[RandArg((3, key_size), np.uint32),
RandArg((3, 4, 5), _f32)],