add mutable array ref error checks to cond and custom_vjp

This commit is contained in:
Matthew Johnson 2024-12-18 23:53:28 +00:00
parent 60ebde89e6
commit b6482f126e
6 changed files with 149 additions and 44 deletions

View File

@ -30,8 +30,10 @@ from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_arg_names)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -41,8 +43,8 @@ from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
treedef_children)
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
tree_leaves_with_path, keystr, treedef_children)
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
unzip2)
@ -608,9 +610,12 @@ class custom_vjp(Generic[ReturnValue]):
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.get_aval(x) for x in args_flat]
if config.mutable_array_checks.value:
f_ = _check_primal_refs(f_, self.nondiff_argnums)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_fwd, out_trees = _flatten_fwd(
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
@ -618,6 +623,37 @@ class custom_vjp(Generic[ReturnValue]):
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)
@lu.transformation2
def _check_primal_refs(f, nondiff_argnums, *args):
_check_for_aliased_refs(f, nondiff_argnums, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out
def _check_for_aliased_refs(f, nondiff_argnums, args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but custom_vjp function {f} got the same mutable "
f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and"
f" {arg_names[i]}.")
def _check_for_returned_refs(f, out, kind):
leaves = tree_leaves_with_path(out)
for path, leaf in leaves:
if isinstance((a := core.get_aval(leaf)), AbstractRef):
loc = f' at output tree path {keystr(path)}' if path else ''
raise ValueError(f"custom_vjp {kind} function {f} returned a mutable "
f"a array reference of type {a.str_short()}{loc}, "
"but mutable array references cannot be returned.")
@dataclasses.dataclass
class CustomVJPPrimal:
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
@ -655,14 +691,18 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)
@partial(lu.transformation_with_aux2, use_eq_store=True)
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
fwd_name, in_tree, maybe_out_type, *args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
if config.mutable_array_checks.value:
_check_for_aliased_refs(f, nondiff_argnums, py_args)
pair_out = f(*py_args)
if config.mutable_array_checks.value:
_check_for_returned_refs(f, pair_out, 'fwd')
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
@ -1393,8 +1433,8 @@ def optimize_remat_of_custom_vjp_fwd(
fwd_ = lu.wrap_init(fwd)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
in_tree, out_type)
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
primal_name, fwd_name, in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.get_aval(x) for x in args_flat]

View File

@ -539,6 +539,9 @@ class JVPTracer(Tracer):
def to_concrete_value(self):
return core.to_concrete_value(self.primal)
def get_referent(self):
return core.get_referent(self.primal)
def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()

View File

@ -2010,8 +2010,8 @@ class DynamicJaxprTrace(core.Trace):
def fwd_jaxpr_from_zeros(*zeros):
for store in fwd.stores: store and store.reset()
fwd_ = _interleave_fun(fwd, zeros)
jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals)
if atr: raise NotImplementedError
jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals)
if attrs: raise NotImplementedError
return jaxpr, consts
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@ -2154,14 +2154,14 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_refs(debug_info, out_tracers)
_check_no_returned_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_refs(
def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:

View File

@ -89,13 +89,6 @@ def _initial_style_jaxprs_with_common_consts(
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
# If we get a `Ref` in the consts, we know it must come from an outer
# `run_state`. We also know if shouldn't be boxed up in another tracer.
# We assert that it is in fact a DynamicJaxprTracer
for consts, consts_avals in zip(all_consts, all_const_avals):
for c, aval in zip(consts, consts_avals):
if isinstance(aval, state.AbstractRef):
assert isinstance(c, pe.DynamicJaxprTracer)
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.

View File

@ -25,6 +25,8 @@ from typing import Any, TypeVar
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src.api_util import (
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
from jax._src import config
from jax._src import core
from jax._src import dispatch
@ -136,8 +138,14 @@ def switch(index, branches: Sequence[Callable], *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
@ -228,11 +236,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
out_tree, false_out_tree = out_trees
if any(isinstance(out_aval, AbstractRef) for out_aval in

View File

@ -306,30 +306,88 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
ValueError, "traced for cond returned a mutable array reference of type"):
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
# TODO test_argument_aliases_cond
# TODO test_closure_and_argument_aliases_cond
def test_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
# TODO test_return_from_custom_jvp/vjp
# TODO test_argument_aliases_custom_jvp/vjp
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
def test_closure_and_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.cond(True,
lambda y_ref: x_ref[...] + y_ref[...],
lambda y_ref: x_ref[...] + y_ref[...],
x_ref)
# TODO(mattjj): enable when cond works with mutable arrays
# @parameterized.parameters([False, True])
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# # see also test_cond_with_ref_reuse in state_test.py
# x_ref = core.mutable_array(0.)
# def f(pred):
# def true_fun():
# x_ref[()] = 1.
# def false_fun():
# x_ref[()] = 2.
# jax.lax.cond(pred, true_fun, false_fun)
# if jit:
# f = jax.jit(f)
# out_true = f(True)
# self.assertAllClose(x_ref[...], 1.)
# out_false = f(False)
# self.assertAllClose(x_ref[...], 2.)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(ref):
return ref
f.defvjp(lambda ref: ..., lambda *_: ...)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp primal function"):
f(x_ref)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x, ref):
return x
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp fwd function"):
jax.vjp(f, 3., x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
f(x_ref, x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
jax.vjp(f, x_ref, x_ref)
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
@parameterized.parameters([False, True])
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# see also test_cond_with_ref_reuse in state_test.py
x_ref = core.mutable_array(0.)
def f(pred):
def true_fun():
x_ref[()] = 1.
def false_fun():
x_ref[()] = 2.
jax.lax.cond(pred, true_fun, false_fun)
if jit:
f = jax.jit(f)
out_true = f(True)
self.assertAllClose(x_ref[...], 1.)
out_false = f(False)
self.assertAllClose(x_ref[...], 2.)
if __name__ == '__main__':