mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] improve some key reuse errors.
This commit is contained in:
parent
28fa88681e
commit
bb91bf2e09
@ -43,6 +43,6 @@ from jax._src.prng import (
|
||||
reuse_key as reuse_key,
|
||||
)
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
from jax.experimental.key_reuse._core import (
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
|
@ -1,113 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
from jax import core
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import prng
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Sink(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Sink({self.idx})"
|
||||
else:
|
||||
return f"Sink({self.idx}, mask={self.mask})"
|
||||
|
||||
|
||||
class Source(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Source({self.idx})"
|
||||
else:
|
||||
return f"Source({self.idx}, mask={self.mask})"
|
||||
|
||||
class Forward(NamedTuple):
|
||||
in_idx: int
|
||||
out_idx: int
|
||||
|
||||
|
||||
class KeyReuseSignature(NamedTuple):
|
||||
sinks: list[Sink]
|
||||
sources: list[Source]
|
||||
forwards: list[Forward] = []
|
||||
|
||||
def check_signature(self, *args, jaxpr=None):
|
||||
for sink in self.sinks:
|
||||
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
|
||||
continue
|
||||
if np.any(args[sink.idx]._consumed & sink.mask):
|
||||
msg = f"Previously-consumed key at index {sink.idx} passed to function"
|
||||
if jaxpr:
|
||||
msg += f"\n{jaxpr=}"
|
||||
raise KeyReuseError(msg)
|
||||
|
||||
def update_consumption(self, args_in, args_out):
|
||||
for sink in self.sinks:
|
||||
arg = args_in[sink.idx]
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = arg._consumed | sink.mask
|
||||
for arg in args_out:
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = True
|
||||
for source in self.sources:
|
||||
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
|
||||
args_out[source.idx]._consumed = ~np.asarray(source.mask)
|
||||
for forward in self.forwards:
|
||||
arg_in = args_in[forward.in_idx]
|
||||
arg_out = args_out[forward.out_idx]
|
||||
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
|
||||
arg_out._consumed = arg_in._consumed
|
||||
|
||||
|
||||
class KeyReuseError(RuntimeError):
|
||||
pass
|
||||
|
||||
consume_p = core.Primitive("consume")
|
||||
consume_p.def_impl(lambda x: x)
|
||||
consume_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(consume_p)
|
||||
mlir.register_lowering(
|
||||
consume_p,
|
||||
mlir.lower_fun(lambda x: x, multiple_results=False))
|
||||
|
||||
def consume(key):
|
||||
"""Consume the key and return a consumed copy."""
|
||||
return consume_p.bind(key)
|
||||
|
||||
|
||||
assert_consumed_value_p = core.Primitive("assert_consumed_value")
|
||||
assert_consumed_value_p.def_impl(lambda x, *, value: x)
|
||||
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
|
||||
batching.defvectorized(assert_consumed_value_p)
|
||||
mlir.register_lowering(
|
||||
assert_consumed_value_p,
|
||||
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
|
||||
|
||||
def assert_unconsumed(key):
|
||||
"""Assert that a key is unconsumed"""
|
||||
assert_consumed_value_p.bind(key, value=False)
|
||||
|
||||
def assert_consumed(key, value=True):
|
||||
"""Assert that a key is consumed"""
|
||||
assert_consumed_value_p.bind(key, value=value)
|
@ -16,17 +16,17 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial, reduce, wraps
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import util
|
||||
@ -35,14 +35,102 @@ from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.util import weakref_lru_cache
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
consume_p, assert_consumed_value_p, KeyReuseError,
|
||||
Sink, Source, Forward, KeyReuseSignature
|
||||
)
|
||||
from jax.experimental.shard_map import shard_map_p
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Sink(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Sink({self.idx})"
|
||||
else:
|
||||
return f"Sink({self.idx}, mask={self.mask})"
|
||||
|
||||
|
||||
class Source(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Source({self.idx})"
|
||||
else:
|
||||
return f"Source({self.idx}, mask={self.mask})"
|
||||
|
||||
class Forward(NamedTuple):
|
||||
in_idx: int
|
||||
out_idx: int
|
||||
|
||||
|
||||
class KeyReuseSignature(NamedTuple):
|
||||
sinks: list[Sink]
|
||||
sources: list[Source]
|
||||
forwards: list[Forward] = []
|
||||
|
||||
def check_signature(self, *args, funcname="function", context=None):
|
||||
for sink in self.sinks:
|
||||
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
|
||||
continue
|
||||
if np.any(args[sink.idx]._consumed & sink.mask):
|
||||
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
|
||||
if context:
|
||||
msg += " {context}"
|
||||
raise KeyReuseError(msg)
|
||||
|
||||
def update_consumption(self, args_in, args_out):
|
||||
for sink in self.sinks:
|
||||
arg = args_in[sink.idx]
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = arg._consumed | sink.mask
|
||||
for arg in args_out:
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = True
|
||||
for source in self.sources:
|
||||
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
|
||||
args_out[source.idx]._consumed = ~np.asarray(source.mask)
|
||||
for forward in self.forwards:
|
||||
arg_in = args_in[forward.in_idx]
|
||||
arg_out = args_out[forward.out_idx]
|
||||
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
|
||||
arg_out._consumed = arg_in._consumed
|
||||
|
||||
|
||||
class KeyReuseError(RuntimeError):
|
||||
pass
|
||||
|
||||
consume_p = core.Primitive("consume")
|
||||
consume_p.def_impl(lambda x: x)
|
||||
consume_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(consume_p)
|
||||
mlir.register_lowering(
|
||||
consume_p,
|
||||
mlir.lower_fun(lambda x: x, multiple_results=False))
|
||||
|
||||
def consume(key):
|
||||
"""Consume the key and return a consumed copy."""
|
||||
return consume_p.bind(key)
|
||||
|
||||
|
||||
assert_consumed_value_p = core.Primitive("assert_consumed_value")
|
||||
assert_consumed_value_p.def_impl(lambda x, *, value: x)
|
||||
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
|
||||
batching.defvectorized(assert_consumed_value_p)
|
||||
mlir.register_lowering(
|
||||
assert_consumed_value_p,
|
||||
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
|
||||
|
||||
def assert_unconsumed(key):
|
||||
"""Assert that a key is unconsumed"""
|
||||
assert_consumed_value_p.bind(key, value=False)
|
||||
|
||||
def assert_consumed(key, value=True):
|
||||
"""Assert that a key is consumed"""
|
||||
assert_consumed_value_p.bind(key, value=value)
|
||||
|
||||
|
||||
def _check_consumed_value(eqn, consumed):
|
||||
"""Extra check for use with assert_consumed_value_p"""
|
||||
expected = eqn.params['value']
|
||||
@ -341,17 +429,20 @@ def key_reuse_impl_rule(prim, original_rule):
|
||||
def key_reuse_impl(*args, **kwargs):
|
||||
if config.enable_key_reuse_checks.value:
|
||||
if prim == pjit.pjit_p:
|
||||
funcname = "jit-compiled function"
|
||||
jaxpr = kwargs['jaxpr'].jaxpr
|
||||
signature = get_jaxpr_type_signature(jaxpr)
|
||||
elif prim in key_reuse_signatures:
|
||||
jaxpr = prim
|
||||
funcname = str(prim)
|
||||
jaxpr = None
|
||||
signature = key_reuse_signatures[prim]
|
||||
elif prim in key_reuse_signatures_dynamic:
|
||||
funcname = str(prim)
|
||||
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
|
||||
signature = get_jaxpr_type_signature(jaxpr)
|
||||
else:
|
||||
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
|
||||
signature.check_signature(*args, jaxpr=jaxpr)
|
||||
signature.check_signature(*args, funcname=funcname)
|
||||
result = original_rule(*args, **kwargs)
|
||||
signature.update_consumption(args, result if prim.multiple_results else [result])
|
||||
return result
|
||||
|
@ -22,7 +22,7 @@ from jax import core
|
||||
import jax.numpy as jnp
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.key_reuse._common import (
|
||||
from jax.experimental.key_reuse._core import (
|
||||
assert_consumed, assert_unconsumed, consume, consume_p)
|
||||
from jax.experimental.key_reuse import _core, KeyReuseError
|
||||
|
||||
@ -587,28 +587,29 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class KeyReuseEager(jtu.JaxTestCase):
|
||||
jit_msg = "Previously-consumed key at index 0 passed to function"
|
||||
bits_msg = "In random_bits, key values a are already consumed."
|
||||
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
|
||||
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
|
||||
traced_bits_msg = "In random_bits, key values a are already consumed."
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
with jax.disable_jit():
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
|
||||
_ = jax.random.bits(key)
|
||||
|
||||
def test_simple_key_reuse_jit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
_ = jax.random.bits(key)
|
||||
_ = jax.jit(jax.random.bits)(key)
|
||||
|
||||
def test_key_reuse_within_jit(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
return jax.random.bits(key) + jax.random.bits(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg):
|
||||
f()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user