mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.random: deprecate passing of batched keys to APIs
This commit is contained in:
parent
3e8067060e
commit
03ce8ca0ca
@ -81,6 +81,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
|
||||
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
|
||||
removed in the future. Use the "stablehlo" dialect instead.
|
||||
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
|
||||
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
|
||||
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
|
||||
|
||||
## jaxlib 0.4.24
|
||||
|
||||
|
@ -69,12 +69,18 @@ def _isnan(x: ArrayLike) -> Array:
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]:
|
||||
# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True.
|
||||
# FutureWarning Added 2024-01-17
|
||||
def _check_prng_key(name: str, key: KeyArrayLike, *,
|
||||
allow_batched: bool = False,
|
||||
error_on_batched: bool = False) -> tuple[KeyArray, bool]:
|
||||
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
|
||||
return key, False
|
||||
wrapped_key = key
|
||||
wrapped = False
|
||||
elif _arraylike(key):
|
||||
# Call random_wrap here to surface errors for invalid keys.
|
||||
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
|
||||
wrapped = True
|
||||
if config.legacy_prng_key.value == 'error':
|
||||
raise ValueError(
|
||||
'Legacy uint32 key array passed as key to jax.random function. '
|
||||
@ -91,10 +97,20 @@ def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]:
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return wrapped_key, True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
if (not allow_batched) and wrapped_key.ndim:
|
||||
msg = (f"{name} accepts a single key, but was given a key array of "
|
||||
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
|
||||
if error_on_batched:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
warnings.warn(msg + " In a future JAX version, this will be an error.",
|
||||
FutureWarning, stacklevel=3)
|
||||
|
||||
return wrapped_key, wrapped
|
||||
|
||||
|
||||
def _return_prng_keys(was_wrapped, key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
@ -245,10 +261,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
A new PRNG key that is a deterministic function of the inputs and is
|
||||
statistically safe for producing a stream of new pseudo-random values.
|
||||
"""
|
||||
key, wrapped = _check_prng_key(key)
|
||||
if np.ndim(key):
|
||||
raise TypeError("fold_in accepts a single key, but was given a key array of"
|
||||
f"shape {np.shape(key)} != (). Use jax.vmap for batching.")
|
||||
key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True)
|
||||
if np.ndim(data):
|
||||
raise TypeError("fold_in accepts a scalar, but was given an array of"
|
||||
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
|
||||
@ -262,7 +275,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
# to always enable_custom_prng
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
if key.ndim:
|
||||
raise TypeError("split accepts a single key, but was given a key array of"
|
||||
raise TypeError("split accepts a single key, but was given a key array of "
|
||||
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
||||
shape = tuple(num) if isinstance(num, Sequence) else (num,)
|
||||
return prng.random_split(key, shape=shape)
|
||||
@ -278,7 +291,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
Returns:
|
||||
An array-like object of `num` new PRNG keys.
|
||||
"""
|
||||
typed_key, wrapped = _check_prng_key(key)
|
||||
typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True)
|
||||
return _return_prng_keys(wrapped, _split(typed_key, num))
|
||||
|
||||
|
||||
@ -288,7 +301,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl:
|
||||
return keys_dtype._impl
|
||||
|
||||
def key_impl(keys: KeyArrayLike) -> Hashable:
|
||||
typed_keys, _ = _check_prng_key(keys)
|
||||
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
|
||||
return PRNGSpec(_key_impl(typed_keys))
|
||||
|
||||
|
||||
@ -298,7 +311,7 @@ def _key_data(keys: KeyArray) -> Array:
|
||||
|
||||
def key_data(keys: KeyArrayLike) -> Array:
|
||||
"""Recover the bits of key data underlying a PRNG key array."""
|
||||
keys, _ = _check_prng_key(keys)
|
||||
keys, _ = _check_prng_key("key_data", keys, allow_batched=True)
|
||||
return _key_data(keys)
|
||||
|
||||
|
||||
@ -350,7 +363,7 @@ def bits(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("bits", key)
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(jnp.uint)
|
||||
else:
|
||||
@ -383,7 +396,7 @@ def uniform(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("uniform", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
@ -452,7 +465,7 @@ def randint(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("randint", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
@ -535,7 +548,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
|
||||
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
|
||||
"Use jax.random.permutation with independent=True.")
|
||||
warnings.warn(msg, FutureWarning)
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("shuffle", key)
|
||||
return _shuffle(key, x, axis) # type: ignore
|
||||
|
||||
|
||||
@ -556,7 +569,7 @@ def permutation(key: KeyArrayLike,
|
||||
Returns:
|
||||
A shuffled version of x or array range
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("permutation", key)
|
||||
check_arraylike("permutation", x)
|
||||
axis = canonicalize_axis(axis, np.ndim(x) or 1)
|
||||
if not np.ndim(x):
|
||||
@ -630,7 +643,7 @@ def choice(key: KeyArrayLike,
|
||||
Returns:
|
||||
An array of shape `shape` containing samples from `a`.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("choice", key)
|
||||
if not isinstance(shape, Sequence):
|
||||
raise TypeError("shape argument of jax.random.choice must be a sequence, "
|
||||
f"got {shape}")
|
||||
@ -697,7 +710,7 @@ def normal(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
|
||||
@ -764,7 +777,7 @@ def multivariate_normal(key: KeyArrayLike,
|
||||
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
|
||||
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("multivariate_normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
mean, cov = promote_dtypes_inexact(mean, cov)
|
||||
if method not in {'svd', 'eigh', 'cholesky'}:
|
||||
@ -843,7 +856,7 @@ def truncated_normal(key: KeyArrayLike,
|
||||
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
|
||||
Returns values in the open interval ``(lower, upper)``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("truncated_normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
|
||||
@ -901,7 +914,7 @@ def bernoulli(key: KeyArrayLike,
|
||||
A random array with boolean dtype and shape given by ``shape`` if ``shape``
|
||||
is not None, or else ``p.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("bernoulli", key)
|
||||
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
|
||||
if shape is not None:
|
||||
shape = core.as_named_shape(shape)
|
||||
@ -952,7 +965,7 @@ def beta(key: KeyArrayLike,
|
||||
A random array with the specified dtype and shape given by ``shape`` if
|
||||
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("beta", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `beta` must be a float "
|
||||
@ -1005,7 +1018,7 @@ def cauchy(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("cauchy", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `cauchy` must be a float "
|
||||
@ -1057,7 +1070,7 @@ def dirichlet(key: KeyArrayLike,
|
||||
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
|
||||
``alpha.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("dirichlet", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `dirichlet` must be a float "
|
||||
@ -1116,7 +1129,7 @@ def exponential(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("exponential", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `exponential` must be a float "
|
||||
@ -1297,7 +1310,7 @@ def gamma(key: KeyArrayLike,
|
||||
loggamma : sample gamma values in log-space, which can provide improved
|
||||
accuracy for small values of ``a``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("gamma", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gamma` must be a float "
|
||||
@ -1339,7 +1352,7 @@ def loggamma(key: KeyArrayLike,
|
||||
See Also:
|
||||
gamma : standard gamma sampler.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("loggamma", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gamma` must be a float "
|
||||
@ -1475,7 +1488,7 @@ def poisson(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape is not None, or else by ``lam.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("poisson", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
# TODO(frostig): generalize underlying poisson implementation and
|
||||
# remove this check
|
||||
@ -1515,7 +1528,7 @@ def gumbel(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("gumbel", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gumbel` must be a float "
|
||||
@ -1550,7 +1563,7 @@ def categorical(key: KeyArrayLike,
|
||||
A random array with int dtype and shape given by ``shape`` if ``shape``
|
||||
is not None, or else ``np.delete(logits.shape, axis)``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("categorical", key)
|
||||
check_arraylike("categorical", logits)
|
||||
logits_arr = jnp.asarray(logits)
|
||||
|
||||
@ -1593,7 +1606,7 @@ def laplace(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("laplace", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `laplace` must be a float "
|
||||
@ -1630,7 +1643,7 @@ def logistic(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("logistic", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `logistic` must be a float "
|
||||
@ -1673,7 +1686,7 @@ def pareto(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``b.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("pareto", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `pareto` must be a float "
|
||||
@ -1722,7 +1735,7 @@ def t(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("t", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `t` must be a float "
|
||||
@ -1775,7 +1788,7 @@ def chisquare(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("chisquare", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `chisquare` must be a float "
|
||||
@ -1833,7 +1846,7 @@ def f(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("f", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `f` must be a float "
|
||||
@ -1885,7 +1898,7 @@ def rademacher(key: KeyArrayLike,
|
||||
a 50% change of being 1 or -1.
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("rademacher", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
@ -1921,7 +1934,7 @@ def maxwell(key: KeyArrayLike,
|
||||
"""
|
||||
# Generate samples using:
|
||||
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("maxwell", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `maxwell` must be a float "
|
||||
@ -1964,7 +1977,7 @@ def double_sided_maxwell(key: KeyArrayLike,
|
||||
A jnp.array of samples.
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("double_sided_maxwell", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
|
||||
@ -2016,7 +2029,7 @@ def weibull_min(key: KeyArrayLike,
|
||||
A jnp.array of samples.
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("weibull_min", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `weibull_min` must be a float "
|
||||
@ -2055,7 +2068,7 @@ def orthogonal(
|
||||
Returns:
|
||||
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("orthogonal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("orthogonal", shape)
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
@ -2090,7 +2103,7 @@ def generalized_normal(
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("generalized_normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("generalized_normal", shape)
|
||||
keys = split(key)
|
||||
@ -2120,7 +2133,7 @@ def ball(
|
||||
Returns:
|
||||
A random array of shape `(*shape, d)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("ball", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("ball", shape)
|
||||
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
|
||||
@ -2158,7 +2171,7 @@ def rayleigh(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``scale.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("rayleigh", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `rayleigh` must be a float "
|
||||
@ -2212,7 +2225,7 @@ def wald(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``mean.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("wald", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `wald` must be a float "
|
||||
@ -2268,7 +2281,7 @@ def geometric(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``p.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("geometric", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
raise ValueError("dtype argument to `geometric` must be an int "
|
||||
@ -2330,7 +2343,7 @@ def triangular(key: KeyArrayLike,
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("triangular", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `triangular` must be a float "
|
||||
@ -2384,7 +2397,7 @@ def lognormal(key: KeyArrayLike,
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("lognormal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, "
|
||||
@ -2597,7 +2610,7 @@ def binomial(
|
||||
A random array with the specified dtype and with shape given by
|
||||
``np.broadcast(n, p).shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
key, _ = _check_prng_key("binomial", key)
|
||||
check_arraylike("binomial", n, p)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
|
@ -2108,7 +2108,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# non-partitionable), and unsafe_rbg.
|
||||
[
|
||||
PolyHarness("random_gamma", f"{flags_name}",
|
||||
lambda key, a: jax.random.gamma(key, a),
|
||||
lambda key, a: jax.vmap(jax.random.gamma)(key, a),
|
||||
arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)],
|
||||
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
|
@ -616,7 +616,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
key2 = key[None]
|
||||
return jax.random.bits(key) + jax.random.bits(key2)
|
||||
return jax.random.bits(key) + jax.vmap(jax.random.bits)(key2)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
self.check_key_reuse(f)
|
||||
|
@ -1247,6 +1247,34 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
|
||||
def test_batched_key_warnings(self):
|
||||
keys = jax.random.split(self.make_key(0))
|
||||
msg = "{} accepts a single key, but was given a key array of shape.*"
|
||||
|
||||
# Check a handful of functions that are expected to warn.
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
|
||||
jax.random.bits(keys, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
|
||||
jax.random.chisquare(keys, 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
|
||||
jax.random.dirichlet(keys, jnp.arange(2.0), shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
|
||||
jax.random.gamma(keys, 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
|
||||
jax.random.loggamma(keys, 1.0, shape=(2,))
|
||||
|
||||
# Other functions should error; test a few cases.
|
||||
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
|
||||
jax.random.fold_in(keys, 0)
|
||||
with self.assertRaisesRegex(ValueError, msg.format('split')):
|
||||
jax.random.split(keys)
|
||||
|
||||
# Some shouldn't error or warn
|
||||
with self.assertNoWarnings():
|
||||
jax.random.key_data(keys)
|
||||
jax.random.key_impl(keys)
|
||||
|
||||
|
||||
threefry_seed = prng_internal.threefry_seed
|
||||
threefry_split = prng_internal.threefry_split
|
||||
threefry_random_bits = prng_internal.threefry_random_bits
|
||||
|
@ -1917,7 +1917,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# non-partitionable), and unsafe_rbg.
|
||||
[
|
||||
PolyHarness("random_gamma", f"{flags_name}",
|
||||
lambda key, a: jax.random.gamma(
|
||||
lambda key, a: jax.vmap(jax.random.gamma)(
|
||||
jax.random.wrap_key_data(key), a),
|
||||
arg_descriptors=[RandArg((3, key_size), np.uint32),
|
||||
RandArg((3, 4, 5), _f32)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user