mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add mutable array ref error checks to cond and custom_vjp
This commit is contained in:
parent
60ebde89e6
commit
b6482f126e
@ -30,8 +30,10 @@ from jax._src import traceback_util
|
||||
from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_arg_names)
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -41,8 +43,8 @@ from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.tree_util import (
|
||||
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
|
||||
treedef_children)
|
||||
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
|
||||
tree_leaves_with_path, keystr, treedef_children)
|
||||
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
|
||||
unzip2)
|
||||
|
||||
@ -608,9 +610,12 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
if config.mutable_array_checks.value:
|
||||
f_ = _check_primal_refs(f_, self.nondiff_argnums)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_fwd, out_trees = _flatten_fwd(
|
||||
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees,
|
||||
@ -618,6 +623,37 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@lu.transformation2
|
||||
def _check_primal_refs(f, nondiff_argnums, *args):
|
||||
_check_for_aliased_refs(f, nondiff_argnums, args)
|
||||
out = f(*args)
|
||||
_check_for_returned_refs(f, out, 'primal')
|
||||
return out
|
||||
|
||||
def _check_for_aliased_refs(f, nondiff_argnums, args):
|
||||
leaves = tree_leaves(args)
|
||||
refs: dict[int, int] = {}
|
||||
for i, x in enumerate(leaves):
|
||||
if (isinstance((a := core.get_aval(x)), AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if arg_names is None:
|
||||
arg_names = [f'flat index {j}' for j in range(len(leaves))]
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but custom_vjp function {f} got the same mutable "
|
||||
f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and"
|
||||
f" {arg_names[i]}.")
|
||||
|
||||
def _check_for_returned_refs(f, out, kind):
|
||||
leaves = tree_leaves_with_path(out)
|
||||
for path, leaf in leaves:
|
||||
if isinstance((a := core.get_aval(leaf)), AbstractRef):
|
||||
loc = f' at output tree path {keystr(path)}' if path else ''
|
||||
raise ValueError(f"custom_vjp {kind} function {f} returned a mutable "
|
||||
f"a array reference of type {a.str_short()}{loc}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomVJPPrimal:
|
||||
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
|
||||
@ -655,14 +691,18 @@ def _check_for_tracers(x):
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
@partial(lu.transformation_with_aux2, use_eq_store=True)
|
||||
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
*args):
|
||||
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, maybe_out_type, *args):
|
||||
if symbolic_zeros:
|
||||
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
||||
else:
|
||||
args = args[::2]
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_aliased_refs(f, nondiff_argnums, py_args)
|
||||
pair_out = f(*py_args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_returned_refs(f, pair_out, 'fwd')
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
||||
"must produce a pair (list or tuple of length two) where the first "
|
||||
@ -1393,8 +1433,8 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
fwd_ = lu.wrap_init(fwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
|
||||
in_tree, out_type)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
|
||||
primal_name, fwd_name, in_tree, out_type)
|
||||
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
|
@ -539,6 +539,9 @@ class JVPTracer(Tracer):
|
||||
def to_concrete_value(self):
|
||||
return core.to_concrete_value(self.primal)
|
||||
|
||||
def get_referent(self):
|
||||
return core.get_referent(self.primal)
|
||||
|
||||
def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = get_aval(primal).strip_weak_type()
|
||||
|
@ -2010,8 +2010,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def fwd_jaxpr_from_zeros(*zeros):
|
||||
for store in fwd.stores: store and store.reset()
|
||||
fwd_ = _interleave_fun(fwd, zeros)
|
||||
jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals)
|
||||
if atr: raise NotImplementedError
|
||||
jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals)
|
||||
if attrs: raise NotImplementedError
|
||||
return jaxpr, consts
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
@ -2154,14 +2154,14 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
_check_no_refs(debug_info, out_tracers)
|
||||
_check_no_returned_refs(debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_refs(
|
||||
def _check_no_returned_refs(
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
|
@ -89,13 +89,6 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
|
||||
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
|
||||
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
|
||||
# If we get a `Ref` in the consts, we know it must come from an outer
|
||||
# `run_state`. We also know if shouldn't be boxed up in another tracer.
|
||||
# We assert that it is in fact a DynamicJaxprTracer
|
||||
for consts, consts_avals in zip(all_consts, all_const_avals):
|
||||
for c, aval in zip(consts, consts_avals):
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
assert isinstance(c, pe.DynamicJaxprTracer)
|
||||
|
||||
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.
|
||||
|
||||
|
@ -25,6 +25,8 @@ from typing import Any, TypeVar
|
||||
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src.api_util import (
|
||||
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -136,8 +138,14 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
@ -228,11 +236,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
|
||||
raise ValueError("Cannot pass `Ref`s into `cond`.")
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
|
||||
|
||||
out_tree, false_out_tree = out_trees
|
||||
if any(isinstance(out_aval, AbstractRef) for out_aval in
|
||||
|
@ -306,30 +306,88 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
ValueError, "traced for cond returned a mutable array reference of type"):
|
||||
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
|
||||
|
||||
# TODO test_argument_aliases_cond
|
||||
# TODO test_closure_and_argument_aliases_cond
|
||||
def test_argument_aliases_cond(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
|
||||
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
|
||||
|
||||
# TODO test_return_from_custom_jvp/vjp
|
||||
# TODO test_argument_aliases_custom_jvp/vjp
|
||||
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
|
||||
def test_closure_and_argument_aliases_cond(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"closed over and passed as the argument y_ref"):
|
||||
jax.lax.cond(True,
|
||||
lambda y_ref: x_ref[...] + y_ref[...],
|
||||
lambda y_ref: x_ref[...] + y_ref[...],
|
||||
x_ref)
|
||||
|
||||
# TODO(mattjj): enable when cond works with mutable arrays
|
||||
# @parameterized.parameters([False, True])
|
||||
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
||||
# # see also test_cond_with_ref_reuse in state_test.py
|
||||
# x_ref = core.mutable_array(0.)
|
||||
# def f(pred):
|
||||
# def true_fun():
|
||||
# x_ref[()] = 1.
|
||||
# def false_fun():
|
||||
# x_ref[()] = 2.
|
||||
# jax.lax.cond(pred, true_fun, false_fun)
|
||||
# if jit:
|
||||
# f = jax.jit(f)
|
||||
# out_true = f(True)
|
||||
# self.assertAllClose(x_ref[...], 1.)
|
||||
# out_false = f(False)
|
||||
# self.assertAllClose(x_ref[...], 2.)
|
||||
@parameterized.parameters([False, True])
|
||||
def test_return_from_custom_vjp_primal(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(ref):
|
||||
return ref
|
||||
f.defvjp(lambda ref: ..., lambda *_: ...)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "custom_vjp primal function"):
|
||||
f(x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_return_from_custom_vjp_fwd(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x, ref):
|
||||
return x
|
||||
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "custom_vjp fwd function"):
|
||||
jax.vjp(f, 3., x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_argument_aliases_custom_vjp_primal(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x_ref, y_ref):
|
||||
...
|
||||
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
f(x_ref, x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_argument_aliases_custom_vjp_fwd(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x_ref, y_ref):
|
||||
...
|
||||
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
jax.vjp(f, x_ref, x_ref)
|
||||
|
||||
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
||||
# see also test_cond_with_ref_reuse in state_test.py
|
||||
x_ref = core.mutable_array(0.)
|
||||
def f(pred):
|
||||
def true_fun():
|
||||
x_ref[()] = 1.
|
||||
def false_fun():
|
||||
x_ref[()] = 2.
|
||||
jax.lax.cond(pred, true_fun, false_fun)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
out_true = f(True)
|
||||
self.assertAllClose(x_ref[...], 1.)
|
||||
out_false = f(False)
|
||||
self.assertAllClose(x_ref[...], 2.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user