add mutable array ref error checks to scan

This commit is contained in:
Matthew Johnson 2024-12-18 22:11:25 +00:00
parent 74eca1346d
commit e52856261f
5 changed files with 59 additions and 32 deletions

View File

@ -344,6 +344,7 @@ pytype_strict_library(
":traceback_util",
":tree_util",
":util",
":state_types",
] + py_deps("numpy"),
)

View File

@ -23,7 +23,9 @@ from typing import Any
import numpy as np
from jax._src import core
from jax._src import config
from jax._src import dtypes
from jax._src.state.types import AbstractRef
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
@ -737,3 +739,31 @@ class _HashableByObjectId:
def register_class_with_attrs(t: type) -> None:
_class_with_attrs.add(t)
_class_with_attrs: set[type] = set()
# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")

View File

@ -34,7 +34,9 @@ from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.api_util import (
shaped_abstractify, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -271,13 +273,20 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
xs_avals = [core.get_aval(x) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat)
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."

View File

@ -49,7 +49,8 @@ from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
hoist_obj_attrs)
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla
@ -627,7 +628,8 @@ def _infer_params_impl(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]
" static_argnums or static_argnames parameters of jax.jit."
) from None
if config.mutable_array_checks.value:
# TODO(mattjj): make this faster
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, explicit_args)):
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
_check_no_aliased_ref_args(dbg, avals, explicit_args)
return tuple(avals)
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
if not config.mutable_array_checks.value: return
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
explicit_args: Sequence[Any]

View File

@ -206,7 +206,6 @@ class MutableArrayTest(jtu.JaxTestCase):
def body_fun(_, index_x):
(index, x) = index_x
x[...] += index
# breakpoint()
return ((), x[...])
x_mut = core.mutable_array(np.arange(5))
@ -289,8 +288,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
ValueError, "traced for scan returned a mutable array reference of type"):
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
# TODO test_argument_aliases_scan
# TODO test_closure_and_argument_aliases_scan
def test_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"appeared at both c\[0\] and c\[1\]"):
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)
def test_closure_and_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
None, length=1)
def test_return_from_cond(self):
with self.assertRaisesRegex(