[key reuse] improve some key reuse errors.

This commit is contained in:
Jake VanderPlas 2024-03-04 14:50:39 -08:00
parent 28fa88681e
commit bb91bf2e09
4 changed files with 107 additions and 128 deletions

View File

@ -43,6 +43,6 @@ from jax._src.prng import (
reuse_key as reuse_key,
)
from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
KeyReuseError as KeyReuseError,
)

View File

@ -1,113 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import NamedTuple
from jax import core
from jax.interpreters import batching, mlir
from jax._src import prng
import numpy as np
class Sink(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Sink({self.idx})"
else:
return f"Sink({self.idx}, mask={self.mask})"
class Source(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Source({self.idx})"
else:
return f"Source({self.idx}, mask={self.mask})"
class Forward(NamedTuple):
in_idx: int
out_idx: int
class KeyReuseSignature(NamedTuple):
sinks: list[Sink]
sources: list[Source]
forwards: list[Forward] = []
def check_signature(self, *args, jaxpr=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
msg = f"Previously-consumed key at index {sink.idx} passed to function"
if jaxpr:
msg += f"\n{jaxpr=}"
raise KeyReuseError(msg)
def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed
class KeyReuseError(RuntimeError):
pass
consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))
def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)
assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=False)
def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=value)

View File

@ -16,17 +16,17 @@ from __future__ import annotations
from collections import defaultdict
from functools import partial, reduce, wraps
from typing import Any, Callable
from typing import Any, Callable, NamedTuple
import jax
from jax import lax
from jax import tree_util
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import pretty_printer as pp
from jax._src import prng
from jax._src import random
from jax._src import util
@ -35,14 +35,102 @@ from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache
from jax.experimental.key_reuse._common import (
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, Forward, KeyReuseSignature
)
from jax.experimental.shard_map import shard_map_p
import numpy as np
class Sink(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Sink({self.idx})"
else:
return f"Sink({self.idx}, mask={self.mask})"
class Source(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Source({self.idx})"
else:
return f"Source({self.idx}, mask={self.mask})"
class Forward(NamedTuple):
in_idx: int
out_idx: int
class KeyReuseSignature(NamedTuple):
sinks: list[Sink]
sources: list[Source]
forwards: list[Forward] = []
def check_signature(self, *args, funcname="function", context=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
if context:
msg += " {context}"
raise KeyReuseError(msg)
def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed
class KeyReuseError(RuntimeError):
pass
consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))
def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)
assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=False)
def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=value)
def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value']
@ -341,17 +429,20 @@ def key_reuse_impl_rule(prim, original_rule):
def key_reuse_impl(*args, **kwargs):
if config.enable_key_reuse_checks.value:
if prim == pjit.pjit_p:
funcname = "jit-compiled function"
jaxpr = kwargs['jaxpr'].jaxpr
signature = get_jaxpr_type_signature(jaxpr)
elif prim in key_reuse_signatures:
jaxpr = prim
funcname = str(prim)
jaxpr = None
signature = key_reuse_signatures[prim]
elif prim in key_reuse_signatures_dynamic:
funcname = str(prim)
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
signature = get_jaxpr_type_signature(jaxpr)
else:
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
signature.check_signature(*args, jaxpr=jaxpr)
signature.check_signature(*args, funcname=funcname)
result = original_rule(*args, **kwargs)
signature.update_consumption(args, result if prim.multiple_results else [result])
return result

View File

@ -22,7 +22,7 @@ from jax import core
import jax.numpy as jnp
from jax._src import prng
from jax._src import test_util as jtu
from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _core, KeyReuseError
@ -587,28 +587,29 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key at index 0 passed to function"
bits_msg = "In random_bits, key values a are already consumed."
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, key values a are already consumed."
def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with jax.disable_jit():
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
_ = jax.random.bits(key)
def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.random.bits(key)
_ = jax.jit(jax.random.bits)(key)
def test_key_reuse_within_jit(self):
@jax.jit
def f():
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg):
f()