[key reuse] simplify key reuse logic through context-free jaxpr evaluation

The args_consumed and forwarded_inputs context is not actually needed, because it can be checked
afterward. The only reason for this was to have more granular errors, but arguably it's better
to error on jaxpr input.
This commit is contained in:
Jake VanderPlas 2024-02-15 15:50:50 -08:00
parent 243e7edc56
commit 8eab599530
2 changed files with 39 additions and 65 deletions

View File

@ -16,7 +16,7 @@ from __future__ import annotations
from collections import defaultdict
from functools import reduce
from typing import Any, Callable, NamedTuple
from typing import Any, Callable
import jax
from jax import lax
@ -38,10 +38,24 @@ from jax.experimental.key_reuse._common import (
)
import numpy as np
def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value']
if not np.all(consumed == expected):
if np.all(expected):
raise AssertionError(f"Expected key to be consumed in {eqn}")
elif not np.any(expected):
raise AssertionError(f"Expected key to not be consumed in {eqn}")
else:
raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}")
# The behavior of most primitives can be described via simple signatures.
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
@ -50,6 +64,7 @@ key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)])
@ -67,7 +82,7 @@ key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([], [], [])
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
# The default signature will Sink all key inputs, and not Source any.
def unknown_signature(eqn, args_consumed):
def unknown_signature(eqn):
def is_key(var: core.Atom):
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
return KeyReuseSignature(
@ -75,11 +90,7 @@ def unknown_signature(eqn, args_consumed):
sources=[],
)
def get_jaxpr_type_signature(
jaxpr: core.Jaxpr,
consumed_inputs: list[bool | np.ndarray] | None = None,
forwarded_inputs: dict[int, int] | None = None,
) -> KeyReuseSignature:
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
"""Parse the jaxpr to determine key reuse signature"""
consumed: dict[core.Atom, bool | np.ndarray] = {}
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
@ -122,24 +133,18 @@ def get_jaxpr_type_signature(
return False
return consumed.get(var, False)
if forwarded_inputs:
for i, j in forwarded_inputs.items():
forwards[jaxpr.invars[i]] = jaxpr.invars[j]
if consumed_inputs:
for var, mask in util.safe_zip(jaxpr.invars, consumed_inputs):
if not isinstance(var, core.Literal):
source(var, mask)
for eqn in jaxpr.eqns:
if eqn.primitive in key_reuse_signatures:
signature = key_reuse_signatures[eqn.primitive]
elif eqn.primitive in key_reuse_signatures_dynamic:
args_consumed = [is_consumed(var) for var in eqn.invars]
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn, args_consumed)
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
else:
args_consumed = [is_consumed(var) for var in eqn.invars]
signature = unknown_signature(eqn, args_consumed)
signature = unknown_signature(eqn)
if eqn.primitive == assert_consumed_value_p:
# This is a special case that goes beyond normal key reuse logic.
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
for in_idx, out_idx in signature.forwards:
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
@ -187,8 +192,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
#----------------------------------------------------------------------------------
# key reuse rules for particular primitives:
def _slice_signature(eqn, args_consumed):
del args_consumed # unused here
def _slice_signature(eqn):
in_aval = eqn.invars[0].aval
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
return KeyReuseSignature([], [], [Forward(0, 0)])
@ -204,35 +208,13 @@ def _slice_signature(eqn, args_consumed):
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
def _pjit_key_type_signature(eqn, args_consumed):
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs)
return sig
def _pjit_key_type_signature(eqn):
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
def _assert_consumed_value_key_type_signature(eqn, args_consumed):
actual = args_consumed[0]
expected = eqn.params['value']
if not np.all(actual == expected):
if np.all(expected):
raise AssertionError(f"Expected key to be consumed in {eqn}")
elif not np.any(expected):
raise AssertionError(f"Expected key to not be consumed in {eqn}")
else:
raise AssertionError(f"Expected {expected}, got {actual} in {eqn}")
return KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures_dynamic[assert_consumed_value_p] = _assert_consumed_value_key_type_signature
def _cond_key_type_signature(eqn, args_consumed):
signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:])
for branch in eqn.params['branches']]
def _cond_key_type_signature(eqn):
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
sinks = defaultdict(list)
sources = defaultdict(list)
for sig in signatures:
@ -249,11 +231,11 @@ def _cond_key_type_signature(eqn, args_consumed):
key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature
def _scan_key_type_signature(eqn, args_consumed):
def _scan_key_type_signature(eqn):
jaxpr = eqn.params['jaxpr'].jaxpr
num_consts = eqn.params['num_consts']
num_carry = eqn.params['num_carry']
signature = get_jaxpr_type_signature(jaxpr, args_consumed)
signature = get_jaxpr_type_signature(jaxpr)
# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
@ -278,13 +260,12 @@ def _scan_key_type_signature(eqn, args_consumed):
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
def _while_key_type_signature(eqn, args_consumed):
def _while_key_type_signature(eqn):
cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
body_nconsts = eqn.params['body_nconsts']
# TODO(jakevdp): pass args_consumed here?
cond_signature = get_jaxpr_type_signature(cond_jaxpr)
body_signature = get_jaxpr_type_signature(body_jaxpr)
@ -320,7 +301,7 @@ def _while_key_type_signature(eqn, args_consumed):
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
def _remat_key_type_signature(eqn, args_consumed):
def _remat_key_type_signature(eqn):
# The assumption here is that the non-differentiated pass contains all relevant
# key usage, and the differentiated pass
# 1) will only consume keys that are already consumed in the non-differentiated pass
@ -328,13 +309,6 @@ def _remat_key_type_signature(eqn, args_consumed):
# Therefore, the differentiated pass is a no-op.
if eqn.params['differentiated']:
return KeyReuseSignature([], [])
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
sig = get_jaxpr_type_signature(jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs)
return sig
return get_jaxpr_type_signature(eqn.params['jaxpr'])
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature

View File

@ -335,7 +335,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
key = jax.random.key(0)
return jax.random.uniform(key) + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f)
def test_reuse_after_split(self):
@ -350,7 +350,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
_ = jax.random.split(key)
return jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad)
def f_bad_2():
@ -418,7 +418,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
r1 = jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key)
return r1 + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad, key, True)
# Check where only one branch consumes the key
@ -426,7 +426,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
r1 = jax.lax.cond(condition, jax.random.uniform, lambda key: 1.0, key)
return r1 + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad_2, key, True)
def test_simple_scan(self):