[key reuse] rename flag to jax_debug_key_reuse

This commit is contained in:
Jake VanderPlas 2024-03-21 10:47:16 -07:00
parent cd79e71d85
commit 8949a63ce1
18 changed files with 44 additions and 44 deletions

View File

@ -47,7 +47,7 @@ from jax import typing as typing
from jax._src.config import (
config as config,
enable_checks as enable_checks,
enable_key_reuse_checks as enable_key_reuse_checks,
debug_key_reuse as debug_key_reuse,
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,

View File

@ -213,7 +213,7 @@ def trace_context():
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
enable_key_reuse_checks.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value)
@ -930,8 +930,8 @@ enable_checks = define_bool_state(
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')
enable_key_reuse_checks = define_bool_state(
name='jax_enable_key_reuse_checks',
debug_key_reuse = define_bool_state(
name='jax_debug_key_reuse',
default=False,
help=('Turn on experimental key reuse checking. With this configuration enabled,'
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'

View File

@ -2861,7 +2861,7 @@ def check_jaxpr(jaxpr: Jaxpr):
raise JaxprTypeError(msg) from None
# Run key reuse checker after validating jaxpr:
if config.enable_key_reuse_checks.value:
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
check_key_reuse_jaxpr(jaxpr)

View File

@ -661,12 +661,12 @@ class UnexpectedTracerError(JAXTypeError):
class KeyReuseError(JAXTypeError):
"""
This error occurs when a PRNG key is reused in an unsafe manner.
Key reuse is checked only when `jax_enable_key_reuse_checks` is
Key reuse is checked only when `jax_debug_key_reuse` is
set to `True`.
Here is a simple example of code that would lead to such an error::
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
>>> with jax.debug_key_reuse(True): # doctest: +SKIP
... key = jax.random.key(0)
... value = jax.random.uniform(key)
... new_value = jax.random.uniform(key)

View File

@ -236,7 +236,7 @@ def _get_fastpath_data(
# no ref state effects
and not any(isinstance(e, RefEffect) for e in effects)
# no prng reuse checking
and not (config.enable_key_reuse_checks.value and any(
and not (config.debug_key_reuse.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat)))
)
@ -1150,7 +1150,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if config.enable_key_reuse_checks.value:
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
check_key_reuse_jaxpr(jaxpr)

View File

@ -42,7 +42,7 @@ from jax._src.internal_test_util import test_harnesses
@jtu.with_config(jax_legacy_prng_key='allow',
jax_enable_key_reuse_checks=False)
jax_debug_key_reuse=False)
class JaxPrimitiveTest(jtu.JaxTestCase):
# This test runs for all primitive harnesses. For each primitive "xxx" the

View File

@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_numpy_dtype_promotion='standard',
jax_legacy_prng_key="allow",
jax_enable_key_reuse_checks=False)
jax_debug_key_reuse=False)
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module and export.

View File

@ -20,16 +20,16 @@ This module contains **experimental** functionality for detecting reuse of rando
keys within JAX programs. It is under active development and the APIs here are
likely to change. The usage below requires JAX version 0.4.26 or newer.
Key reuse checking can be enabled using the ``jax_enable_key_reuse_checks`` configuration.
Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration.
This can be set globally using::
>>> jax.config.update('jax_enable_key_reuse_checks', True) # doctest: +SKIP
>>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP
Or it can be enabled locally with the :func:`jax.enable_key_reuse_checks` context manager.
Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager.
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
>>> import jax
>>> with jax.enable_key_reuse_checks(True):
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL

View File

@ -530,7 +530,7 @@ key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
def key_reuse_impl_rule(prim, original_rule):
@wraps(original_rule)
def key_reuse_impl(*args, **kwargs):
if config.enable_key_reuse_checks.value:
if config.debug_key_reuse.value:
if prim == pjit.pjit_p:
funcname = "jit-compiled function"
jaxpr = kwargs['jaxpr'].jaxpr

View File

@ -962,7 +962,7 @@ class BatchingTest(jtu.JaxTestCase):
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
return u
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
def testEmptyTuples(self):

View File

@ -750,7 +750,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
core.check_jaxpr(jaxpr)
def test_check_jaxpr_key_reuse(self):
with config.enable_key_reuse_checks(True):
with config.debug_key_reuse(True):
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)

View File

@ -66,7 +66,7 @@ config.parse_flags_with_absl()
@jtu.with_config(jax_legacy_prng_key='allow',
jax_enable_key_reuse_checks=False)
jax_debug_key_reuse=False)
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
# Tests the testing mechanism. Let this test run on all platforms

View File

@ -67,7 +67,7 @@ def apply_unknown_primitive(key):
@jtu.with_config(
jax_enable_custom_prng=False,
jax_enable_key_reuse_checks=False)
jax_debug_key_reuse=False)
class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
def check_key_reuse(self, *args):
return _core.check_key_reuse(*args)
@ -353,7 +353,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
self.assertEqual(signature, _core.function_type_signature(func, *args))
@jtu.with_config(jax_enable_key_reuse_checks=False)
@jtu.with_config(jax_debug_key_reuse=False)
class KeyReuseIntegrationTest(jtu.JaxTestCase):
random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*"
random_split_error = "In random_split, argument [0-9]+ is already consumed.*"
@ -607,7 +607,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
self.check_key_reuse(jax.grad(f_good), x, key)
@jtu.with_config(jax_enable_key_reuse_checks=True)
@jtu.with_config(jax_debug_key_reuse=True)
class KeyReuseEagerTest(jtu.JaxTestCase):
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
@ -735,14 +735,14 @@ class KeyReuseGlobalFlagsTest(jtu.JaxTestCase):
key = jax.random.key(0)
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
f_good(key)
f_bad(key) # No failure
f_bad.clear_cache()
f_good.clear_cache()
with jax.enable_key_reuse_checks(True):
with jax.debug_key_reuse(True):
f_good(key)
with self.assertRaisesRegex(KeyReuseError, "In random_bits.*"):
f_bad(key)

View File

@ -1083,7 +1083,7 @@ class LaxRandomTest(jtu.JaxTestCase):
# TODO(jakevdp): key reuse checks for this OOM because of slice masking.
# Can we fix this?
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
# just lower, don't run, takes too long
jax.jit(f).lower()
@ -1348,7 +1348,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
self.assertEqual(out.shape, (3, 2))
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
@ -1383,7 +1383,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10, *key.shape))
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.make_key(73), 2)
@ -1399,7 +1399,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
self.assertEqual(out.shape, keys.shape)
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_split_not_mapped_key(self):
key = self.make_key(73)
single_split_key = random.split(key)
@ -1409,14 +1409,14 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertArraysEqual(random.key_data(vk),
random.key_data(single_split_key))
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_shape(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
@ -1426,7 +1426,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertArraysEqual(random.key_data(rk),
random.key_data(vk))
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_random_bits_shape(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
@ -1435,7 +1435,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertEqual(rand_nums.shape, (3,))
@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_random_bits_value(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
@ -1491,7 +1491,7 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
return random.PRNGKey(seed, impl="unsafe_rbg")
@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)

View File

@ -677,7 +677,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))
# TODO(jakevdp) remove this decorator when reuse checks move to C++
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key array as an argument.
@ -694,7 +694,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(count[0], 1)
# TODO(jakevdp) remove this decorator when reuse checks move to C++
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_cpp_dispatch_split(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key arrays as inputs and as outputs.
@ -1277,7 +1277,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.check_shape(key_func, keys(), key())
self.check_shape(arr_func, keys(), key())
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
self.check_against_reference(key_func, arr_func, keys(), key())
def test_ravel(self):
@ -1322,7 +1322,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
key_func = arr_func = lambda x: x[idx]
self.check_shape(key_func, keys())
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
self.check_against_reference(key_func, arr_func, keys())
@parameterized.parameters([
@ -1336,10 +1336,10 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
key_func = arr_func = lambda key: key.at[idx].get()
self.check_shape(key_func, keys())
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
self.check_against_reference(key_func, arr_func, keys())
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_equality(self):
key = random.key(123)
key2 = random.key(456)

View File

@ -1476,14 +1476,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
ndim = shape[0] if len(shape) > 1 else 1
func = partial(resample, shape=())
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args_maker())
assert result.shape == (ndim,)
func = partial(resample, shape=(4,))
with jax.enable_key_reuse_checks(False):
with jax.debug_key_reuse(False):
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args_maker())

View File

@ -1757,7 +1757,7 @@ class ShapePolyTest(jtu.JaxTestCase):
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
def test_prng(self):
# The PRNG implementation uses opaque types, test shape polymorphism
with config.enable_custom_prng(True):
@ -3294,7 +3294,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
# Update this here rather than in harness object because vmap_random_gamma is derived
# from test_harnesses.all_harnesses, which strips override_jax_config_flags.
if "random_gamma" in harness.group_name:
config_flags = {**config_flags, "jax_enable_key_reuse_checks": False}
config_flags = {**config_flags, "jax_debug_key_reuse": False}
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
try:

View File

@ -112,7 +112,7 @@ class X64ContextTests(jtu.JaxTestCase):
self.assertEqual(x32.result(), jnp.int32)
@jax.legacy_prng_key('allow')
@jax.enable_key_reuse_checks(False)
@jax.debug_key_reuse(False)
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype float64 is not available")
def test_jit_cache(self):