mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] rename flag to jax_debug_key_reuse
This commit is contained in:
parent
cd79e71d85
commit
8949a63ce1
@ -47,7 +47,7 @@ from jax import typing as typing
|
||||
from jax._src.config import (
|
||||
config as config,
|
||||
enable_checks as enable_checks,
|
||||
enable_key_reuse_checks as enable_key_reuse_checks,
|
||||
debug_key_reuse as debug_key_reuse,
|
||||
check_tracer_leaks as check_tracer_leaks,
|
||||
checking_leaks as checking_leaks,
|
||||
enable_custom_prng as enable_custom_prng,
|
||||
|
@ -213,7 +213,7 @@ def trace_context():
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
enable_key_reuse_checks.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value)
|
||||
@ -930,8 +930,8 @@ enable_checks = define_bool_state(
|
||||
default=False,
|
||||
help='Turn on invariant checking for JAX internals. Makes things slower.')
|
||||
|
||||
enable_key_reuse_checks = define_bool_state(
|
||||
name='jax_enable_key_reuse_checks',
|
||||
debug_key_reuse = define_bool_state(
|
||||
name='jax_debug_key_reuse',
|
||||
default=False,
|
||||
help=('Turn on experimental key reuse checking. With this configuration enabled,'
|
||||
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'
|
||||
|
@ -2861,7 +2861,7 @@ def check_jaxpr(jaxpr: Jaxpr):
|
||||
raise JaxprTypeError(msg) from None
|
||||
|
||||
# Run key reuse checker after validating jaxpr:
|
||||
if config.enable_key_reuse_checks.value:
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
|
||||
check_key_reuse_jaxpr(jaxpr)
|
||||
|
@ -661,12 +661,12 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
class KeyReuseError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a PRNG key is reused in an unsafe manner.
|
||||
Key reuse is checked only when `jax_enable_key_reuse_checks` is
|
||||
Key reuse is checked only when `jax_debug_key_reuse` is
|
||||
set to `True`.
|
||||
|
||||
Here is a simple example of code that would lead to such an error::
|
||||
|
||||
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
|
||||
>>> with jax.debug_key_reuse(True): # doctest: +SKIP
|
||||
... key = jax.random.key(0)
|
||||
... value = jax.random.uniform(key)
|
||||
... new_value = jax.random.uniform(key)
|
||||
|
@ -236,7 +236,7 @@ def _get_fastpath_data(
|
||||
# no ref state effects
|
||||
and not any(isinstance(e, RefEffect) for e in effects)
|
||||
# no prng reuse checking
|
||||
and not (config.enable_key_reuse_checks.value and any(
|
||||
and not (config.debug_key_reuse.value and any(
|
||||
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
|
||||
for arg in (*args_flat, *out_flat)))
|
||||
)
|
||||
@ -1150,7 +1150,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
|
||||
if config.enable_key_reuse_checks.value:
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
|
||||
check_key_reuse_jaxpr(jaxpr)
|
||||
|
@ -42,7 +42,7 @@ from jax._src.internal_test_util import test_harnesses
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key='allow',
|
||||
jax_enable_key_reuse_checks=False)
|
||||
jax_debug_key_reuse=False)
|
||||
class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
# This test runs for all primitive harnesses. For each primitive "xxx" the
|
||||
|
@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow",
|
||||
jax_numpy_dtype_promotion='standard',
|
||||
jax_legacy_prng_key="allow",
|
||||
jax_enable_key_reuse_checks=False)
|
||||
jax_debug_key_reuse=False)
|
||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
# We want most tests to use the maximum available version, from the locally
|
||||
# installed tfxla module and export.
|
||||
|
@ -20,16 +20,16 @@ This module contains **experimental** functionality for detecting reuse of rando
|
||||
keys within JAX programs. It is under active development and the APIs here are
|
||||
likely to change. The usage below requires JAX version 0.4.26 or newer.
|
||||
|
||||
Key reuse checking can be enabled using the ``jax_enable_key_reuse_checks`` configuration.
|
||||
Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration.
|
||||
This can be set globally using::
|
||||
|
||||
>>> jax.config.update('jax_enable_key_reuse_checks', True) # doctest: +SKIP
|
||||
>>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP
|
||||
|
||||
Or it can be enabled locally with the :func:`jax.enable_key_reuse_checks` context manager.
|
||||
Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager.
|
||||
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
|
||||
|
||||
>>> import jax
|
||||
>>> with jax.enable_key_reuse_checks(True):
|
||||
>>> with jax.debug_key_reuse(True):
|
||||
... key = jax.random.key(0)
|
||||
... val1 = jax.random.normal(key)
|
||||
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
|
@ -530,7 +530,7 @@ key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
|
||||
def key_reuse_impl_rule(prim, original_rule):
|
||||
@wraps(original_rule)
|
||||
def key_reuse_impl(*args, **kwargs):
|
||||
if config.enable_key_reuse_checks.value:
|
||||
if config.debug_key_reuse.value:
|
||||
if prim == pjit.pjit_p:
|
||||
funcname = "jit-compiled function"
|
||||
jaxpr = kwargs['jaxpr'].jaxpr
|
||||
|
@ -962,7 +962,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
||||
return u
|
||||
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
|
||||
|
||||
def testEmptyTuples(self):
|
||||
|
@ -750,7 +750,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
def test_check_jaxpr_key_reuse(self):
|
||||
with config.enable_key_reuse_checks(True):
|
||||
with config.debug_key_reuse(True):
|
||||
def f(seed):
|
||||
key = jax.random.key(seed)
|
||||
return jax.random.uniform(key) + jax.random.normal(key)
|
||||
|
@ -66,7 +66,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key='allow',
|
||||
jax_enable_key_reuse_checks=False)
|
||||
jax_debug_key_reuse=False)
|
||||
class CompatTest(bctu.CompatTestBase):
|
||||
def test_dummy(self):
|
||||
# Tests the testing mechanism. Let this test run on all platforms
|
||||
|
@ -67,7 +67,7 @@ def apply_unknown_primitive(key):
|
||||
|
||||
@jtu.with_config(
|
||||
jax_enable_custom_prng=False,
|
||||
jax_enable_key_reuse_checks=False)
|
||||
jax_debug_key_reuse=False)
|
||||
class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
def check_key_reuse(self, *args):
|
||||
return _core.check_key_reuse(*args)
|
||||
@ -353,7 +353,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
self.assertEqual(signature, _core.function_type_signature(func, *args))
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
@jtu.with_config(jax_debug_key_reuse=False)
|
||||
class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*"
|
||||
random_split_error = "In random_split, argument [0-9]+ is already consumed.*"
|
||||
@ -607,7 +607,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
self.check_key_reuse(jax.grad(f_good), x, key)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=True)
|
||||
@jtu.with_config(jax_debug_key_reuse=True)
|
||||
class KeyReuseEagerTest(jtu.JaxTestCase):
|
||||
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
|
||||
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
|
||||
@ -735,14 +735,14 @@ class KeyReuseGlobalFlagsTest(jtu.JaxTestCase):
|
||||
|
||||
key = jax.random.key(0)
|
||||
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
f_good(key)
|
||||
f_bad(key) # No failure
|
||||
|
||||
f_bad.clear_cache()
|
||||
f_good.clear_cache()
|
||||
|
||||
with jax.enable_key_reuse_checks(True):
|
||||
with jax.debug_key_reuse(True):
|
||||
f_good(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, "In random_bits.*"):
|
||||
f_bad(key)
|
||||
|
@ -1083,7 +1083,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(jakevdp): key reuse checks for this OOM because of slice masking.
|
||||
# Can we fix this?
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
# just lower, don't run, takes too long
|
||||
jax.jit(f).lower()
|
||||
|
||||
@ -1348,7 +1348,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
|
||||
self.assertEqual(out.shape, (3, 2))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
@ -1383,7 +1383,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
keys = random.split(key, 10)
|
||||
self.assertEqual(keys.shape, (10, *key.shape))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.make_key(73), 2)
|
||||
@ -1399,7 +1399,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
||||
self.assertEqual(out.shape, keys.shape)
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_split_not_mapped_key(self):
|
||||
key = self.make_key(73)
|
||||
single_split_key = random.split(key)
|
||||
@ -1409,14 +1409,14 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(random.key_data(vk),
|
||||
random.key_data(single_split_key))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_split_mapped_key_shape(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_split_mapped_key_values(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
@ -1426,7 +1426,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(random.key_data(rk),
|
||||
random.key_data(vk))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_random_bits_shape(self):
|
||||
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
|
||||
key = self.make_key(73)
|
||||
@ -1435,7 +1435,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(rand_nums.shape, (3,))
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_random_bits_value(self):
|
||||
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
|
||||
key = self.make_key(73)
|
||||
@ -1491,7 +1491,7 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
return random.PRNGKey(seed, impl="unsafe_rbg")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_vmap_split_mapped_key_values(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
|
@ -677,7 +677,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))
|
||||
|
||||
# TODO(jakevdp) remove this decorator when reuse checks move to C++
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_cpp_dispatch_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key array as an argument.
|
||||
@ -694,7 +694,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
# TODO(jakevdp) remove this decorator when reuse checks move to C++
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_cpp_dispatch_split(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key arrays as inputs and as outputs.
|
||||
@ -1277,7 +1277,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
self.check_shape(key_func, keys(), key())
|
||||
self.check_shape(arr_func, keys(), key())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
self.check_against_reference(key_func, arr_func, keys(), key())
|
||||
|
||||
def test_ravel(self):
|
||||
@ -1322,7 +1322,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
key_func = arr_func = lambda x: x[idx]
|
||||
|
||||
self.check_shape(key_func, keys())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
self.check_against_reference(key_func, arr_func, keys())
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -1336,10 +1336,10 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
key_func = arr_func = lambda key: key.at[idx].get()
|
||||
|
||||
self.check_shape(key_func, keys())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
self.check_against_reference(key_func, arr_func, keys())
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_equality(self):
|
||||
key = random.key(123)
|
||||
key2 = random.key(456)
|
||||
|
@ -1476,14 +1476,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
ndim = shape[0] if len(shape) > 1 else 1
|
||||
|
||||
func = partial(resample, shape=())
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args_maker())
|
||||
assert result.shape == (ndim,)
|
||||
|
||||
func = partial(resample, shape=(4,))
|
||||
with jax.enable_key_reuse_checks(False):
|
||||
with jax.debug_key_reuse(False):
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args_maker())
|
||||
|
@ -1757,7 +1757,7 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
polymorphic_shapes=["(b,)"])
|
||||
self.assertAllClose(f(x), res_tf)
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
def test_prng(self):
|
||||
# The PRNG implementation uses opaque types, test shape polymorphism
|
||||
with config.enable_custom_prng(True):
|
||||
@ -3294,7 +3294,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
# Update this here rather than in harness object because vmap_random_gamma is derived
|
||||
# from test_harnesses.all_harnesses, which strips override_jax_config_flags.
|
||||
if "random_gamma" in harness.group_name:
|
||||
config_flags = {**config_flags, "jax_enable_key_reuse_checks": False}
|
||||
config_flags = {**config_flags, "jax_debug_key_reuse": False}
|
||||
|
||||
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
|
||||
try:
|
||||
|
@ -112,7 +112,7 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
self.assertEqual(x32.result(), jnp.int32)
|
||||
|
||||
@jax.legacy_prng_key('allow')
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
@jax.debug_key_reuse(False)
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype float64 is not available")
|
||||
def test_jit_cache(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user