[key reuse] add eager checks

This commit is contained in:
Jake VanderPlas 2024-02-29 15:30:19 -08:00
parent 087f99a31c
commit d08e9a03d8
12 changed files with 370 additions and 255 deletions

View File

@ -932,8 +932,11 @@ enable_checks = define_bool_state(
enable_key_reuse_checks = define_bool_state(
name='jax_enable_key_reuse_checks',
default=False,
help="Turn on experimental key reuse checking."
)
help=('Turn on experimental key reuse checking. With this configuration enabled,'
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'
' usage tracked, and incorrect reuse of a previously-used key will lead to'
' an error. Currently enabling this leads to a small Python overhead on'
' every call to a JIT-compiled function with keys as inputs or outputs.'))
check_tracer_leaks = define_bool_state(
name='jax_check_tracer_leaks',

View File

@ -33,6 +33,7 @@ from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
@ -181,8 +182,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
) -> Optional[pxla.MeshExecutableFastpathData]:
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
@ -191,14 +191,19 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
not executable.unsafe_call.ordered_effects and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and
# no attr state effects
not attrs_tracked
not attrs_tracked and
# no prng reuse checking
not (config.enable_key_reuse_checks.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat)
))
)
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
out_avals = [o.aval for o in out_reflattened]
out_committed = [o._committed for o in out_reflattened]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
fastpath_data = pxla.MeshExecutableFastpathData(

View File

@ -153,12 +153,14 @@ class PRNGKeyArray(jax.Array):
_impl: PRNGImpl
_base_array: typing.Array
_consumed: bool | np.ndarray # Used in jax.experimental.key_reuse.
def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
_check_prng_key_data(impl, key_data)
self._impl = impl
self._base_array = key_data
self._consumed = False # TODO(jakevdp): default to True here?
def block_until_ready(self):
_ = self._base_array.block_until_ready()
@ -269,7 +271,9 @@ class PRNGKeyArray(jax.Array):
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
def copy(self):
return self.__class__(self._impl, self._base_array.copy())
out = self.__class__(self._impl, self._base_array.copy())
out._consumed = self._consumed # TODO(jakevdp): is this correct?
return out
__hash__ = None # type: ignore[assignment]
__array_priority__ = 100

View File

@ -17,6 +17,7 @@ from __future__ import annotations
from typing import NamedTuple
from jax import core
from jax.interpreters import batching, mlir
from jax._src import prng
import numpy as np
@ -51,6 +52,33 @@ class KeyReuseSignature(NamedTuple):
sources: list[Source]
forwards: list[Forward] = []
def check_signature(self, *args, jaxpr=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
msg = f"Previously-consumed key at index {sink.idx} passed to function"
if jaxpr:
msg += f"\n{jaxpr=}"
raise KeyReuseError(msg)
def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed
class KeyReuseError(RuntimeError):
pass

View File

@ -15,13 +15,14 @@
from __future__ import annotations
from collections import defaultdict
from functools import reduce
from functools import partial, reduce, wraps
from typing import Any, Callable
import jax
from jax import lax
from jax import tree_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
@ -32,11 +33,13 @@ from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache
from jax.experimental.key_reuse._common import (
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, Forward, KeyReuseSignature
)
from jax.experimental.shard_map import shard_map_p
import numpy as np
@ -77,7 +80,9 @@ key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], []
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], [])
key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], [])
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([], [], [])
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)])
key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)])
# Rules which require more dynamic logic.
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
@ -91,6 +96,7 @@ def unknown_signature(eqn):
sources=[],
)
@weakref_lru_cache
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
"""Parse the jaxpr to determine key reuse signature"""
consumed: dict[core.Atom, bool | np.ndarray] = {}
@ -219,6 +225,11 @@ def _pjit_key_type_signature(eqn):
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
def _shard_map_type_signature(eqn):
return get_jaxpr_type_signature(eqn.params['jaxpr'])
key_reuse_signatures_dynamic[shard_map_p] = _shard_map_type_signature
def _cond_key_type_signature(eqn):
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
sinks = defaultdict(list)
@ -318,3 +329,33 @@ def _remat_key_type_signature(eqn):
return get_jaxpr_type_signature(eqn.params['jaxpr'])
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
# TODO(jakevdp): when we integrate key reuse checks more tightly with JAX,
# we should move this logic directly into each primitive impl.
def key_reuse_impl_rule(prim, original_rule):
@wraps(original_rule)
def key_reuse_impl(*args, **kwargs):
if config.enable_key_reuse_checks.value:
if prim == pjit.pjit_p:
jaxpr = kwargs['jaxpr'].jaxpr
signature = get_jaxpr_type_signature(jaxpr)
elif prim in key_reuse_signatures:
jaxpr = prim
signature = key_reuse_signatures[prim]
elif prim in key_reuse_signatures_dynamic:
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
signature = get_jaxpr_type_signature(jaxpr)
else:
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
signature.check_signature(*args, jaxpr=jaxpr)
result = original_rule(*args, **kwargs)
signature.update_consumption(args, result if prim.multiple_results else [result])
return result
else:
return original_rule(*args, **kwargs)
return key_reuse_impl
for prim in (*key_reuse_signatures, *key_reuse_signatures_dynamic):
prim.impl = key_reuse_impl_rule(prim, prim.impl) # type: ignore[method-assign]

View File

@ -574,6 +574,33 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
self.check_key_reuse(jax.grad(f_good), x, key)
class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key at index 0 passed to function"
bits_msg = "In random_bits, key values a are already consumed."
def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with jax.disable_jit():
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.random.bits(key)
def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.random.bits(key)
def test_key_reuse_within_jit(self):
@jax.jit
def f():
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
f()
@jtu.with_config(jax_enable_checks=False)
class KeyReuseGlobalFlags(jtu.JaxTestCase):
def test_key_reuse_flag(self):

View File

@ -1737,10 +1737,10 @@ class PythonPmapTest(jtu.JaxTestCase):
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
return res
key = random.PRNGKey(1)
x = random.normal(key, (80, 50))
key = lambda: random.PRNGKey(1)
x = random.normal(key(), (80, 50))
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
y = random.normal(key, (10, 50, 1))
y = random.normal(key(), (10, 50, 1))
result = batched_mvm(y)
expected = jnp.einsum('ij,njk->nik', x, y)
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3,

View File

@ -126,12 +126,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testRngUniform(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
@ -142,12 +142,12 @@ class LaxRandomTest(jtu.JaxTestCase):
lo = 5
hi = 10
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self.assertTrue(np.all(lo <= samples))
@ -155,12 +155,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testNormal(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
@ -174,12 +174,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=complex_dtypes)
def testNormalComplex(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
@ -188,12 +188,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testTruncatedNormal(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
min_val = np.min(uncompiled_samples)
max_val = np.max(uncompiled_samples)
@ -204,15 +204,15 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
def testShuffle(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
x = np.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = jax.jit(rand)
with self.assertWarns((DeprecationWarning, FutureWarning)):
perm1 = rand(key)
perm1 = rand(key())
with self.assertWarns((DeprecationWarning, FutureWarning)):
perm2 = crand(key)
perm2 = crand(key())
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
@ -238,7 +238,7 @@ class LaxRandomTest(jtu.JaxTestCase):
np_choice = np.random.default_rng(0).choice
p_dtype = dtypes.to_inexact_dtype(dtype)
key = self.make_key(0)
key = lambda: self.make_key(0)
is_range = type(input_range_or_shape) is int
x = (input_range_or_shape if is_range else
self.rng().permutation(np.arange(math.prod(
@ -250,7 +250,7 @@ class LaxRandomTest(jtu.JaxTestCase):
else:
p = None
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
sample = rand(key, x)
sample = rand(key(), x)
if not is_range:
self.assertEqual(dtype, sample.dtype)
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
@ -263,9 +263,9 @@ class LaxRandomTest(jtu.JaxTestCase):
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
self.assertArraysEqual(sample, rand(key, np.array(x)))
self.assertArraysEqual(sample, rand(key(), np.array(x)))
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
'x' if is_range else None)(key, x))
'x' if is_range else None)(key(), x))
@jtu.sample_product(
[dict(range_or_shape=range_or_shape, axis=axis)
@ -278,7 +278,7 @@ class LaxRandomTest(jtu.JaxTestCase):
independent=[True, False],
)
def testPermutation(self, dtype, range_or_shape, axis, independent):
key = self.make_key(0)
key = lambda: self.make_key(0)
is_range = type(range_or_shape) is int
x = (range_or_shape if is_range else
self.rng().permutation(np.arange(
@ -286,7 +286,7 @@ class LaxRandomTest(jtu.JaxTestCase):
shape = ((range_or_shape,) if is_range else range_or_shape)
x_ = np.copy(x)
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
perm = rand(key, x)
perm = rand(key(), x)
if shape[axis] >= 10:
self.assertFalse(np.all(perm == x)) # seems unlikely!
arr = np.arange(x) if is_range else x
@ -302,9 +302,9 @@ class LaxRandomTest(jtu.JaxTestCase):
with self.assertRaises(AssertionError):
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
self.assertArraysEqual(x_, x)
self.assertArraysEqual(perm, rand(key, np.array(x)))
self.assertArraysEqual(perm, rand(key(), np.array(x)))
self.assertArraysEqual(perm, jax.jit(rand, static_argnames=
'x' if is_range else None)(key, x))
'x' if is_range else None)(key(), x))
def testPermutationErrors(self):
key = self.make_key(0)
@ -320,13 +320,13 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testBernoulli(self, p, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
@ -344,7 +344,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testCategorical(self, p, axis, dtype, sample_shape):
key = self.make_key(0)
key = lambda: self.make_key(0)
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
@ -352,8 +352,8 @@ class LaxRandomTest(jtu.JaxTestCase):
rand = partial(random.categorical, shape=shape, axis=axis)
crand = jax.jit(rand)
uncompiled_samples = rand(key, logits)
compiled_samples = crand(key, logits)
uncompiled_samples = rand(key(), logits)
compiled_samples = crand(key(), logits)
if axis < 0:
axis += len(logits.shape)
@ -384,12 +384,12 @@ class LaxRandomTest(jtu.JaxTestCase):
def testBeta(self, a, b, dtype):
if not config.enable_x64.value:
raise SkipTest("skip test except on X64")
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, a, b)
compiled_samples = crand(key, a, b)
uncompiled_samples = rand(key(), a, b)
compiled_samples = crand(key(), a, b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
@ -412,12 +412,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testCauchy(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
@ -428,13 +428,13 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
num_samples = 10000
rand = lambda key, alpha: random.dirichlet(key, alpha, (num_samples,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, alpha)
compiled_samples = crand(key, alpha)
uncompiled_samples = rand(key(), alpha)
compiled_samples = crand(key(), alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), np.ones(num_samples, dtype=dtype))
@ -462,12 +462,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testExponential(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@ -479,15 +479,15 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # low accuracy leads to failures.
def testGammaVsLogGamma(self, a, dtype):
# Test that gamma() and loggamma() produce equivalent samples.
key = self.make_key(0)
rand_gamma = lambda key, a: random.gamma(key, a, (100,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (100,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
tol = {np.float32: 1E-6, np.float64: 1E-12}
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)),
key = lambda: self.make_key(0)
self.assertAllClose(rand_gamma(key(), a), jnp.exp(rand_loggamma(key(), a)),
atol=tol, rtol=tol)
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)),
self.assertAllClose(rand_gamma(key(), a), jnp.exp(crand_loggamma(key(), a)),
atol=tol, rtol=tol)
@jtu.sample_product(
@ -495,12 +495,12 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testGamma(self, a, dtype):
key = self.make_key(1)
key = lambda: self.make_key(1)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, a)
compiled_samples = crand(key, a)
uncompiled_samples = rand(key(), a)
compiled_samples = crand(key(), a)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
@ -515,13 +515,13 @@ class LaxRandomTest(jtu.JaxTestCase):
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
)
def testGammaGrad(self, log_space, alpha):
rng = self.make_key(0)
rng = lambda: self.make_key(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
z = random.gamma(rng(), alphas)
if log_space:
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng(), x)).sum())(alphas)
else:
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
actual_grad = jax.grad(lambda x: random.gamma(rng(), x).sum())(alphas)
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
@ -548,12 +548,12 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
)
def testPoisson(self, lam, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, lam)
compiled_samples = crand(key, lam)
uncompiled_samples = rand(key(), lam)
compiled_samples = crand(key(), lam)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
@ -594,36 +594,36 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testGumbel(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
@jtu.sample_product(dtype=float_dtypes)
def testLaplace(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
@jtu.sample_product(dtype=float_dtypes)
def testLogistic(self, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@ -651,12 +651,12 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormal(self, p, shape, dtype):
key = self.make_key(2)
key = lambda: self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, shape)
self.assertEqual(samples.dtype, dtype)
@ -669,12 +669,12 @@ class LaxRandomTest(jtu.JaxTestCase):
def testGeneralizedNormalKS(self, p, shape, dtype):
self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws
"sensitive to random key - https://github.com/google/jax/issues/18941")
key = self.make_key(2)
key = lambda: self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
@ -686,11 +686,11 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBall(self, d, p, shape, dtype):
key = self.make_key(123)
key = lambda: self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, (*shape, d))
self.assertEqual(samples.dtype, dtype)
@ -706,11 +706,11 @@ class LaxRandomTest(jtu.JaxTestCase):
def testBallKS(self, d, p, shape, dtype):
self.skipTest(
"sensitive to random key - https://github.com/google/jax/issues/18932")
key = self.make_key(123)
key = lambda: self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
@ -720,12 +720,12 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testPareto(self, b, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, b)
compiled_samples = crand(key, b)
uncompiled_samples = rand(key(), b)
compiled_samples = crand(key(), b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
@ -742,12 +742,12 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = self.make_key(1)
key = lambda: self.make_key(1)
rand = lambda key, df: random.t(key, df, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, df)
compiled_samples = crand(key, df)
uncompiled_samples = rand(key(), df)
compiled_samples = crand(key(), df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
@ -763,14 +763,14 @@ class LaxRandomTest(jtu.JaxTestCase):
cov_factor = r.randn(dim, dim)
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
shape=(10000,), method=method)
crand = jax.jit(rand)
with jax.numpy_rank_promotion('allow'):
uncompiled_samples = np.asarray(rand(key), np.float64)
compiled_samples = np.asarray(crand(key), np.float64)
uncompiled_samples = np.asarray(rand(key()), np.float64)
compiled_samples = np.asarray(crand(key()), np.float64)
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
for samples in [uncompiled_samples, compiled_samples]:
@ -895,17 +895,17 @@ class LaxRandomTest(jtu.JaxTestCase):
def testRandomBroadcast(self):
"""Issue 4033"""
# test for broadcast issue in https://github.com/google/jax/issues/4033
key = self.make_key(0)
key = lambda: self.make_key(0)
shape = (10, 2)
with jax.numpy_rank_promotion('allow'):
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
x2 = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
x1 = random.uniform(key(), shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
x2 = random.randint(key(), shape, jnp.array([0, 1]), jnp.array([1, 2]))
assert x1.shape == shape
assert x2.shape == shape
def testMaxwellSample(self):
num_samples = 10**5
rng = self.make_key(0)
rng = lambda: self.make_key(0)
rand = lambda x: random.maxwell(x, (num_samples, ))
crand = jax.jit(rand)
@ -913,8 +913,8 @@ class LaxRandomTest(jtu.JaxTestCase):
loc = jtu.to_default_dtype(scipy.stats.maxwell.mean())
std = jtu.to_default_dtype(scipy.stats.maxwell.std())
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
@ -928,7 +928,7 @@ class LaxRandomTest(jtu.JaxTestCase):
('test2', 2.0, 3.0))
def testWeibullSample(self, concentration, scale):
num_samples = 10**5
rng = self.make_key(0)
rng = lambda: self.make_key(0)
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
crand = jax.jit(rand)
@ -936,8 +936,8 @@ class LaxRandomTest(jtu.JaxTestCase):
loc = jtu.to_default_dtype(scipy.stats.weibull_min.mean(c=concentration, scale=scale))
std = jtu.to_default_dtype(scipy.stats.weibull_min.std(c=concentration, scale=scale))
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
@ -952,17 +952,17 @@ class LaxRandomTest(jtu.JaxTestCase):
('test2', 2.0, 3.0))
def testDoublesidedMaxwellSample(self, loc, scale):
num_samples = 10**4
rng = self.make_key(0)
rng = lambda: self.make_key(0)
rand = lambda key: random.double_sided_maxwell(
rng, loc, scale, (num_samples,))
rng(), loc, scale, (num_samples,))
crand = jax.jit(rand)
mean = loc
std = np.sqrt(3.) * scale
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
# Compute the double sided maxwell CDF through the one sided maxwell cdf.
# This is done as follows:
@ -989,14 +989,14 @@ class LaxRandomTest(jtu.JaxTestCase):
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
def testRadamacher(self):
rng = self.make_key(0)
rng = lambda: self.make_key(0)
num_samples = 10**5
rand = lambda x: random.rademacher(x, (num_samples,))
crand = jax.jit(rand)
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
unique_values, counts = np.unique(samples, return_counts=True)
@ -1052,24 +1052,25 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
key = self.make_key(1701)
key = lambda: self.make_key(1701)
shape = (10,)
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
expected = random.randint(key, shape, min, max + 1, dtype)
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
expected = random.randint(key(), shape, min, max + 1, dtype)
self.assertArraysEqual(expected, random.randint(key(), shape, min - 12345, max + 12345, dtype))
else:
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
self.assertRaises(OverflowError, random.randint, key(), shape, min - 12345, max + 12345, dtype)
def test_randint_out_of_range(self):
key = self.make_key(0)
r = random.randint(key, (10,), 255, 256, np.uint8)
self.assertAllClose(r, jnp.full_like(r, 255))
key = self.make_key(0)
r = random.randint(key, (1000,), -128, 128, np.int8)
self.assertGreater((r == -128).sum(), 0)
self.assertGreater((r == 127).sum(), 0)
key = self.make_key(0)
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
self.assertGreater((r == 0).sum(), 0)
self.assertGreater((r == 255).sum(), 0)
@ -1103,14 +1104,14 @@ class LaxRandomTest(jtu.JaxTestCase):
df = [0.2, 1., 10., 100.],
dtype=jtu.dtypes.floating)
def testChisquare(self, df, dtype):
key = self.make_key(1)
key = lambda: self.make_key(1)
def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, df)
compiled_samples = crand(key, df)
uncompiled_samples = rand(key(), df)
compiled_samples = crand(key(), df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
@ -1120,12 +1121,12 @@ class LaxRandomTest(jtu.JaxTestCase):
dfden = [1. ,2., 10., 100.],
dtype=jtu.dtypes.floating)
def testF(self, dfnum, dfden, dtype):
key = self.make_key(9)
key = lambda: self.make_key(9)
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
@ -1134,12 +1135,12 @@ class LaxRandomTest(jtu.JaxTestCase):
scale= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testRayleigh(self, scale, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf)
@ -1148,12 +1149,12 @@ class LaxRandomTest(jtu.JaxTestCase):
mean= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testWald(self, mean, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
@ -1162,12 +1163,12 @@ class LaxRandomTest(jtu.JaxTestCase):
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
def testGeometric(self, p, dtype):
key = self.make_key(1)
key = lambda: self.make_key(1)
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
@ -1181,13 +1182,13 @@ class LaxRandomTest(jtu.JaxTestCase):
right= [10., 20., 30., 40.],
dtype= jtu.dtypes.floating)
def testTriangular(self, left, mode, right, dtype):
key = self.make_key(1)
key = lambda: self.make_key(1)
rand = lambda key: random.triangular(key, left, mode, right, shape=(10000,),
dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang(
@ -1197,12 +1198,12 @@ class LaxRandomTest(jtu.JaxTestCase):
sigma = [0.2, 0.5, 1., 2.],
dtype=jtu.dtypes.floating)
def testLogNormal(self, sigma, dtype):
key = self.make_key(0)
key = lambda: self.make_key(0)
rand = lambda key: random.lognormal(key, sigma, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)
@ -1212,11 +1213,11 @@ class LaxRandomTest(jtu.JaxTestCase):
p= [0.1, 0.3, 0.5, 0.7, 0.9],
dtype= jtu.dtypes.floating)
def testBinomialSample(self, n, p, dtype):
key = self.make_key(12)
key = lambda: self.make_key(12)
rand = lambda key: random.binomial(key, n, p, shape=(12000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
pmf = lambda x: scipy.stats.binom(n, p).pmf(x)
@ -1227,52 +1228,52 @@ class LaxRandomTest(jtu.JaxTestCase):
check_dtypes=False)
def testBinomialCornerCases(self):
key = self.make_key(0)
key = lambda: self.make_key(0)
# corner case n
n = jnp.array([-1, 0, jnp.nan, jnp.inf])
samples1 = random.binomial(key, n, 0.5, shape=(4,))
samples1 = random.binomial(key(), n, 0.5, shape=(4,))
# corner case p
p = jnp.array([jnp.nan, 0, -0.1, 1.1])
samples2 = random.binomial(key, 5, p, shape=(4,))
samples2 = random.binomial(key(), 5, p, shape=(4,))
# corner case n and p
# expect nan or illegal will lead to nan
n_cc = jnp.array([jnp.inf, -1, jnp.inf])
p_cc = jnp.array([jnp.nan, jnp.nan, -0.1])
samples3 = random.binomial(key, n_cc, p_cc, shape=(3,))
samples3 = random.binomial(key(), n_cc, p_cc, shape=(3,))
self.assertArraysAllClose(samples1, jnp.array([jnp.nan, 0., jnp.nan, jnp.inf]), check_dtypes=False)
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
def test_batched_key_warnings(self):
keys = jax.random.split(self.make_key(0))
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
# Check a handful of functions that are expected to warn.
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
jax.random.bits(keys, shape=(2,))
jax.random.bits(keys(), shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
jax.random.chisquare(keys, 1.0, shape=(2,))
jax.random.chisquare(keys(), 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
jax.random.dirichlet(keys, jnp.arange(2.0), shape=(2,))
jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
jax.random.gamma(keys, 1.0, shape=(2,))
jax.random.gamma(keys(), 1.0, shape=(2,))
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
jax.random.loggamma(keys, 1.0, shape=(2,))
jax.random.loggamma(keys(), 1.0, shape=(2,))
# Other functions should error; test a few cases.
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
jax.random.fold_in(keys, 0)
jax.random.fold_in(keys(), 0)
with self.assertRaisesRegex(ValueError, msg.format('split')):
jax.random.split(keys)
jax.random.split(keys())
# Some shouldn't error or warn
with self.assertNoWarnings():
jax.random.key_data(keys)
jax.random.key_impl(keys)
jax.random.key_data(keys())
jax.random.key_impl(keys())
threefry_seed = prng_internal.threefry_seed
@ -1323,28 +1324,28 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.make_key(73), 2)
keys = lambda: random.split(self.make_key(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
out = vmap(lambda i: random.fold_in(keys()[0], i))(msgs)
self.assertEqual(out.shape, (3,))
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys())
self.assertEqual(out.shape, (2,))
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
out = vmap(random.fold_in, in_axes=(None, 0))(keys()[0], msgs)
self.assertEqual(out.shape, (3,))
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
out = vmap(random.fold_in, in_axes=(0, None))(keys(), msgs[0])
self.assertEqual(out.shape, (2,))
# vmap all
msgs = jnp.arange(2)
out = vmap(random.fold_in)(keys, msgs)
out = vmap(random.fold_in)(keys(), msgs)
self.assertEqual(out.shape, (2,))
# nested vmap
keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3))
keys = lambda: random.split(self.make_key(73), 2 * 3).reshape((2, 3))
msgs = jnp.arange(2 * 3).reshape((2, 3))
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys(), msgs.T)
self.assertEqual(out.shape, (2, 3))
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys, msgs.T)
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
self.assertEqual(out.shape, (3, 2))
def test_vmap_split_mapped_key(self):
@ -1381,6 +1382,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10, *key.shape))
@jax.enable_key_reuse_checks(False)
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.make_key(73), 2)
@ -1396,6 +1398,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
self.assertEqual(out.shape, keys.shape)
@jax.enable_key_reuse_checks(False)
def test_vmap_split_not_mapped_key(self):
key = self.make_key(73)
single_split_key = random.split(key)

View File

@ -253,22 +253,22 @@ class PrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBits(self, make_key):
# Test specific outputs to ensure consistent random values between JAX versions.
key = make_key(1701)
seed = 1701
bits8 = random.bits(key, (3,), 'uint8')
bits8 = random.bits(make_key(seed), (3,), 'uint8')
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits16 = random.bits(key, (3,), 'uint16')
bits16 = random.bits(make_key(seed), (3,), 'uint16')
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits32 = random.bits(key, (3,), 'uint32')
bits32 = random.bits(make_key(seed), (3,), 'uint32')
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random.bits(key, (3,), 'uint64')
bits64 = random.bits(make_key(seed), (3,), 'uint64')
if config.enable_x64.value:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
@ -287,23 +287,23 @@ class PrngTest(jtu.JaxTestCase):
dtype = jnp.dtype(f'uint{width}')
return jax.random.bits(key, shape, dtype)
with jax.default_prng_impl(prng_name):
key = make_key(1701)
seed = 1701
bits8 = random_bits(key, 8, (3,))
with jax.default_prng_impl(prng_name):
bits8 = random_bits(make_key(seed), 8, (3,))
self.assertEqual(bits8.shape, (3,))
self.assertEqual(bits8.dtype, np.dtype('uint8'))
bits16 = random_bits(key, 16, (3,))
bits16 = random_bits(make_key(seed), 16, (3,))
self.assertEqual(bits16.shape, (3,))
self.assertEqual(bits16.dtype, np.dtype('uint16'))
bits32 = random_bits(key, 32, (3,))
bits32 = random_bits(make_key(seed), 32, (3,))
self.assertEqual(bits32.shape, (3,))
self.assertEqual(bits32.dtype, np.dtype('uint32'))
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
bits64 = random_bits(make_key(seed), 64, (3,))
expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32')
self.assertEqual(bits64.shape, (3,))
self.assertEqual(bits64.dtype, expected_dtype)
@ -319,9 +319,8 @@ class PrngTest(jtu.JaxTestCase):
return jax.random.bits(key, shape, dtype)
N = 10
key = make_key(1701)
nbits = [8, 16, 32]
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits = [random_bits(make_key(1701), n, (N * 64 // n,)) for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
@ -361,31 +360,30 @@ class PrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testPRNGValues(self, make_key):
# Test to ensure consistent random values between JAX versions
k = make_key(0)
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
seed = 0
self.assertEqual(random.randint(make_key(seed), (3, 3), 0, 8).dtype,
dtypes.canonicalize_dtype(jnp.int_))
if config.enable_x64.value:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int64'),
random.randint(make_key(seed), (3, 3), 0, 8, dtype='int64'),
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'))
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int32'),
random.randint(make_key(seed), (3, 3), 0, 8, dtype='int32'),
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'))
self.assertAllClose(
random.key_data(random.split(k, 4)),
random.key_data(random.split(make_key(seed), 4)),
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'))
self.assertAllClose(
random.key_data(random.fold_in(k, 4)),
random.key_data(random.fold_in(make_key(seed), 4)),
np.array([2285895361, 433833334], dtype='uint32'))
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
@ -645,7 +643,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def f(k):
g.append(k.dtype)
return random.split(k)
_ = jax.jit(f)(k1)
_ = jax.jit(f)(self.make_keys())
self.assertEqual(g[0], k1.dtype)
self.assertEqual(g[0], k2.dtype)
@ -670,6 +668,8 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertKeysEqual(key, copy.deepcopy(key))
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))
# TODO(jakevdp) remove this decorator when reuse checks move to C++
@jax.enable_key_reuse_checks(False)
def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key array as an argument.
@ -685,6 +685,8 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(count[0], 1)
# TODO(jakevdp) remove this decorator when reuse checks move to C++
@jax.enable_key_reuse_checks(False)
def test_cpp_dispatch_split(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key arrays as inputs and as outputs.
@ -743,17 +745,17 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_random_unwrap(self, use_internal):
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
def f(k): return unwrap(k)
k = self.make_keys(3, 4)
out = f(k)
keys = lambda: self.make_keys(3, 4)
out = f(keys())
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.jit(f)(k)
out = jax.jit(f)(keys())
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(f)(k)
out = jax.vmap(f)(keys())
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(jax.jit(f))(k)
out = jax.vmap(jax.jit(f))(keys())
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
@ -864,26 +866,25 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(ys.shape, (4, 3))
def test_gather(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(ks)
keys = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(keys)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (4,))
ks = self.make_keys(3, 4, 5)
ys = jax.jit(lambda x: x[1])(ks)
keys = lambda: self.make_keys(3, 4, 5)
ys = jax.jit(lambda x: x[1])(keys())
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 5))
ys = jax.jit(lambda x: x[1, 2:4])(ks)
ys = jax.jit(lambda x: x[1, 2:4])(keys())
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 5))
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
ys = jax.jit(lambda x: x[1, 2:4, 3])(keys())
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2,))
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(keys())
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2, 1))
@ -966,9 +967,8 @@ class KeyArrayTest(jtu.JaxTestCase):
return primal_out, tangent_out
key_dot = None
key = self.make_keys()
default_result = jax.grad(f_raw)(0.0, key)
custom_result = jax.grad(f)(0.0, key)
default_result = jax.grad(f_raw)(0.0, self.make_keys())
custom_result = jax.grad(f)(0.0, self.make_keys())
self.assertAllClose(default_result, custom_result)
self.assertEqual(key_dot.dtype, dtypes.float0)
@ -977,27 +977,28 @@ class KeyArrayTest(jtu.JaxTestCase):
key = self.make_keys()
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
key = self.make_keys()
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
def test_key_array_indexing_nd(self):
keys = self.make_keys(2, 3)
self.assertEqual(keys.shape, (2, 3))
self.assertEqual(keys[0, 0].shape, ())
self.assertEqual(keys[0, 1].shape, ())
self.assertEqual(keys[0].shape, (3,))
self.assertEqual(keys[1, :].shape, (3,))
self.assertEqual(keys[:, 1].shape, (2,))
self.assertEqual(keys[None].shape, (1, 2, 3))
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
keys = lambda: self.make_keys(2, 3)
self.assertEqual(keys().shape, (2, 3))
self.assertEqual(keys()[0, 0].shape, ())
self.assertEqual(keys()[0, 1].shape, ())
self.assertEqual(keys()[0].shape, (3,))
self.assertEqual(keys()[1, :].shape, (3,))
self.assertEqual(keys()[:, 1].shape, (2,))
self.assertEqual(keys()[None].shape, (1, 2, 3))
self.assertEqual(keys()[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys()[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys()[None, None, None, 0, None, None, None, 1].shape,
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertEqual(keys()[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys()[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, 2])
lambda: keys()[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, None, 2])
lambda: keys()[0, 1, None, 2])
def test_not_hashable(self):
key = self.make_keys()
@ -1222,13 +1223,12 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.check_against_reference(key_func, arr_func, key)
def test_concatenate(self):
key = random.key(123)
args = [random.split(k, 2) for k in random.split(key, 3)]
args = lambda: [random.split(k, 2) for k in random.split(random.key(123), 3)]
key_func = arr_func = partial(jnp.concatenate, axis=0)
self.check_shape(key_func, args)
self.check_against_reference(key_func, arr_func, args)
self.check_shape(key_func, args())
self.check_against_reference(key_func, arr_func, args())
def test_broadcast_to(self):
key = random.key(123)
@ -1259,14 +1259,16 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.check_against_reference(key_func, arr_func, key, keys)
def test_append(self):
key = random.key(123)
keys = random.split(key, 4)
key = lambda: random.key(123)
keys = lambda: random.split(random.key(123), 4)
key_func = jnp.append
arr_func = lambda keys, key: jnp.append(keys, key[None], axis=0)
self.check_shape(key_func, keys, key)
self.check_against_reference(key_func, arr_func, keys, key)
self.check_shape(key_func, keys(), key())
self.check_shape(arr_func, keys(), key())
with jax.enable_key_reuse_checks(False):
self.check_against_reference(key_func, arr_func, keys(), key())
def test_ravel(self):
key = random.key(123)
@ -1306,13 +1308,12 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
(np.array([False, True, True]),)
])
def test_getitem(self, idx):
key = random.key(123)
keys = random.split(key, 3)
keys = lambda: random.split(random.key(123), 3)
key_func = arr_func = lambda x: x[idx]
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
self.check_shape(key_func, keys())
with jax.enable_key_reuse_checks(False):
self.check_against_reference(key_func, arr_func, keys())
@parameterized.parameters([
(0,),
@ -1321,14 +1322,14 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
(np.array([False, True, True]),)
])
def test_gather(self, idx):
key = random.key(123)
keys = random.split(key, 3)
keys = lambda: random.split(random.key(123), 3)
key_func = arr_func = lambda key: key.at[idx].get()
key_func = arr_func = lambda x: x.at[idx].get()
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
self.check_shape(key_func, keys())
with jax.enable_key_reuse_checks(False):
self.check_against_reference(key_func, arr_func, keys())
@jax.enable_key_reuse_checks(False)
def test_equality(self):
key = random.key(123)
key2 = random.key(456)
@ -1354,13 +1355,13 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
(np.array([False, True, True]),)
])
def test_scatter(self, idx):
key = random.key(123)
keys = random.split(key, 3)
key = lambda: random.key(123)
keys = lambda: random.split(key(), 3)
key_func = arr_func = lambda x, y: x.at[idx].set(y)
key_func = arr_func = lambda k1, k2: k1.at[idx].set(k2)
self.check_shape(key_func, keys, key)
self.check_against_reference(key_func, arr_func, keys, key)
self.check_shape(key_func, keys(), key())
self.check_against_reference(key_func, arr_func, keys(), key())
def test_errors(self):
key = random.key(123)

View File

@ -1475,17 +1475,18 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
ndim = shape[0] if len(shape) > 1 else 1
args = args_maker()
func = partial(resample, shape=())
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args)
with jax.enable_key_reuse_checks(False):
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args_maker())
assert result.shape == (ndim,)
func = partial(resample, shape=(4,))
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args)
with jax.enable_key_reuse_checks(False):
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args_maker())
assert result.shape == (ndim, 4)
@jtu.sample_product(

View File

@ -29,6 +29,7 @@ from jax._src import core
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.util import tuple_insert
import jax.numpy as jnp
@ -1696,8 +1697,8 @@ if CAN_USE_HYPOTHESIS:
y, impl_vjp = jax.vjp(impl, x)
y_ref, ref_vjp = jax.vjp(ref, x)
self.assertAllClose(y, y_ref)
t = random.normal(k2, x.shape)
y2 = random.normal(k1, y.shape)
t = random.normal(prng.reuse_key(k2), x.shape)
y2 = random.normal(prng.reuse_key(k1), y.shape)
self.assertAllClose(impl_vjp(t), ref_vjp(t))
# Second order
@ -1713,7 +1714,7 @@ if CAN_USE_HYPOTHESIS:
(x,), impl_vjp2 = jax.vjp(impl_vjp, t2)
(x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2)
self.assertAllClose(x, x_ref)
y2 = random.normal(k1, y.shape)
y2 = random.normal(prng.reuse_key(k1), y.shape)
self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,)))
if __name__ == '__main__':

View File

@ -112,6 +112,7 @@ class X64ContextTests(jtu.JaxTestCase):
self.assertEqual(x32.result(), jnp.int32)
@jax.legacy_prng_key('allow')
@jax.enable_key_reuse_checks(False)
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype float64 is not available")
def test_jit_cache(self):