mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[key reuse] define KeyReuseError in jax.errors
This commit is contained in:
parent
0f2a89d837
commit
7634708743
@ -7,6 +7,7 @@ along with representative examples of how one might fix them.
|
||||
|
||||
.. currentmodule:: jax.errors
|
||||
.. autoclass:: ConcretizationTypeError
|
||||
.. autoclass:: KeyReuseError
|
||||
.. autoclass:: NonConcreteBooleanIndexError
|
||||
.. autoclass:: TracerArrayConversionError
|
||||
.. autoclass:: TracerBoolConversionError
|
||||
|
@ -2,11 +2,3 @@
|
||||
=====================================
|
||||
|
||||
.. automodule:: jax.experimental.key_reuse
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
KeyReuseError
|
||||
|
@ -655,3 +655,29 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
|
||||
def __init__(self, msg: str):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@export
|
||||
class KeyReuseError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a PRNG key is reused in an unsafe manner.
|
||||
Key reuse is checked only when `jax_enable_key_reuse_checks` is
|
||||
set to `True`.
|
||||
|
||||
Here is a simple example of code that would lead to such an error::
|
||||
|
||||
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
|
||||
... key = jax.random.key(0)
|
||||
... value = jax.random.uniform(key)
|
||||
... new_value = jax.random.uniform(key)
|
||||
...
|
||||
---------------------------------------------------------------------------
|
||||
KeyReuseError Traceback (most recent call last)
|
||||
...
|
||||
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
|
||||
|
||||
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys
|
||||
must be manually split; For more information on this see `Sharp Bits: Random Numbers
|
||||
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers>`_.
|
||||
"""
|
||||
pass
|
||||
|
@ -24,5 +24,6 @@ from jax._src.errors import (
|
||||
TracerBoolConversionError as TracerBoolConversionError,
|
||||
TracerIntegerConversionError as TracerIntegerConversionError,
|
||||
UnexpectedTracerError as UnexpectedTracerError,
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
|
||||
|
@ -39,6 +39,3 @@ context manager::
|
||||
... print(jax.random.normal(key))
|
||||
-0.20584226
|
||||
"""
|
||||
from jax.experimental.key_reuse._core import (
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ from typing import Any, Callable, NamedTuple
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax.errors import KeyReuseError
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
@ -99,9 +100,6 @@ class KeyReuseSignature(NamedTuple):
|
||||
arg_out._consumed = arg_in._consumed
|
||||
|
||||
|
||||
class KeyReuseError(RuntimeError):
|
||||
pass
|
||||
|
||||
consume_p = core.Primitive("consume")
|
||||
consume_p.def_impl(lambda x: x)
|
||||
consume_p.def_abstract_eval(lambda x: x)
|
||||
|
@ -751,15 +751,11 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_check_jaxpr_key_reuse(self):
|
||||
with config.enable_key_reuse_checks(True):
|
||||
try:
|
||||
from jax.experimental.key_reuse import KeyReuseError
|
||||
except ImportError:
|
||||
self.skipTest("Test requires jax.experimental.key_reuse")
|
||||
def f(seed):
|
||||
key = jax.random.key(seed)
|
||||
return jax.random.uniform(key) + jax.random.normal(key)
|
||||
with jax.enable_checks(True):
|
||||
with self.assertRaises(KeyReuseError):
|
||||
with self.assertRaises(jax.errors.KeyReuseError):
|
||||
jax.jit(f)(0)
|
||||
|
||||
|
||||
|
@ -23,9 +23,10 @@ import jax.numpy as jnp
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import test_util as jtu
|
||||
from jax.errors import KeyReuseError
|
||||
from jax.experimental.key_reuse._core import (
|
||||
assert_consumed, assert_unconsumed, consume, consume_p)
|
||||
from jax.experimental.key_reuse import _core, KeyReuseError
|
||||
from jax.experimental.key_reuse import _core
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user