mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] add eager checks
This commit is contained in:
parent
087f99a31c
commit
d08e9a03d8
@ -932,8 +932,11 @@ enable_checks = define_bool_state(
|
||||
enable_key_reuse_checks = define_bool_state(
|
||||
name='jax_enable_key_reuse_checks',
|
||||
default=False,
|
||||
help="Turn on experimental key reuse checking."
|
||||
)
|
||||
help=('Turn on experimental key reuse checking. With this configuration enabled,'
|
||||
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'
|
||||
' usage tracked, and incorrect reuse of a previously-used key will lead to'
|
||||
' an error. Currently enabling this leads to a small Python overhead on'
|
||||
' every call to a JIT-compiled function with keys as inputs or outputs.'))
|
||||
|
||||
check_tracer_leaks = define_bool_state(
|
||||
name='jax_check_tracer_leaks',
|
||||
|
@ -33,6 +33,7 @@ from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
@ -181,8 +182,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
|
||||
def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
|
||||
) -> Optional[pxla.MeshExecutableFastpathData]:
|
||||
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
use_fastpath = (
|
||||
executable is not None and
|
||||
isinstance(executable, pxla.MeshExecutable) and
|
||||
@ -191,14 +191,19 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
|
||||
not executable.unsafe_call.ordered_effects and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and
|
||||
# no attr state effects
|
||||
not attrs_tracked
|
||||
not attrs_tracked and
|
||||
# no prng reuse checking
|
||||
not (config.enable_key_reuse_checks.value and any(
|
||||
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
|
||||
for arg in (*args_flat, *out_flat)
|
||||
))
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
out_avals = [o.aval for o in out_reflattened]
|
||||
out_committed = [o._committed for o in out_reflattened]
|
||||
kept_var_bitvec = [i in executable._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = pxla.MeshExecutableFastpathData(
|
||||
|
@ -153,12 +153,14 @@ class PRNGKeyArray(jax.Array):
|
||||
|
||||
_impl: PRNGImpl
|
||||
_base_array: typing.Array
|
||||
_consumed: bool | np.ndarray # Used in jax.experimental.key_reuse.
|
||||
|
||||
def __init__(self, impl, key_data: Any):
|
||||
assert not isinstance(key_data, core.Tracer)
|
||||
_check_prng_key_data(impl, key_data)
|
||||
self._impl = impl
|
||||
self._base_array = key_data
|
||||
self._consumed = False # TODO(jakevdp): default to True here?
|
||||
|
||||
def block_until_ready(self):
|
||||
_ = self._base_array.block_until_ready()
|
||||
@ -269,7 +271,9 @@ class PRNGKeyArray(jax.Array):
|
||||
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
||||
|
||||
def copy(self):
|
||||
return self.__class__(self._impl, self._base_array.copy())
|
||||
out = self.__class__(self._impl, self._base_array.copy())
|
||||
out._consumed = self._consumed # TODO(jakevdp): is this correct?
|
||||
return out
|
||||
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
__array_priority__ = 100
|
||||
|
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
from typing import NamedTuple
|
||||
from jax import core
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import prng
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -51,6 +52,33 @@ class KeyReuseSignature(NamedTuple):
|
||||
sources: list[Source]
|
||||
forwards: list[Forward] = []
|
||||
|
||||
def check_signature(self, *args, jaxpr=None):
|
||||
for sink in self.sinks:
|
||||
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
|
||||
continue
|
||||
if np.any(args[sink.idx]._consumed & sink.mask):
|
||||
msg = f"Previously-consumed key at index {sink.idx} passed to function"
|
||||
if jaxpr:
|
||||
msg += f"\n{jaxpr=}"
|
||||
raise KeyReuseError(msg)
|
||||
|
||||
def update_consumption(self, args_in, args_out):
|
||||
for sink in self.sinks:
|
||||
arg = args_in[sink.idx]
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = arg._consumed | sink.mask
|
||||
for arg in args_out:
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = True
|
||||
for source in self.sources:
|
||||
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
|
||||
args_out[source.idx]._consumed = ~np.asarray(source.mask)
|
||||
for forward in self.forwards:
|
||||
arg_in = args_in[forward.in_idx]
|
||||
arg_out = args_out[forward.out_idx]
|
||||
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
|
||||
arg_out._consumed = arg_in._consumed
|
||||
|
||||
|
||||
class KeyReuseError(RuntimeError):
|
||||
pass
|
||||
|
@ -15,13 +15,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
from functools import partial, reduce, wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
@ -32,11 +33,13 @@ from jax._src import util
|
||||
from jax._src.ad_checkpoint import remat_p
|
||||
from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.util import weakref_lru_cache
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
consume_p, assert_consumed_value_p, KeyReuseError,
|
||||
Sink, Source, Forward, KeyReuseSignature
|
||||
)
|
||||
from jax.experimental.shard_map import shard_map_p
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -77,7 +80,9 @@ key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], []
|
||||
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], [])
|
||||
key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], [])
|
||||
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([], [], [])
|
||||
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)])
|
||||
|
||||
# Rules which require more dynamic logic.
|
||||
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
|
||||
@ -91,6 +96,7 @@ def unknown_signature(eqn):
|
||||
sources=[],
|
||||
)
|
||||
|
||||
@weakref_lru_cache
|
||||
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
"""Parse the jaxpr to determine key reuse signature"""
|
||||
consumed: dict[core.Atom, bool | np.ndarray] = {}
|
||||
@ -219,6 +225,11 @@ def _pjit_key_type_signature(eqn):
|
||||
|
||||
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
|
||||
|
||||
def _shard_map_type_signature(eqn):
|
||||
return get_jaxpr_type_signature(eqn.params['jaxpr'])
|
||||
|
||||
key_reuse_signatures_dynamic[shard_map_p] = _shard_map_type_signature
|
||||
|
||||
def _cond_key_type_signature(eqn):
|
||||
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
|
||||
sinks = defaultdict(list)
|
||||
@ -318,3 +329,33 @@ def _remat_key_type_signature(eqn):
|
||||
return get_jaxpr_type_signature(eqn.params['jaxpr'])
|
||||
|
||||
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
|
||||
|
||||
|
||||
# TODO(jakevdp): when we integrate key reuse checks more tightly with JAX,
|
||||
# we should move this logic directly into each primitive impl.
|
||||
def key_reuse_impl_rule(prim, original_rule):
|
||||
@wraps(original_rule)
|
||||
def key_reuse_impl(*args, **kwargs):
|
||||
if config.enable_key_reuse_checks.value:
|
||||
if prim == pjit.pjit_p:
|
||||
jaxpr = kwargs['jaxpr'].jaxpr
|
||||
signature = get_jaxpr_type_signature(jaxpr)
|
||||
elif prim in key_reuse_signatures:
|
||||
jaxpr = prim
|
||||
signature = key_reuse_signatures[prim]
|
||||
elif prim in key_reuse_signatures_dynamic:
|
||||
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
|
||||
signature = get_jaxpr_type_signature(jaxpr)
|
||||
else:
|
||||
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
|
||||
signature.check_signature(*args, jaxpr=jaxpr)
|
||||
result = original_rule(*args, **kwargs)
|
||||
signature.update_consumption(args, result if prim.multiple_results else [result])
|
||||
return result
|
||||
else:
|
||||
return original_rule(*args, **kwargs)
|
||||
return key_reuse_impl
|
||||
|
||||
|
||||
for prim in (*key_reuse_signatures, *key_reuse_signatures_dynamic):
|
||||
prim.impl = key_reuse_impl_rule(prim, prim.impl) # type: ignore[method-assign]
|
||||
|
@ -574,6 +574,33 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
self.check_key_reuse(jax.grad(f_good), x, key)
|
||||
|
||||
|
||||
class KeyReuseEager(jtu.JaxTestCase):
|
||||
jit_msg = "Previously-consumed key at index 0 passed to function"
|
||||
bits_msg = "In random_bits, key values a are already consumed."
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
with jax.disable_jit():
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
_ = jax.random.bits(key)
|
||||
|
||||
def test_simple_key_reuse_jit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
_ = jax.random.bits(key)
|
||||
|
||||
def test_key_reuse_within_jit(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
return jax.random.bits(key) + jax.random.bits(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
|
||||
f()
|
||||
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_checks=False)
|
||||
class KeyReuseGlobalFlags(jtu.JaxTestCase):
|
||||
def test_key_reuse_flag(self):
|
||||
|
@ -1737,10 +1737,10 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
|
||||
return res
|
||||
|
||||
key = random.PRNGKey(1)
|
||||
x = random.normal(key, (80, 50))
|
||||
key = lambda: random.PRNGKey(1)
|
||||
x = random.normal(key(), (80, 50))
|
||||
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
|
||||
y = random.normal(key, (10, 50, 1))
|
||||
y = random.normal(key(), (10, 50, 1))
|
||||
result = batched_mvm(y)
|
||||
expected = jnp.einsum('ij,njk->nik', x, y)
|
||||
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3,
|
||||
|
@ -126,12 +126,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testRngUniform(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.uniform(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
|
||||
@ -142,12 +142,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
lo = 5
|
||||
hi = 10
|
||||
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertTrue(np.all(lo <= samples))
|
||||
@ -155,12 +155,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testNormal(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
||||
@ -174,12 +174,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=complex_dtypes)
|
||||
def testNormalComplex(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
|
||||
@ -188,12 +188,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testTruncatedNormal(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
min_val = np.min(uncompiled_samples)
|
||||
max_val = np.max(uncompiled_samples)
|
||||
@ -204,15 +204,15 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
|
||||
def testShuffle(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
x = np.arange(100).astype(dtype)
|
||||
rand = lambda key: random.shuffle(key, x)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
with self.assertWarns((DeprecationWarning, FutureWarning)):
|
||||
perm1 = rand(key)
|
||||
perm1 = rand(key())
|
||||
with self.assertWarns((DeprecationWarning, FutureWarning)):
|
||||
perm2 = crand(key)
|
||||
perm2 = crand(key())
|
||||
|
||||
self.assertAllClose(perm1, perm2)
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
@ -238,7 +238,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
np_choice = np.random.default_rng(0).choice
|
||||
p_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
is_range = type(input_range_or_shape) is int
|
||||
x = (input_range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(math.prod(
|
||||
@ -250,7 +250,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
else:
|
||||
p = None
|
||||
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
|
||||
sample = rand(key, x)
|
||||
sample = rand(key(), x)
|
||||
if not is_range:
|
||||
self.assertEqual(dtype, sample.dtype)
|
||||
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
|
||||
@ -263,9 +263,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
|
||||
return jnp.take(x, ind, axis)
|
||||
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
|
||||
self.assertArraysEqual(sample, rand(key, np.array(x)))
|
||||
self.assertArraysEqual(sample, rand(key(), np.array(x)))
|
||||
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
|
||||
'x' if is_range else None)(key, x))
|
||||
'x' if is_range else None)(key(), x))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(range_or_shape=range_or_shape, axis=axis)
|
||||
@ -278,7 +278,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
independent=[True, False],
|
||||
)
|
||||
def testPermutation(self, dtype, range_or_shape, axis, independent):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
is_range = type(range_or_shape) is int
|
||||
x = (range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(
|
||||
@ -286,7 +286,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
shape = ((range_or_shape,) if is_range else range_or_shape)
|
||||
x_ = np.copy(x)
|
||||
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
|
||||
perm = rand(key, x)
|
||||
perm = rand(key(), x)
|
||||
if shape[axis] >= 10:
|
||||
self.assertFalse(np.all(perm == x)) # seems unlikely!
|
||||
arr = np.arange(x) if is_range else x
|
||||
@ -302,9 +302,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
|
||||
self.assertArraysEqual(x_, x)
|
||||
self.assertArraysEqual(perm, rand(key, np.array(x)))
|
||||
self.assertArraysEqual(perm, rand(key(), np.array(x)))
|
||||
self.assertArraysEqual(perm, jax.jit(rand, static_argnames=
|
||||
'x' if is_range else None)(key, x))
|
||||
'x' if is_range else None)(key(), x))
|
||||
|
||||
def testPermutationErrors(self):
|
||||
key = self.make_key(0)
|
||||
@ -320,13 +320,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testBernoulli(self, p, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
rand = lambda key, p: random.bernoulli(key, p, (10000,))
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
uncompiled_samples = rand(key(), p)
|
||||
compiled_samples = crand(key(), p)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
|
||||
@ -344,7 +344,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
logits = np.log(p) - 42 # test unnormalized
|
||||
out_shape = tuple(np.delete(logits.shape, axis))
|
||||
@ -352,8 +352,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
rand = partial(random.categorical, shape=shape, axis=axis)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, logits)
|
||||
compiled_samples = crand(key, logits)
|
||||
uncompiled_samples = rand(key(), logits)
|
||||
compiled_samples = crand(key(), logits)
|
||||
|
||||
if axis < 0:
|
||||
axis += len(logits.shape)
|
||||
@ -384,12 +384,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not config.enable_x64.value:
|
||||
raise SkipTest("skip test except on X64")
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, a, b)
|
||||
compiled_samples = crand(key, a, b)
|
||||
uncompiled_samples = rand(key(), a, b)
|
||||
compiled_samples = crand(key(), a, b)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
|
||||
@ -412,12 +412,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testCauchy(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
|
||||
@ -428,13 +428,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
num_samples = 10000
|
||||
rand = lambda key, alpha: random.dirichlet(key, alpha, (num_samples,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, alpha)
|
||||
compiled_samples = crand(key, alpha)
|
||||
uncompiled_samples = rand(key(), alpha)
|
||||
compiled_samples = crand(key(), alpha)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertAllClose(samples.sum(-1), np.ones(num_samples, dtype=dtype))
|
||||
@ -462,12 +462,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testExponential(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.exponential(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
|
||||
@ -479,15 +479,15 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu") # low accuracy leads to failures.
|
||||
def testGammaVsLogGamma(self, a, dtype):
|
||||
# Test that gamma() and loggamma() produce equivalent samples.
|
||||
key = self.make_key(0)
|
||||
rand_gamma = lambda key, a: random.gamma(key, a, (100,), dtype)
|
||||
rand_loggamma = lambda key, a: random.loggamma(key, a, (100,), dtype)
|
||||
crand_loggamma = jax.jit(rand_loggamma)
|
||||
tol = {np.float32: 1E-6, np.float64: 1E-12}
|
||||
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)),
|
||||
key = lambda: self.make_key(0)
|
||||
self.assertAllClose(rand_gamma(key(), a), jnp.exp(rand_loggamma(key(), a)),
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)),
|
||||
self.assertAllClose(rand_gamma(key(), a), jnp.exp(crand_loggamma(key(), a)),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -495,12 +495,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGamma(self, a, dtype):
|
||||
key = self.make_key(1)
|
||||
key = lambda: self.make_key(1)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, a)
|
||||
compiled_samples = crand(key, a)
|
||||
uncompiled_samples = rand(key(), a)
|
||||
compiled_samples = crand(key(), a)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
|
||||
@ -515,13 +515,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
|
||||
)
|
||||
def testGammaGrad(self, log_space, alpha):
|
||||
rng = self.make_key(0)
|
||||
rng = lambda: self.make_key(0)
|
||||
alphas = np.full((100,), alpha)
|
||||
z = random.gamma(rng, alphas)
|
||||
z = random.gamma(rng(), alphas)
|
||||
if log_space:
|
||||
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
|
||||
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng(), x)).sum())(alphas)
|
||||
else:
|
||||
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
|
||||
actual_grad = jax.grad(lambda x: random.gamma(rng(), x).sum())(alphas)
|
||||
|
||||
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
|
||||
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
|
||||
@ -548,12 +548,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
|
||||
)
|
||||
def testPoisson(self, lam, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, lam)
|
||||
compiled_samples = crand(key, lam)
|
||||
uncompiled_samples = rand(key(), lam)
|
||||
compiled_samples = crand(key(), lam)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
|
||||
@ -594,36 +594,36 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
def testGumbel(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLaplace(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.laplace(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLogistic(self, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.logistic(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
||||
@ -651,12 +651,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGeneralizedNormal(self, p, shape, dtype):
|
||||
key = self.make_key(2)
|
||||
key = lambda: self.make_key(2)
|
||||
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
uncompiled_samples = rand(key(), p)
|
||||
compiled_samples = crand(key(), p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertEqual(samples.shape, shape)
|
||||
self.assertEqual(samples.dtype, dtype)
|
||||
@ -669,12 +669,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testGeneralizedNormalKS(self, p, shape, dtype):
|
||||
self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws
|
||||
"sensitive to random key - https://github.com/google/jax/issues/18941")
|
||||
key = self.make_key(2)
|
||||
key = lambda: self.make_key(2)
|
||||
rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
uncompiled_samples = rand(key(), p)
|
||||
compiled_samples = crand(key(), p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
|
||||
|
||||
@ -686,11 +686,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
||||
def testBall(self, d, p, shape, dtype):
|
||||
key = self.make_key(123)
|
||||
key = lambda: self.make_key(123)
|
||||
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
uncompiled_samples = rand(key(), p)
|
||||
compiled_samples = crand(key(), p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertEqual(samples.shape, (*shape, d))
|
||||
self.assertEqual(samples.dtype, dtype)
|
||||
@ -706,11 +706,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testBallKS(self, d, p, shape, dtype):
|
||||
self.skipTest(
|
||||
"sensitive to random key - https://github.com/google/jax/issues/18932")
|
||||
key = self.make_key(123)
|
||||
key = lambda: self.make_key(123)
|
||||
rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
uncompiled_samples = rand(key(), p)
|
||||
compiled_samples = crand(key(), p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
|
||||
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
|
||||
@ -720,12 +720,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testPareto(self, b, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, b)
|
||||
compiled_samples = crand(key, b)
|
||||
uncompiled_samples = rand(key(), b)
|
||||
compiled_samples = crand(key(), b)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
|
||||
@ -742,12 +742,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = self.make_key(1)
|
||||
key = lambda: self.make_key(1)
|
||||
rand = lambda key, df: random.t(key, df, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, df)
|
||||
compiled_samples = crand(key, df)
|
||||
uncompiled_samples = rand(key(), df)
|
||||
compiled_samples = crand(key(), df)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
|
||||
@ -763,14 +763,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
cov_factor = r.randn(dim, dim)
|
||||
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
|
||||
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
|
||||
shape=(10000,), method=method)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
uncompiled_samples = np.asarray(rand(key), np.float64)
|
||||
compiled_samples = np.asarray(crand(key), np.float64)
|
||||
uncompiled_samples = np.asarray(rand(key()), np.float64)
|
||||
compiled_samples = np.asarray(crand(key()), np.float64)
|
||||
|
||||
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
@ -895,17 +895,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testRandomBroadcast(self):
|
||||
"""Issue 4033"""
|
||||
# test for broadcast issue in https://github.com/google/jax/issues/4033
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
shape = (10, 2)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
|
||||
x2 = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
|
||||
x1 = random.uniform(key(), shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
|
||||
x2 = random.randint(key(), shape, jnp.array([0, 1]), jnp.array([1, 2]))
|
||||
assert x1.shape == shape
|
||||
assert x2.shape == shape
|
||||
|
||||
def testMaxwellSample(self):
|
||||
num_samples = 10**5
|
||||
rng = self.make_key(0)
|
||||
rng = lambda: self.make_key(0)
|
||||
|
||||
rand = lambda x: random.maxwell(x, (num_samples, ))
|
||||
crand = jax.jit(rand)
|
||||
@ -913,8 +913,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
loc = jtu.to_default_dtype(scipy.stats.maxwell.mean())
|
||||
std = jtu.to_default_dtype(scipy.stats.maxwell.std())
|
||||
|
||||
uncompiled_samples = rand(rng)
|
||||
compiled_samples = crand(rng)
|
||||
uncompiled_samples = rand(rng())
|
||||
compiled_samples = crand(rng())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
# Check first and second moments.
|
||||
@ -928,7 +928,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
('test2', 2.0, 3.0))
|
||||
def testWeibullSample(self, concentration, scale):
|
||||
num_samples = 10**5
|
||||
rng = self.make_key(0)
|
||||
rng = lambda: self.make_key(0)
|
||||
|
||||
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
|
||||
crand = jax.jit(rand)
|
||||
@ -936,8 +936,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
loc = jtu.to_default_dtype(scipy.stats.weibull_min.mean(c=concentration, scale=scale))
|
||||
std = jtu.to_default_dtype(scipy.stats.weibull_min.std(c=concentration, scale=scale))
|
||||
|
||||
uncompiled_samples = rand(rng)
|
||||
compiled_samples = crand(rng)
|
||||
uncompiled_samples = rand(rng())
|
||||
compiled_samples = crand(rng())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
# Check first and second moments.
|
||||
@ -952,17 +952,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
('test2', 2.0, 3.0))
|
||||
def testDoublesidedMaxwellSample(self, loc, scale):
|
||||
num_samples = 10**4
|
||||
rng = self.make_key(0)
|
||||
rng = lambda: self.make_key(0)
|
||||
|
||||
rand = lambda key: random.double_sided_maxwell(
|
||||
rng, loc, scale, (num_samples,))
|
||||
rng(), loc, scale, (num_samples,))
|
||||
crand = jax.jit(rand)
|
||||
|
||||
mean = loc
|
||||
std = np.sqrt(3.) * scale
|
||||
|
||||
uncompiled_samples = rand(rng)
|
||||
compiled_samples = crand(rng)
|
||||
uncompiled_samples = rand(rng())
|
||||
compiled_samples = crand(rng())
|
||||
|
||||
# Compute the double sided maxwell CDF through the one sided maxwell cdf.
|
||||
# This is done as follows:
|
||||
@ -989,14 +989,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
|
||||
|
||||
def testRadamacher(self):
|
||||
rng = self.make_key(0)
|
||||
rng = lambda: self.make_key(0)
|
||||
num_samples = 10**5
|
||||
|
||||
rand = lambda x: random.rademacher(x, (num_samples,))
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(rng)
|
||||
compiled_samples = crand(rng)
|
||||
uncompiled_samples = rand(rng())
|
||||
compiled_samples = crand(rng())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
unique_values, counts = np.unique(samples, return_counts=True)
|
||||
@ -1052,24 +1052,25 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def test_randint_bounds(self, dtype):
|
||||
min = np.iinfo(dtype).min
|
||||
max = np.iinfo(dtype).max
|
||||
key = self.make_key(1701)
|
||||
key = lambda: self.make_key(1701)
|
||||
shape = (10,)
|
||||
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
|
||||
expected = random.randint(key, shape, min, max + 1, dtype)
|
||||
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
|
||||
expected = random.randint(key(), shape, min, max + 1, dtype)
|
||||
self.assertArraysEqual(expected, random.randint(key(), shape, min - 12345, max + 12345, dtype))
|
||||
else:
|
||||
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
|
||||
self.assertRaises(OverflowError, random.randint, key(), shape, min - 12345, max + 12345, dtype)
|
||||
|
||||
def test_randint_out_of_range(self):
|
||||
key = self.make_key(0)
|
||||
|
||||
r = random.randint(key, (10,), 255, 256, np.uint8)
|
||||
self.assertAllClose(r, jnp.full_like(r, 255))
|
||||
|
||||
key = self.make_key(0)
|
||||
r = random.randint(key, (1000,), -128, 128, np.int8)
|
||||
self.assertGreater((r == -128).sum(), 0)
|
||||
self.assertGreater((r == 127).sum(), 0)
|
||||
|
||||
key = self.make_key(0)
|
||||
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
|
||||
self.assertGreater((r == 0).sum(), 0)
|
||||
self.assertGreater((r == 255).sum(), 0)
|
||||
@ -1103,14 +1104,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
df = [0.2, 1., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testChisquare(self, df, dtype):
|
||||
key = self.make_key(1)
|
||||
key = lambda: self.make_key(1)
|
||||
|
||||
def rand(key, df):
|
||||
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, df)
|
||||
compiled_samples = crand(key, df)
|
||||
uncompiled_samples = rand(key(), df)
|
||||
compiled_samples = crand(key(), df)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
|
||||
@ -1120,12 +1121,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dfden = [1. ,2., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testF(self, dfnum, dfden, dtype):
|
||||
key = self.make_key(9)
|
||||
key = lambda: self.make_key(9)
|
||||
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
|
||||
@ -1134,12 +1135,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
scale= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testRayleigh(self, scale, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf)
|
||||
@ -1148,12 +1149,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
mean= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testWald(self, mean, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
|
||||
@ -1162,12 +1163,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
|
||||
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
|
||||
def testGeometric(self, p, dtype):
|
||||
key = self.make_key(1)
|
||||
key = lambda: self.make_key(1)
|
||||
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
|
||||
@ -1181,13 +1182,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
right= [10., 20., 30., 40.],
|
||||
dtype= jtu.dtypes.floating)
|
||||
def testTriangular(self, left, mode, right, dtype):
|
||||
key = self.make_key(1)
|
||||
key = lambda: self.make_key(1)
|
||||
rand = lambda key: random.triangular(key, left, mode, right, shape=(10000,),
|
||||
dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang(
|
||||
@ -1197,12 +1198,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
sigma = [0.2, 0.5, 1., 2.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testLogNormal(self, sigma, dtype):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
rand = lambda key: random.lognormal(key, sigma, shape=(10000,), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)
|
||||
@ -1212,11 +1213,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
p= [0.1, 0.3, 0.5, 0.7, 0.9],
|
||||
dtype= jtu.dtypes.floating)
|
||||
def testBinomialSample(self, n, p, dtype):
|
||||
key = self.make_key(12)
|
||||
key = lambda: self.make_key(12)
|
||||
rand = lambda key: random.binomial(key, n, p, shape=(12000,), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
uncompiled_samples = rand(key())
|
||||
compiled_samples = crand(key())
|
||||
|
||||
pmf = lambda x: scipy.stats.binom(n, p).pmf(x)
|
||||
|
||||
@ -1227,52 +1228,52 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
check_dtypes=False)
|
||||
|
||||
def testBinomialCornerCases(self):
|
||||
key = self.make_key(0)
|
||||
key = lambda: self.make_key(0)
|
||||
|
||||
# corner case n
|
||||
n = jnp.array([-1, 0, jnp.nan, jnp.inf])
|
||||
samples1 = random.binomial(key, n, 0.5, shape=(4,))
|
||||
samples1 = random.binomial(key(), n, 0.5, shape=(4,))
|
||||
|
||||
# corner case p
|
||||
p = jnp.array([jnp.nan, 0, -0.1, 1.1])
|
||||
samples2 = random.binomial(key, 5, p, shape=(4,))
|
||||
samples2 = random.binomial(key(), 5, p, shape=(4,))
|
||||
|
||||
# corner case n and p
|
||||
# expect nan or illegal will lead to nan
|
||||
n_cc = jnp.array([jnp.inf, -1, jnp.inf])
|
||||
p_cc = jnp.array([jnp.nan, jnp.nan, -0.1])
|
||||
samples3 = random.binomial(key, n_cc, p_cc, shape=(3,))
|
||||
samples3 = random.binomial(key(), n_cc, p_cc, shape=(3,))
|
||||
|
||||
self.assertArraysAllClose(samples1, jnp.array([jnp.nan, 0., jnp.nan, jnp.inf]), check_dtypes=False)
|
||||
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
|
||||
|
||||
def test_batched_key_warnings(self):
|
||||
keys = jax.random.split(self.make_key(0))
|
||||
keys = lambda: jax.random.split(self.make_key(0))
|
||||
msg = "{} accepts a single key, but was given a key array of shape.*"
|
||||
|
||||
# Check a handful of functions that are expected to warn.
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('bits')):
|
||||
jax.random.bits(keys, shape=(2,))
|
||||
jax.random.bits(keys(), shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')):
|
||||
jax.random.chisquare(keys, 1.0, shape=(2,))
|
||||
jax.random.chisquare(keys(), 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')):
|
||||
jax.random.dirichlet(keys, jnp.arange(2.0), shape=(2,))
|
||||
jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('gamma')):
|
||||
jax.random.gamma(keys, 1.0, shape=(2,))
|
||||
jax.random.gamma(keys(), 1.0, shape=(2,))
|
||||
with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')):
|
||||
jax.random.loggamma(keys, 1.0, shape=(2,))
|
||||
jax.random.loggamma(keys(), 1.0, shape=(2,))
|
||||
|
||||
# Other functions should error; test a few cases.
|
||||
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
|
||||
jax.random.fold_in(keys, 0)
|
||||
jax.random.fold_in(keys(), 0)
|
||||
with self.assertRaisesRegex(ValueError, msg.format('split')):
|
||||
jax.random.split(keys)
|
||||
jax.random.split(keys())
|
||||
|
||||
# Some shouldn't error or warn
|
||||
with self.assertNoWarnings():
|
||||
jax.random.key_data(keys)
|
||||
jax.random.key_impl(keys)
|
||||
jax.random.key_data(keys())
|
||||
jax.random.key_impl(keys())
|
||||
|
||||
|
||||
threefry_seed = prng_internal.threefry_seed
|
||||
@ -1323,28 +1324,28 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.make_key(73), 2)
|
||||
keys = lambda: random.split(self.make_key(73), 2)
|
||||
msgs = jnp.arange(3)
|
||||
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
||||
out = vmap(lambda i: random.fold_in(keys()[0], i))(msgs)
|
||||
self.assertEqual(out.shape, (3,))
|
||||
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
|
||||
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys())
|
||||
self.assertEqual(out.shape, (2,))
|
||||
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
|
||||
out = vmap(random.fold_in, in_axes=(None, 0))(keys()[0], msgs)
|
||||
self.assertEqual(out.shape, (3,))
|
||||
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
||||
out = vmap(random.fold_in, in_axes=(0, None))(keys(), msgs[0])
|
||||
self.assertEqual(out.shape, (2,))
|
||||
|
||||
# vmap all
|
||||
msgs = jnp.arange(2)
|
||||
out = vmap(random.fold_in)(keys, msgs)
|
||||
out = vmap(random.fold_in)(keys(), msgs)
|
||||
self.assertEqual(out.shape, (2,))
|
||||
|
||||
# nested vmap
|
||||
keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3))
|
||||
keys = lambda: random.split(self.make_key(73), 2 * 3).reshape((2, 3))
|
||||
msgs = jnp.arange(2 * 3).reshape((2, 3))
|
||||
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
|
||||
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys(), msgs.T)
|
||||
self.assertEqual(out.shape, (2, 3))
|
||||
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys, msgs.T)
|
||||
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
|
||||
self.assertEqual(out.shape, (3, 2))
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
@ -1381,6 +1382,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
keys = random.split(key, 10)
|
||||
self.assertEqual(keys.shape, (10, *key.shape))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.make_key(73), 2)
|
||||
@ -1396,6 +1398,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
||||
self.assertEqual(out.shape, keys.shape)
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_split_not_mapped_key(self):
|
||||
key = self.make_key(73)
|
||||
single_split_key = random.split(key)
|
||||
|
@ -253,22 +253,22 @@ class PrngTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testRngRandomBits(self, make_key):
|
||||
# Test specific outputs to ensure consistent random values between JAX versions.
|
||||
key = make_key(1701)
|
||||
seed = 1701
|
||||
|
||||
bits8 = random.bits(key, (3,), 'uint8')
|
||||
bits8 = random.bits(make_key(seed), (3,), 'uint8')
|
||||
expected8 = np.array([216, 115, 43], dtype=np.uint8)
|
||||
self.assertArraysEqual(bits8, expected8)
|
||||
|
||||
bits16 = random.bits(key, (3,), 'uint16')
|
||||
bits16 = random.bits(make_key(seed), (3,), 'uint16')
|
||||
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
|
||||
self.assertArraysEqual(bits16, expected16)
|
||||
|
||||
bits32 = random.bits(key, (3,), 'uint32')
|
||||
bits32 = random.bits(make_key(seed), (3,), 'uint32')
|
||||
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits32, expected32)
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = random.bits(key, (3,), 'uint64')
|
||||
bits64 = random.bits(make_key(seed), (3,), 'uint64')
|
||||
if config.enable_x64.value:
|
||||
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
@ -287,23 +287,23 @@ class PrngTest(jtu.JaxTestCase):
|
||||
dtype = jnp.dtype(f'uint{width}')
|
||||
return jax.random.bits(key, shape, dtype)
|
||||
|
||||
with jax.default_prng_impl(prng_name):
|
||||
key = make_key(1701)
|
||||
seed = 1701
|
||||
|
||||
bits8 = random_bits(key, 8, (3,))
|
||||
with jax.default_prng_impl(prng_name):
|
||||
bits8 = random_bits(make_key(seed), 8, (3,))
|
||||
self.assertEqual(bits8.shape, (3,))
|
||||
self.assertEqual(bits8.dtype, np.dtype('uint8'))
|
||||
|
||||
bits16 = random_bits(key, 16, (3,))
|
||||
bits16 = random_bits(make_key(seed), 16, (3,))
|
||||
self.assertEqual(bits16.shape, (3,))
|
||||
self.assertEqual(bits16.dtype, np.dtype('uint16'))
|
||||
|
||||
bits32 = random_bits(key, 32, (3,))
|
||||
bits32 = random_bits(make_key(seed), 32, (3,))
|
||||
self.assertEqual(bits32.shape, (3,))
|
||||
self.assertEqual(bits32.dtype, np.dtype('uint32'))
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = random_bits(key, 64, (3,))
|
||||
bits64 = random_bits(make_key(seed), 64, (3,))
|
||||
expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32')
|
||||
self.assertEqual(bits64.shape, (3,))
|
||||
self.assertEqual(bits64.dtype, expected_dtype)
|
||||
@ -319,9 +319,8 @@ class PrngTest(jtu.JaxTestCase):
|
||||
return jax.random.bits(key, shape, dtype)
|
||||
|
||||
N = 10
|
||||
key = make_key(1701)
|
||||
nbits = [8, 16, 32]
|
||||
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
|
||||
rand_bits = [random_bits(make_key(1701), n, (N * 64 // n,)) for n in nbits]
|
||||
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
|
||||
assert np.all(rand_bits_32 == rand_bits_32[0])
|
||||
|
||||
@ -361,31 +360,30 @@ class PrngTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testPRNGValues(self, make_key):
|
||||
# Test to ensure consistent random values between JAX versions
|
||||
k = make_key(0)
|
||||
|
||||
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
|
||||
seed = 0
|
||||
self.assertEqual(random.randint(make_key(seed), (3, 3), 0, 8).dtype,
|
||||
dtypes.canonicalize_dtype(jnp.int_))
|
||||
if config.enable_x64.value:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8, dtype='int64'),
|
||||
random.randint(make_key(seed), (3, 3), 0, 8, dtype='int64'),
|
||||
np.array([[7, 2, 6],
|
||||
[2, 1, 0],
|
||||
[6, 7, 7]], dtype='int64'))
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8, dtype='int32'),
|
||||
random.randint(make_key(seed), (3, 3), 0, 8, dtype='int32'),
|
||||
np.array([[2, 1, 3],
|
||||
[6, 1, 5],
|
||||
[6, 3, 4]], dtype='int32'))
|
||||
|
||||
self.assertAllClose(
|
||||
random.key_data(random.split(k, 4)),
|
||||
random.key_data(random.split(make_key(seed), 4)),
|
||||
np.array([[2285895361, 1501764800],
|
||||
[1518642379, 4090693311],
|
||||
[ 433833334, 4221794875],
|
||||
[ 839183663, 3740430601]], dtype='uint32'))
|
||||
|
||||
self.assertAllClose(
|
||||
random.key_data(random.fold_in(k, 4)),
|
||||
random.key_data(random.fold_in(make_key(seed), 4)),
|
||||
np.array([2285895361, 433833334], dtype='uint32'))
|
||||
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
@ -645,7 +643,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def f(k):
|
||||
g.append(k.dtype)
|
||||
return random.split(k)
|
||||
_ = jax.jit(f)(k1)
|
||||
_ = jax.jit(f)(self.make_keys())
|
||||
self.assertEqual(g[0], k1.dtype)
|
||||
self.assertEqual(g[0], k2.dtype)
|
||||
|
||||
@ -670,6 +668,8 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertKeysEqual(key, copy.deepcopy(key))
|
||||
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))
|
||||
|
||||
# TODO(jakevdp) remove this decorator when reuse checks move to C++
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_cpp_dispatch_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key array as an argument.
|
||||
@ -685,6 +685,8 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
# TODO(jakevdp) remove this decorator when reuse checks move to C++
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_cpp_dispatch_split(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key arrays as inputs and as outputs.
|
||||
@ -743,17 +745,17 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def test_random_unwrap(self, use_internal):
|
||||
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
|
||||
def f(k): return unwrap(k)
|
||||
k = self.make_keys(3, 4)
|
||||
out = f(k)
|
||||
keys = lambda: self.make_keys(3, 4)
|
||||
out = f(keys())
|
||||
self.assertEqual(out.dtype, np.dtype('uint32'))
|
||||
self.assertEqual(out.shape[:2], (3, 4))
|
||||
out = jax.jit(f)(k)
|
||||
out = jax.jit(f)(keys())
|
||||
self.assertEqual(out.dtype, np.dtype('uint32'))
|
||||
self.assertEqual(out.shape[:2], (3, 4))
|
||||
out = jax.vmap(f)(k)
|
||||
out = jax.vmap(f)(keys())
|
||||
self.assertEqual(out.dtype, np.dtype('uint32'))
|
||||
self.assertEqual(out.shape[:2], (3, 4))
|
||||
out = jax.vmap(jax.jit(f))(k)
|
||||
out = jax.vmap(jax.jit(f))(keys())
|
||||
self.assertEqual(out.dtype, np.dtype('uint32'))
|
||||
self.assertEqual(out.shape[:2], (3, 4))
|
||||
|
||||
@ -864,26 +866,25 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(ys.shape, (4, 3))
|
||||
|
||||
def test_gather(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
keys = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x[1])(keys)
|
||||
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (4,))
|
||||
|
||||
ks = self.make_keys(3, 4, 5)
|
||||
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
keys = lambda: self.make_keys(3, 4, 5)
|
||||
ys = jax.jit(lambda x: x[1])(keys())
|
||||
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (4, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4])(ks)
|
||||
ys = jax.jit(lambda x: x[1, 2:4])(keys())
|
||||
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
|
||||
ys = jax.jit(lambda x: x[1, 2:4, 3])(keys())
|
||||
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2,))
|
||||
|
||||
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
|
||||
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(keys())
|
||||
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (3, 2, 1))
|
||||
|
||||
@ -966,9 +967,8 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
return primal_out, tangent_out
|
||||
|
||||
key_dot = None
|
||||
key = self.make_keys()
|
||||
default_result = jax.grad(f_raw)(0.0, key)
|
||||
custom_result = jax.grad(f)(0.0, key)
|
||||
default_result = jax.grad(f_raw)(0.0, self.make_keys())
|
||||
custom_result = jax.grad(f)(0.0, self.make_keys())
|
||||
|
||||
self.assertAllClose(default_result, custom_result)
|
||||
self.assertEqual(key_dot.dtype, dtypes.float0)
|
||||
@ -977,27 +977,28 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
key = self.make_keys()
|
||||
self.assertEqual(key.shape, ())
|
||||
self.assertEqual(key[None].shape, (1,))
|
||||
key = self.make_keys()
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
|
||||
|
||||
def test_key_array_indexing_nd(self):
|
||||
keys = self.make_keys(2, 3)
|
||||
self.assertEqual(keys.shape, (2, 3))
|
||||
self.assertEqual(keys[0, 0].shape, ())
|
||||
self.assertEqual(keys[0, 1].shape, ())
|
||||
self.assertEqual(keys[0].shape, (3,))
|
||||
self.assertEqual(keys[1, :].shape, (3,))
|
||||
self.assertEqual(keys[:, 1].shape, (2,))
|
||||
self.assertEqual(keys[None].shape, (1, 2, 3))
|
||||
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
|
||||
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
|
||||
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
|
||||
keys = lambda: self.make_keys(2, 3)
|
||||
self.assertEqual(keys().shape, (2, 3))
|
||||
self.assertEqual(keys()[0, 0].shape, ())
|
||||
self.assertEqual(keys()[0, 1].shape, ())
|
||||
self.assertEqual(keys()[0].shape, (3,))
|
||||
self.assertEqual(keys()[1, :].shape, (3,))
|
||||
self.assertEqual(keys()[:, 1].shape, (2,))
|
||||
self.assertEqual(keys()[None].shape, (1, 2, 3))
|
||||
self.assertEqual(keys()[None, None].shape, (1, 1, 2, 3))
|
||||
self.assertEqual(keys()[None, :, None].shape, (1, 2, 1, 3))
|
||||
self.assertEqual(keys()[None, None, None, 0, None, None, None, 1].shape,
|
||||
(1,) * 6)
|
||||
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
|
||||
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
|
||||
self.assertEqual(keys()[..., 1:, None].shape, (2, 2, 1))
|
||||
self.assertEqual(keys()[None, 0, ..., 1, None].shape, (1, 1))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, 2])
|
||||
lambda: keys()[0, 1, 2])
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, None, 2])
|
||||
lambda: keys()[0, 1, None, 2])
|
||||
|
||||
def test_not_hashable(self):
|
||||
key = self.make_keys()
|
||||
@ -1222,13 +1223,12 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.check_against_reference(key_func, arr_func, key)
|
||||
|
||||
def test_concatenate(self):
|
||||
key = random.key(123)
|
||||
args = [random.split(k, 2) for k in random.split(key, 3)]
|
||||
args = lambda: [random.split(k, 2) for k in random.split(random.key(123), 3)]
|
||||
|
||||
key_func = arr_func = partial(jnp.concatenate, axis=0)
|
||||
|
||||
self.check_shape(key_func, args)
|
||||
self.check_against_reference(key_func, arr_func, args)
|
||||
self.check_shape(key_func, args())
|
||||
self.check_against_reference(key_func, arr_func, args())
|
||||
|
||||
def test_broadcast_to(self):
|
||||
key = random.key(123)
|
||||
@ -1259,14 +1259,16 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.check_against_reference(key_func, arr_func, key, keys)
|
||||
|
||||
def test_append(self):
|
||||
key = random.key(123)
|
||||
keys = random.split(key, 4)
|
||||
key = lambda: random.key(123)
|
||||
keys = lambda: random.split(random.key(123), 4)
|
||||
|
||||
key_func = jnp.append
|
||||
arr_func = lambda keys, key: jnp.append(keys, key[None], axis=0)
|
||||
|
||||
self.check_shape(key_func, keys, key)
|
||||
self.check_against_reference(key_func, arr_func, keys, key)
|
||||
self.check_shape(key_func, keys(), key())
|
||||
self.check_shape(arr_func, keys(), key())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
self.check_against_reference(key_func, arr_func, keys(), key())
|
||||
|
||||
def test_ravel(self):
|
||||
key = random.key(123)
|
||||
@ -1306,13 +1308,12 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_getitem(self, idx):
|
||||
key = random.key(123)
|
||||
keys = random.split(key, 3)
|
||||
|
||||
keys = lambda: random.split(random.key(123), 3)
|
||||
key_func = arr_func = lambda x: x[idx]
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
self.check_shape(key_func, keys())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
self.check_against_reference(key_func, arr_func, keys())
|
||||
|
||||
@parameterized.parameters([
|
||||
(0,),
|
||||
@ -1321,14 +1322,14 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_gather(self, idx):
|
||||
key = random.key(123)
|
||||
keys = random.split(key, 3)
|
||||
keys = lambda: random.split(random.key(123), 3)
|
||||
key_func = arr_func = lambda key: key.at[idx].get()
|
||||
|
||||
key_func = arr_func = lambda x: x.at[idx].get()
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
self.check_shape(key_func, keys())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
self.check_against_reference(key_func, arr_func, keys())
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_equality(self):
|
||||
key = random.key(123)
|
||||
key2 = random.key(456)
|
||||
@ -1354,13 +1355,13 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_scatter(self, idx):
|
||||
key = random.key(123)
|
||||
keys = random.split(key, 3)
|
||||
key = lambda: random.key(123)
|
||||
keys = lambda: random.split(key(), 3)
|
||||
|
||||
key_func = arr_func = lambda x, y: x.at[idx].set(y)
|
||||
key_func = arr_func = lambda k1, k2: k1.at[idx].set(k2)
|
||||
|
||||
self.check_shape(key_func, keys, key)
|
||||
self.check_against_reference(key_func, arr_func, keys, key)
|
||||
self.check_shape(key_func, keys(), key())
|
||||
self.check_against_reference(key_func, arr_func, keys(), key())
|
||||
|
||||
def test_errors(self):
|
||||
key = random.key(123)
|
||||
|
@ -1475,17 +1475,18 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
ndim = shape[0] if len(shape) > 1 else 1
|
||||
|
||||
args = args_maker()
|
||||
func = partial(resample, shape=())
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args)
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args_maker())
|
||||
assert result.shape == (ndim,)
|
||||
|
||||
func = partial(resample, shape=(4,))
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args)
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args_maker())
|
||||
assert result.shape == (ndim, 4)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
@ -29,6 +29,7 @@ from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import tuple_insert
|
||||
import jax.numpy as jnp
|
||||
@ -1696,8 +1697,8 @@ if CAN_USE_HYPOTHESIS:
|
||||
y, impl_vjp = jax.vjp(impl, x)
|
||||
y_ref, ref_vjp = jax.vjp(ref, x)
|
||||
self.assertAllClose(y, y_ref)
|
||||
t = random.normal(k2, x.shape)
|
||||
y2 = random.normal(k1, y.shape)
|
||||
t = random.normal(prng.reuse_key(k2), x.shape)
|
||||
y2 = random.normal(prng.reuse_key(k1), y.shape)
|
||||
self.assertAllClose(impl_vjp(t), ref_vjp(t))
|
||||
|
||||
# Second order
|
||||
@ -1713,7 +1714,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
(x,), impl_vjp2 = jax.vjp(impl_vjp, t2)
|
||||
(x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2)
|
||||
self.assertAllClose(x, x_ref)
|
||||
y2 = random.normal(k1, y.shape)
|
||||
y2 = random.normal(prng.reuse_key(k1), y.shape)
|
||||
self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -112,6 +112,7 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
self.assertEqual(x32.result(), jnp.int32)
|
||||
|
||||
@jax.legacy_prng_key('allow')
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype float64 is not available")
|
||||
def test_jit_cache(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user