mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
ref errors
This commit is contained in:
parent
3262770227
commit
42ac4ca357
@ -1479,6 +1479,12 @@ custom_vjp_disable_shape_check = bool_state(
|
||||
upgrade=True,
|
||||
help='Disable the check from #19009 to enable some custom_vjp hacks.')
|
||||
|
||||
mutable_array_checks = bool_state(
|
||||
name='jax_mutable_array_checks',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help='Enable error checks for mutable arrays that rule out aliasing.')
|
||||
|
||||
xla_runtime_errors = bool_state(
|
||||
name='jax_experimental_unsafe_xla_runtime_errors',
|
||||
default=False,
|
||||
|
@ -1917,6 +1917,8 @@ def mutable_array_abstract_eval(init_aval):
|
||||
def _mutable_array_impl(init_val):
|
||||
from jax._src.state.types import AbstractRef # pytype: disable=import-error
|
||||
aval = get_aval(init_val)
|
||||
# TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep
|
||||
init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val
|
||||
return MutableArray(AbstractRef(aval), init_val)
|
||||
|
||||
def freeze(ref):
|
||||
|
@ -986,21 +986,11 @@ def partial_eval_jaxpr_custom(
|
||||
ensure_out_inst: bool | Sequence[bool],
|
||||
saveable: Callable[..., RematCases_],
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
|
||||
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
|
||||
tuple(in_inst),
|
||||
tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
if num_res_ref > 0:
|
||||
raise ValueError(
|
||||
"Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res
|
||||
*outs, num_res_ref = partial_eval_jaxpr_stateful(
|
||||
jaxpr, in_unknowns, in_inst, ensure_out_unknowns, ensure_out_inst, saveable)
|
||||
if num_res_ref:
|
||||
raise ValueError("Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
|
||||
return *outs, # type: ignore
|
||||
|
||||
def partial_eval_jaxpr_stateful(
|
||||
jaxpr: Jaxpr,
|
||||
@ -1019,10 +1009,9 @@ def partial_eval_jaxpr_stateful(
|
||||
if saveable is None:
|
||||
saveable = everything_saveable
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
|
||||
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
|
||||
tuple(in_inst),
|
||||
tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
_partial_eval_jaxpr_custom_cached(
|
||||
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
|
||||
|
||||
everything_saveable = lambda *_, **__: True
|
||||
@ -2165,12 +2154,45 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
_check_no_refs(debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_refs(
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
for i, t in enumerate(out_tracers):
|
||||
a = t.aval
|
||||
if isinstance(a, AbstractRef):
|
||||
if dbg is None:
|
||||
raise ValueError(
|
||||
f"function returned a mutable array reference of type {a.str_short()}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
loc = (f' at output tree path {keystr(ls[i])}' # type: ignore
|
||||
if dbg.result_paths and (ls := dbg.result_paths()) and ls[i] else '')
|
||||
frame = t._trace.frame
|
||||
v = frame.tracer_to_var.get(id(t))
|
||||
eqn = next((e for e in frame.eqns if v in e.outvars), None)
|
||||
if eqn:
|
||||
assert eqn.primitive is core.mutable_array_p
|
||||
origin_info = ('\n\nThe returned mutable array was created on line '
|
||||
f'{source_info_util.summarize(eqn.source_info)}.')
|
||||
elif v in frame.invars:
|
||||
arg_name = dbg.arg_names[frame.invars.index(v)]
|
||||
origin_info = ('\n\nThe returned mutable array was passed in as the '
|
||||
f'argument {arg_name}.')
|
||||
else:
|
||||
origin_info = ''
|
||||
raise ValueError(
|
||||
f"function {dbg.func_src_info} traced for {dbg.traced_for} returned "
|
||||
f"a mutable array reference of type {a.str_short()}{loc}, but "
|
||||
f"mutable array references cannot be returned.{origin_info}")
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
|
||||
|
117
jax/_src/pjit.py
117
jax/_src/pjit.py
@ -556,17 +556,14 @@ def _infer_params_impl(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
if pjit_mesh is not None:
|
||||
jit_name = 'pjit'
|
||||
if (ji.backend or ji.device) and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
else:
|
||||
jit_name = 'jit'
|
||||
|
||||
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
|
||||
|
||||
dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
ji.static_argnums, ji.static_argnames)
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
@ -593,6 +590,7 @@ def _infer_params_impl(
|
||||
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
|
||||
in_shardings_treedef = out_shardings_treedef = treedef
|
||||
else:
|
||||
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
|
||||
in_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
|
||||
for x in ji.in_shardings_leaves)
|
||||
@ -607,35 +605,12 @@ def _infer_params_impl(
|
||||
|
||||
in_type: core.InputType | tuple[core.AbstractValue, ...]
|
||||
if config.dynamic_shapes.value:
|
||||
assert in_avals is None
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||
in_avals = tuple(a for a, e in in_type if e)
|
||||
elif in_avals is None:
|
||||
avals = []
|
||||
for i, a in enumerate(explicit_args):
|
||||
try:
|
||||
avals.append(shaped_abstractify(a))
|
||||
except OverflowError as e:
|
||||
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
|
||||
else f"flattened argument number is {i}")
|
||||
raise OverflowError(
|
||||
"An overflow was encountered while parsing an argument to a jitted "
|
||||
f"computation, whose {arg_path}."
|
||||
) from e
|
||||
except TypeError as e:
|
||||
arg_description = (f"path {dbg.arg_names[i]}" if dbg
|
||||
else f"flattened argument number {i}")
|
||||
raise TypeError(
|
||||
f"Error interpreting argument to {fun} as an abstract array."
|
||||
f" The problematic value is of type {type(a)} and was passed to"
|
||||
f" the function at {arg_description}.\n"
|
||||
"This typically means that a jit-wrapped function was called with a non-array"
|
||||
" argument, and this argument was not marked as static using the"
|
||||
" static_argnums or static_argnames parameters of jax.jit."
|
||||
) from e
|
||||
|
||||
in_type = in_avals = tuple(avals)
|
||||
else:
|
||||
in_type = in_avals
|
||||
in_type = in_avals # type: ignore
|
||||
assert in_avals is not None
|
||||
|
||||
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves,
|
||||
@ -652,6 +627,7 @@ def _infer_params_impl(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
HashableFunction(res_paths, closure=()),
|
||||
IgnoreKey(ji.inline))
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
|
||||
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
|
||||
|
||||
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
@ -693,7 +669,6 @@ def _infer_params_impl(
|
||||
donated_invars, dbg.arg_names if dbg else None, len(consts),
|
||||
attrs_tracked, abstract_mesh), args_flat
|
||||
|
||||
|
||||
def get_abstract_mesh_from_avals(in_avals):
|
||||
if not config.sharding_in_types.value:
|
||||
return None
|
||||
@ -711,9 +686,7 @@ def get_abstract_mesh_from_avals(in_avals):
|
||||
class InferParamsCacheEntry:
|
||||
"""Mutable value object for _infer_params_cached."""
|
||||
__slots__ = ['pjit_params']
|
||||
|
||||
pjit_params: PjitParams | None
|
||||
|
||||
def __init__(self):
|
||||
self.pjit_params = None
|
||||
|
||||
@ -747,34 +720,76 @@ def _infer_params(
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
|
||||
skip_cache = config.dynamic_shapes.value
|
||||
if not skip_cache:
|
||||
signature, dynargs = jax_jit.parse_arguments(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
try:
|
||||
avals = tuple(shaped_abstractify(a) for a in dynargs)
|
||||
except (OverflowError, TypeError):
|
||||
# If we see something we don't understand, use the slow path.
|
||||
skip_cache = True
|
||||
|
||||
if skip_cache:
|
||||
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
|
||||
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
|
||||
kwargs, in_avals=None)
|
||||
return p, p.consts + args_flat
|
||||
|
||||
entry = _infer_params_cached(
|
||||
fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
signature, dynargs = jax_jit.parse_arguments(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
ji.static_argnums, ji.static_argnames)
|
||||
avals = _infer_input_type(fun, dbg, dynargs)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
if entry.pjit_params is None:
|
||||
p, args_flat = _infer_params_impl(
|
||||
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
|
||||
if p.attrs_tracked:
|
||||
# If there are attrs_tracked, don't use the cache.
|
||||
if p.attrs_tracked: # if attrs, don't popoulate the cache
|
||||
return p, p.consts + args_flat
|
||||
else:
|
||||
entry.pjit_params = p
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
avals = []
|
||||
try:
|
||||
for i, x in enumerate(explicit_args):
|
||||
avals.append(shaped_abstractify(x))
|
||||
except OverflowError:
|
||||
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg # type: ignore
|
||||
else f"flattened argument number is {i}") # type: ignore
|
||||
raise OverflowError(
|
||||
"An overflow was encountered while parsing an argument to a jitted "
|
||||
f"computation, whose {arg_path}."
|
||||
) from None
|
||||
except TypeError:
|
||||
arg_description = (f"path {dbg.arg_names[i]}" if dbg # type: ignore
|
||||
else f"flattened argument number {i}") # type: ignore
|
||||
raise TypeError(
|
||||
f"Error interpreting argument to {fun} as an abstract array."
|
||||
f" The problematic value is of type {type(x)} and was passed to" # type: ignore
|
||||
f" the function at {arg_description}.\n"
|
||||
"This typically means that a jit-wrapped function was called with a non-array"
|
||||
" argument, and this argument was not marked as static using the"
|
||||
" static_argnums or static_argnames parameters of jax.jit."
|
||||
) from None
|
||||
if config.mutable_array_checks.value:
|
||||
# TODO(mattjj): make this faster
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, explicit_args)):
|
||||
if (isinstance(a, AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
|
||||
f"the mutable array reference of type {a.str_short()} appeared at both "
|
||||
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
return tuple(avals)
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
for i, x in enumerate(args):
|
||||
if id(core.get_referent(x)) in refs:
|
||||
a = shaped_abstractify(x)
|
||||
raise ValueError(
|
||||
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
||||
f"array reference of type {a.str_short()} was both closed over and "
|
||||
f"passed as the argument "
|
||||
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
|
@ -307,6 +307,7 @@ class AttrsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(thing.x, 1024., check_dtypes=False)
|
||||
|
||||
def test_arg_to_jit(self):
|
||||
self.skipTest("regressed this experimental feature") # TODO(mattjj)
|
||||
thing = Thing(1.0)
|
||||
count = 0
|
||||
|
||||
|
@ -48,13 +48,6 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(f)(x_mut)
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
# disabling this test for now. TODO(dougalm): re-enable once we add checks to
|
||||
# ensure mutable arrays aren't returned or duplicated etc.
|
||||
# def test_staging_error(self):
|
||||
# x = jnp.zeros(3)
|
||||
# with self.assertRaises(Exception):
|
||||
# jax.jit(core.mutable_array)(x)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_multiple_inputs_and_outputs(self, jit):
|
||||
def f(x_mut, y, z_mut, w):
|
||||
@ -244,6 +237,91 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
expected = 2.0
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_defensive_copy(self):
|
||||
x = jnp.arange(3.)
|
||||
_ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x))
|
||||
x + 1 # don't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_mutable_array_checks=True)
|
||||
class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
def test_return_from_jit(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"traced for jit returned a mutable array reference.*\n\n"
|
||||
r".*was created on line"):
|
||||
jax.jit(core.mutable_array)(jnp.arange(3))
|
||||
|
||||
def test_return_from_jit_arg(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"traced for jit returned a mutable array reference.*\n\n"
|
||||
r".*was passed in as the argument x_ref"):
|
||||
jax.jit(lambda x_ref: x_ref)(core.mutable_array(jnp.arange(3)))
|
||||
|
||||
def test_return_from_jit_pytree(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"tree path \['hi'\]"):
|
||||
jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3)))
|
||||
|
||||
def test_return_from_jit_closure(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"tree path \['hi'\]"):
|
||||
x_ref = core.mutable_array(jnp.arange(3))
|
||||
jax.jit(lambda: {'hi': x_ref})()
|
||||
|
||||
def test_argument_aliases_jit(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "appeared at both x_ref and y_ref"):
|
||||
jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref)
|
||||
|
||||
def test_closure_and_argument_aliases_jit(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "closed over and passed as the argument y_ref"):
|
||||
jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref)
|
||||
|
||||
def test_return_from_scan(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "traced for scan returned a mutable array reference of type"):
|
||||
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
|
||||
|
||||
# TODO test_argument_aliases_scan
|
||||
# TODO test_closure_and_argument_aliases_scan
|
||||
|
||||
def test_return_from_cond(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "traced for cond returned a mutable array reference of type"):
|
||||
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
|
||||
|
||||
# TODO test_argument_aliases_cond
|
||||
# TODO test_closure_and_argument_aliases_cond
|
||||
|
||||
# TODO test_return_from_custom_jvp/vjp
|
||||
# TODO test_argument_aliases_custom_jvp/vjp
|
||||
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
|
||||
|
||||
# TODO(mattjj): enable when cond works with mutable arrays
|
||||
# @parameterized.parameters([False, True])
|
||||
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
||||
# # see also test_cond_with_ref_reuse in state_test.py
|
||||
# x_ref = core.mutable_array(0.)
|
||||
# def f(pred):
|
||||
# def true_fun():
|
||||
# x_ref[()] = 1.
|
||||
# def false_fun():
|
||||
# x_ref[()] = 2.
|
||||
# jax.lax.cond(pred, true_fun, false_fun)
|
||||
# if jit:
|
||||
# f = jax.jit(f)
|
||||
# out_true = f(True)
|
||||
# self.assertAllClose(x_ref[...], 1.)
|
||||
# out_false = f(False)
|
||||
# self.assertAllClose(x_ref[...], 2.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user