mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] simplify key reuse logic through context-free jaxpr evaluation
The args_consumed and forwarded_inputs context is not actually needed, because it can be checked afterward. The only reason for this was to have more granular errors, but arguably it's better to error on jaxpr input.
This commit is contained in:
parent
243e7edc56
commit
8eab599530
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -38,10 +38,24 @@ from jax.experimental.key_reuse._common import (
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _check_consumed_value(eqn, consumed):
|
||||
"""Extra check for use with assert_consumed_value_p"""
|
||||
expected = eqn.params['value']
|
||||
if not np.all(consumed == expected):
|
||||
if np.all(expected):
|
||||
raise AssertionError(f"Expected key to be consumed in {eqn}")
|
||||
elif not np.any(expected):
|
||||
raise AssertionError(f"Expected key to not be consumed in {eqn}")
|
||||
else:
|
||||
raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}")
|
||||
|
||||
|
||||
# The behavior of most primitives can be described via simple signatures.
|
||||
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
|
||||
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
@ -50,6 +64,7 @@ key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
|
||||
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
|
||||
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
@ -67,7 +82,7 @@ key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([], [], [])
|
||||
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
|
||||
|
||||
# The default signature will Sink all key inputs, and not Source any.
|
||||
def unknown_signature(eqn, args_consumed):
|
||||
def unknown_signature(eqn):
|
||||
def is_key(var: core.Atom):
|
||||
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
|
||||
return KeyReuseSignature(
|
||||
@ -75,11 +90,7 @@ def unknown_signature(eqn, args_consumed):
|
||||
sources=[],
|
||||
)
|
||||
|
||||
def get_jaxpr_type_signature(
|
||||
jaxpr: core.Jaxpr,
|
||||
consumed_inputs: list[bool | np.ndarray] | None = None,
|
||||
forwarded_inputs: dict[int, int] | None = None,
|
||||
) -> KeyReuseSignature:
|
||||
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
"""Parse the jaxpr to determine key reuse signature"""
|
||||
consumed: dict[core.Atom, bool | np.ndarray] = {}
|
||||
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
|
||||
@ -122,24 +133,18 @@ def get_jaxpr_type_signature(
|
||||
return False
|
||||
return consumed.get(var, False)
|
||||
|
||||
if forwarded_inputs:
|
||||
for i, j in forwarded_inputs.items():
|
||||
forwards[jaxpr.invars[i]] = jaxpr.invars[j]
|
||||
|
||||
if consumed_inputs:
|
||||
for var, mask in util.safe_zip(jaxpr.invars, consumed_inputs):
|
||||
if not isinstance(var, core.Literal):
|
||||
source(var, mask)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in key_reuse_signatures:
|
||||
signature = key_reuse_signatures[eqn.primitive]
|
||||
elif eqn.primitive in key_reuse_signatures_dynamic:
|
||||
args_consumed = [is_consumed(var) for var in eqn.invars]
|
||||
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn, args_consumed)
|
||||
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
|
||||
else:
|
||||
args_consumed = [is_consumed(var) for var in eqn.invars]
|
||||
signature = unknown_signature(eqn, args_consumed)
|
||||
signature = unknown_signature(eqn)
|
||||
|
||||
if eqn.primitive == assert_consumed_value_p:
|
||||
# This is a special case that goes beyond normal key reuse logic.
|
||||
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
|
||||
|
||||
for in_idx, out_idx in signature.forwards:
|
||||
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
|
||||
|
||||
@ -187,8 +192,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
|
||||
#----------------------------------------------------------------------------------
|
||||
# key reuse rules for particular primitives:
|
||||
|
||||
def _slice_signature(eqn, args_consumed):
|
||||
del args_consumed # unused here
|
||||
def _slice_signature(eqn):
|
||||
in_aval = eqn.invars[0].aval
|
||||
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
|
||||
return KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
@ -204,35 +208,13 @@ def _slice_signature(eqn, args_consumed):
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
||||
def _pjit_key_type_signature(eqn, args_consumed):
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
|
||||
if var in eqn.invars[:i]}
|
||||
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
|
||||
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
|
||||
# Double consumption detected: re-trace with context for better errors.
|
||||
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs)
|
||||
return sig
|
||||
def _pjit_key_type_signature(eqn):
|
||||
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
|
||||
|
||||
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
|
||||
|
||||
def _assert_consumed_value_key_type_signature(eqn, args_consumed):
|
||||
actual = args_consumed[0]
|
||||
expected = eqn.params['value']
|
||||
if not np.all(actual == expected):
|
||||
if np.all(expected):
|
||||
raise AssertionError(f"Expected key to be consumed in {eqn}")
|
||||
elif not np.any(expected):
|
||||
raise AssertionError(f"Expected key to not be consumed in {eqn}")
|
||||
else:
|
||||
raise AssertionError(f"Expected {expected}, got {actual} in {eqn}")
|
||||
return KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
|
||||
key_reuse_signatures_dynamic[assert_consumed_value_p] = _assert_consumed_value_key_type_signature
|
||||
|
||||
def _cond_key_type_signature(eqn, args_consumed):
|
||||
signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:])
|
||||
for branch in eqn.params['branches']]
|
||||
def _cond_key_type_signature(eqn):
|
||||
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
|
||||
sinks = defaultdict(list)
|
||||
sources = defaultdict(list)
|
||||
for sig in signatures:
|
||||
@ -249,11 +231,11 @@ def _cond_key_type_signature(eqn, args_consumed):
|
||||
|
||||
key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature
|
||||
|
||||
def _scan_key_type_signature(eqn, args_consumed):
|
||||
def _scan_key_type_signature(eqn):
|
||||
jaxpr = eqn.params['jaxpr'].jaxpr
|
||||
num_consts = eqn.params['num_consts']
|
||||
num_carry = eqn.params['num_carry']
|
||||
signature = get_jaxpr_type_signature(jaxpr, args_consumed)
|
||||
signature = get_jaxpr_type_signature(jaxpr)
|
||||
|
||||
# scan body should not consume key in constants
|
||||
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
|
||||
@ -278,13 +260,12 @@ def _scan_key_type_signature(eqn, args_consumed):
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
|
||||
|
||||
def _while_key_type_signature(eqn, args_consumed):
|
||||
def _while_key_type_signature(eqn):
|
||||
cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr
|
||||
cond_nconsts = eqn.params['cond_nconsts']
|
||||
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
|
||||
body_nconsts = eqn.params['body_nconsts']
|
||||
|
||||
# TODO(jakevdp): pass args_consumed here?
|
||||
cond_signature = get_jaxpr_type_signature(cond_jaxpr)
|
||||
body_signature = get_jaxpr_type_signature(body_jaxpr)
|
||||
|
||||
@ -320,7 +301,7 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
|
||||
|
||||
def _remat_key_type_signature(eqn, args_consumed):
|
||||
def _remat_key_type_signature(eqn):
|
||||
# The assumption here is that the non-differentiated pass contains all relevant
|
||||
# key usage, and the differentiated pass
|
||||
# 1) will only consume keys that are already consumed in the non-differentiated pass
|
||||
@ -328,13 +309,6 @@ def _remat_key_type_signature(eqn, args_consumed):
|
||||
# Therefore, the differentiated pass is a no-op.
|
||||
if eqn.params['differentiated']:
|
||||
return KeyReuseSignature([], [])
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
|
||||
if var in eqn.invars[:i]}
|
||||
sig = get_jaxpr_type_signature(jaxpr)
|
||||
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
|
||||
# Double consumption detected: re-trace with context for better errors.
|
||||
get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs)
|
||||
return sig
|
||||
return get_jaxpr_type_signature(eqn.params['jaxpr'])
|
||||
|
||||
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
|
||||
|
@ -335,7 +335,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
key = jax.random.key(0)
|
||||
return jax.random.uniform(key) + jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
|
||||
self.check_key_reuse(f)
|
||||
|
||||
def test_reuse_after_split(self):
|
||||
@ -350,7 +350,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
_ = jax.random.split(key)
|
||||
return jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
|
||||
self.check_key_reuse(f_bad)
|
||||
|
||||
def f_bad_2():
|
||||
@ -418,7 +418,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
r1 = jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key)
|
||||
return r1 + jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
|
||||
self.check_key_reuse(f_bad, key, True)
|
||||
|
||||
# Check where only one branch consumes the key
|
||||
@ -426,7 +426,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
r1 = jax.lax.cond(condition, jax.random.uniform, lambda key: 1.0, key)
|
||||
return r1 + jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
|
||||
self.check_key_reuse(f_bad_2, key, True)
|
||||
|
||||
def test_simple_scan(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user