[key reuse] define KeyReuseError in jax.errors

This commit is contained in:
Jake VanderPlas 2024-03-08 10:40:33 -08:00
parent 0f2a89d837
commit 7634708743
8 changed files with 32 additions and 20 deletions

View File

@ -7,6 +7,7 @@ along with representative examples of how one might fix them.
.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: KeyReuseError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerBoolConversionError

View File

@ -2,11 +2,3 @@
=====================================
.. automodule:: jax.experimental.key_reuse
API
---
.. autosummary::
:toctree: _autosummary
KeyReuseError

View File

@ -655,3 +655,29 @@ class UnexpectedTracerError(JAXTypeError):
def __init__(self, msg: str):
super().__init__(msg)
@export
class KeyReuseError(JAXTypeError):
"""
This error occurs when a PRNG key is reused in an unsafe manner.
Key reuse is checked only when `jax_enable_key_reuse_checks` is
set to `True`.
Here is a simple example of code that would lead to such an error::
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
... key = jax.random.key(0)
... value = jax.random.uniform(key)
... new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys
must be manually split; For more information on this see `Sharp Bits: Random Numbers
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers>`_.
"""
pass

View File

@ -24,5 +24,6 @@ from jax._src.errors import (
TracerBoolConversionError as TracerBoolConversionError,
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
KeyReuseError as KeyReuseError,
)
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback

View File

@ -39,6 +39,3 @@ context manager::
... print(jax.random.normal(key))
-0.20584226
"""
from jax.experimental.key_reuse._core import (
KeyReuseError as KeyReuseError,
)

View File

@ -21,6 +21,7 @@ from typing import Any, Callable, NamedTuple
import jax
from jax import lax
from jax import tree_util
from jax.errors import KeyReuseError
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
@ -99,9 +100,6 @@ class KeyReuseSignature(NamedTuple):
arg_out._consumed = arg_in._consumed
class KeyReuseError(RuntimeError):
pass
consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)

View File

@ -751,15 +751,11 @@ class DynamicShapesTest(jtu.JaxTestCase):
def test_check_jaxpr_key_reuse(self):
with config.enable_key_reuse_checks(True):
try:
from jax.experimental.key_reuse import KeyReuseError
except ImportError:
self.skipTest("Test requires jax.experimental.key_reuse")
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(KeyReuseError):
with self.assertRaises(jax.errors.KeyReuseError):
jax.jit(f)(0)

View File

@ -23,9 +23,10 @@ import jax.numpy as jnp
from jax._src import prng
from jax._src import random
from jax._src import test_util as jtu
from jax.errors import KeyReuseError
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _core, KeyReuseError
from jax.experimental.key_reuse import _core
from jax import config
config.parse_flags_with_absl()