diff --git a/docs/errors.rst b/docs/errors.rst index 4c76f5dcf..23dbaf29c 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -7,6 +7,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError +.. autoclass:: KeyReuseError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst index c78f23866..7255afabf 100644 --- a/docs/jax.experimental.key_reuse.rst +++ b/docs/jax.experimental.key_reuse.rst @@ -2,11 +2,3 @@ ===================================== .. automodule:: jax.experimental.key_reuse - -API ---- - -.. autosummary:: - :toctree: _autosummary - - KeyReuseError diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 5594b261a..dd5d83d51 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -655,3 +655,29 @@ class UnexpectedTracerError(JAXTypeError): def __init__(self, msg: str): super().__init__(msg) + + +@export +class KeyReuseError(JAXTypeError): + """ + This error occurs when a PRNG key is reused in an unsafe manner. + Key reuse is checked only when `jax_enable_key_reuse_checks` is + set to `True`. + + Here is a simple example of code that would lead to such an error:: + + >>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP + ... key = jax.random.key(0) + ... value = jax.random.uniform(key) + ... new_value = jax.random.uniform(key) + ... + --------------------------------------------------------------------------- + KeyReuseError Traceback (most recent call last) + ... + KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 + + This sort of key reuse is problematic because the JAX PRNG is stateless, and keys + must be manually split; For more information on this see `Sharp Bits: Random Numbers + `_. + """ + pass diff --git a/jax/errors.py b/jax/errors.py index 4b8a0cf75..15a6654fa 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -24,5 +24,6 @@ from jax._src.errors import ( TracerBoolConversionError as TracerBoolConversionError, TracerIntegerConversionError as TracerIntegerConversionError, UnexpectedTracerError as UnexpectedTracerError, + KeyReuseError as KeyReuseError, ) from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 75b231ed2..f33020009 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -39,6 +39,3 @@ context manager:: ... print(jax.random.normal(key)) -0.20584226 """ -from jax.experimental.key_reuse._core import ( - KeyReuseError as KeyReuseError, -) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 0d832c3dd..275c0e5f2 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -21,6 +21,7 @@ from typing import Any, Callable, NamedTuple import jax from jax import lax from jax import tree_util +from jax.errors import KeyReuseError from jax.interpreters import batching, mlir from jax._src import api_util from jax._src import config @@ -99,9 +100,6 @@ class KeyReuseSignature(NamedTuple): arg_out._consumed = arg_in._consumed -class KeyReuseError(RuntimeError): - pass - consume_p = core.Primitive("consume") consume_p.def_impl(lambda x: x) consume_p.def_abstract_eval(lambda x: x) diff --git a/tests/core_test.py b/tests/core_test.py index 788f61db9..1831c5774 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -751,15 +751,11 @@ class DynamicShapesTest(jtu.JaxTestCase): def test_check_jaxpr_key_reuse(self): with config.enable_key_reuse_checks(True): - try: - from jax.experimental.key_reuse import KeyReuseError - except ImportError: - self.skipTest("Test requires jax.experimental.key_reuse") def f(seed): key = jax.random.key(seed) return jax.random.uniform(key) + jax.random.normal(key) with jax.enable_checks(True): - with self.assertRaises(KeyReuseError): + with self.assertRaises(jax.errors.KeyReuseError): jax.jit(f)(0) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d1c390663..725c99958 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -23,9 +23,10 @@ import jax.numpy as jnp from jax._src import prng from jax._src import random from jax._src import test_util as jtu +from jax.errors import KeyReuseError from jax.experimental.key_reuse._core import ( assert_consumed, assert_unconsumed, consume, consume_p) -from jax.experimental.key_reuse import _core, KeyReuseError +from jax.experimental.key_reuse import _core from jax import config config.parse_flags_with_absl()