Typecheck avals and sharding for arguments that were DCE'd.

This keeps the promise of AOT that recompilation is guaranteed.

Fixes https://github.com/google/jax/issues/18686

PiperOrigin-RevId: 585855658
This commit is contained in:
Yash Katariya 2023-11-27 22:38:46 -08:00 committed by jax authors
parent 37f11428a3
commit 88d980f164
3 changed files with 93 additions and 13 deletions

View File

@ -1875,6 +1875,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
class AllArgsInfo(NamedTuple):
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
in_shardings: Any
debug_info: core.JaxprDebugInfo | None
@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
@ -1904,6 +1912,9 @@ def lower_sharding_computation(
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
closed_jaxpr.jaxpr.debug_info)
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _dce_jaxpr(
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
@ -2004,7 +2015,8 @@ def lower_sharding_computation(
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=shape_poly_state,
all_default_mem_kind=all_default_mem_kind)
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info)
def _to_logical_sharding(
@ -2090,6 +2102,8 @@ def lower_mesh_computation(
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
@ -2179,7 +2193,8 @@ def lower_mesh_computation(
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)
shape_poly_state=lowering_result.shape_poly_state,
all_args_info=all_args_info)
class MeshComputation(stages.XlaLowering):
_hlo: ir.Module | None
@ -2568,6 +2583,7 @@ class UnloadedMeshExecutable:
jaxpr_debug_info: core.JaxprDebugInfo | None
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
all_args_info: AllArgsInfo | None
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
@ -2590,7 +2606,7 @@ class UnloadedMeshExecutable:
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
self.jaxpr_debug_info, self)
self.jaxpr_debug_info, self.all_args_info, self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
@ -2618,6 +2634,7 @@ class UnloadedMeshExecutable:
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
compiler_options=None,
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
@ -2710,7 +2727,8 @@ class UnloadedMeshExecutable:
auto_spmd_lowering=auto_spmd_lowering,
jaxpr_debug_info=jaxpr_debug_info,
in_layouts=in_layouts, # type: ignore
out_layouts=out_layouts).load() # type: ignore
out_layouts=out_layouts, # type: ignore
all_args_info=all_args_info).load() # type: ignore
class MeshExecutableFastpathData(NamedTuple):
@ -2735,12 +2753,14 @@ class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable",
"_in_layouts", "_out_layouts", "_jaxpr_debug_info",
"_all_args_info", "_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
in_layouts, out_layouts, jaxpr_debug_info=None,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
@ -2755,13 +2775,14 @@ class MeshExecutable(stages.XlaExecutable):
self._in_layouts = in_layouts
self._out_layouts = out_layouts
self._jaxpr_debug_info = jaxpr_debug_info
self._all_args_info = all_args_info
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
return self._unsafe_call # type: ignore
# -- stages.XlaExecutable overrides
@ -2769,13 +2790,23 @@ class MeshExecutable(stages.XlaExecutable):
return self.xla_executable
def call(self, *args):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
if self._all_args_info is None:
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
ref_avals = self.in_avals
in_shardings = self._in_shardings
debug_info = self._jaxpr_debug_info
else:
kept_args = args
ref_avals = self._all_args_info.in_avals
iter_in_shardings = iter(self._in_shardings)
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_shardings)]
debug_info = self._all_args_info.debug_info
arg_avals = map(xla.abstractify, kept_args)
ref_avals = self.in_avals
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
self._jaxpr_debug_info)
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@ -2922,7 +2953,8 @@ def _compile_replicated_mesh_executable_from_hlo(
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, (None,) * len(global_in_avals),
(None,) * len(global_out_avals), jaxpr_debug_info, None)
(None,) * len(global_out_avals), jaxpr_debug_info,
None, None)
@lru_cache
@ -2956,6 +2988,8 @@ def check_gda_or_array_xla_sharding_match(
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
continue
db_xs = check_device_backend_on_shardings([xs])
if not db_xs:

View File

@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None:
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts"]
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -4018,6 +4018,52 @@ class PJitErrorTest(jtu.JaxTestCase):
x.delete()
_ = f(x)
def test_aot_error_on_dced_avals_mismatch(self):
x, y1, y2 = jnp.ones(4), jnp.ones(4), jnp.ones(1)
@jax.jit
def f(x, y):
return x + 1 if y.shape[0] > 2 else x + 2
f_out1 = f(x, y1)
f(x, y2)
g = f.lower(x, y1).compile()
g_out1 = g(x, y1)
self.assertArraysEqual(f_out1, g_out1)
with self.assertRaisesRegex(
TypeError,
'Argument types differ from the types for which this computation was'
' compiled'):
g(x, y2)
def test_aot_error_on_dced_shardings_mismatch(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y')))
@jax.jit
def f(x, y):
return x + 1
f_out1 = f(x, y1)
f(x, y2)
g = f.lower(x, y1).compile()
g_out1 = g(x, y1)
self.assertArraysEqual(f_out1, g_out1)
with self.assertRaisesRegex(
ValueError,
r"Compiled object called with input sharding.*does not match the "
r"sharding.*the computation was compiled with"):
g(x, y2)
@jtu.pytest_mark_if_available('multiaccelerator')
class UtilTest(jtu.JaxTestCase):