mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add mutable array ref error checks to scan
This commit is contained in:
parent
74eca1346d
commit
e52856261f
@ -344,6 +344,7 @@ pytype_strict_library(
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
":state_types",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,9 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.tree_util import (
|
||||
@ -737,3 +739,31 @@ class _HashableByObjectId:
|
||||
def register_class_with_attrs(t: type) -> None:
|
||||
_class_with_attrs.add(t)
|
||||
_class_with_attrs: set[type] = set()
|
||||
|
||||
# TODO(mattjj): make this function faster
|
||||
def _check_no_aliased_ref_args(dbg, avals, args):
|
||||
assert config.mutable_array_checks.value
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, args)):
|
||||
if (isinstance(a, AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
|
||||
f"the mutable array reference of type {a.str_short()} appeared at both "
|
||||
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
assert config.mutable_array_checks.value
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
for i, x in enumerate(args):
|
||||
if id(core.get_referent(x)) in refs:
|
||||
a = shaped_abstractify(x)
|
||||
raise ValueError(
|
||||
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
||||
f"array reference of type {a.str_short()} was both closed over and "
|
||||
f"passed as the argument "
|
||||
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
|
||||
|
@ -34,7 +34,9 @@ from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.api_util import (
|
||||
shaped_abstractify, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -271,13 +273,20 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
xs_avals = [core.get_aval(x) for x in xs_flat]
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
|
||||
in_avals = tuple(_map(core.get_aval, in_flat))
|
||||
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
|
||||
|
||||
def _create_jaxpr(init):
|
||||
init_flat, init_tree = tree_flatten(init)
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(_map(core.get_aval, init_flat))
|
||||
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
|
||||
f, in_tree, (*carry_avals, *x_avals), "scan")
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat)
|
||||
out_tree_children = out_tree.children()
|
||||
if len(out_tree_children) != 2:
|
||||
msg = "scan body output must be a pair, got {}."
|
||||
|
@ -49,7 +49,8 @@ from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
|
||||
hoist_obj_attrs)
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.interpreters import xla
|
||||
@ -627,7 +628,8 @@ def _infer_params_impl(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
HashableFunction(res_paths, closure=()),
|
||||
IgnoreKey(ji.inline))
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
|
||||
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
|
||||
|
||||
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]
|
||||
" static_argnums or static_argnames parameters of jax.jit."
|
||||
) from None
|
||||
if config.mutable_array_checks.value:
|
||||
# TODO(mattjj): make this faster
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, explicit_args)):
|
||||
if (isinstance(a, AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
|
||||
f"the mutable array reference of type {a.str_short()} appeared at both "
|
||||
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
_check_no_aliased_ref_args(dbg, avals, explicit_args)
|
||||
return tuple(avals)
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
for i, x in enumerate(args):
|
||||
if id(core.get_referent(x)) in refs:
|
||||
a = shaped_abstractify(x)
|
||||
raise ValueError(
|
||||
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
||||
f"array reference of type {a.str_short()} was both closed over and "
|
||||
f"passed as the argument "
|
||||
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
explicit_args: Sequence[Any]
|
||||
|
@ -206,7 +206,6 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
def body_fun(_, index_x):
|
||||
(index, x) = index_x
|
||||
x[...] += index
|
||||
# breakpoint()
|
||||
return ((), x[...])
|
||||
|
||||
x_mut = core.mutable_array(np.arange(5))
|
||||
@ -289,8 +288,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
ValueError, "traced for scan returned a mutable array reference of type"):
|
||||
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
|
||||
|
||||
# TODO test_argument_aliases_scan
|
||||
# TODO test_closure_and_argument_aliases_scan
|
||||
def test_argument_aliases_scan(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"appeared at both c\[0\] and c\[1\]"):
|
||||
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)
|
||||
|
||||
def test_closure_and_argument_aliases_scan(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"closed over and passed as the argument y_ref"):
|
||||
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
|
||||
None, length=1)
|
||||
|
||||
def test_return_from_cond(self):
|
||||
with self.assertRaisesRegex(
|
||||
|
Loading…
x
Reference in New Issue
Block a user