mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Typecheck avals and sharding for arguments that were DCE'd.
This keeps the promise of AOT that recompilation is guaranteed. Fixes https://github.com/google/jax/issues/18686 PiperOrigin-RevId: 585855658
This commit is contained in:
parent
37f11428a3
commit
88d980f164
@ -1875,6 +1875,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
|
||||
MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
|
||||
|
||||
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
in_shardings: Any
|
||||
debug_info: core.JaxprDebugInfo | None
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
@ -1904,6 +1912,9 @@ def lower_sharding_computation(
|
||||
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
|
||||
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
|
||||
closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
kept_var_idx, name_stack) = _dce_jaxpr(
|
||||
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
||||
@ -2004,7 +2015,8 @@ def lower_sharding_computation(
|
||||
pmap_nreps=nreps,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=shape_poly_state,
|
||||
all_default_mem_kind=all_default_mem_kind)
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info)
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2090,6 +2102,8 @@ def lower_mesh_computation(
|
||||
out_jaxpr_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
|
||||
|
||||
assert len(out_shardings) == len(out_jaxpr_avals)
|
||||
if spmd_lowering:
|
||||
global_out_avals = out_jaxpr_avals
|
||||
@ -2179,7 +2193,8 @@ def lower_mesh_computation(
|
||||
in_layouts=(None,) * len(global_in_avals),
|
||||
out_layouts=(None,) * len(global_out_avals),
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
shape_poly_state=lowering_result.shape_poly_state,
|
||||
all_args_info=all_args_info)
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: ir.Module | None
|
||||
@ -2568,6 +2583,7 @@ class UnloadedMeshExecutable:
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None
|
||||
in_layouts: Sequence[SpecifiedLayout | None]
|
||||
out_layouts: Sequence[SpecifiedLayout | None]
|
||||
all_args_info: AllArgsInfo | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
@ -2590,7 +2606,7 @@ class UnloadedMeshExecutable:
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.in_layouts, self.out_layouts,
|
||||
self.jaxpr_debug_info, self)
|
||||
self.jaxpr_debug_info, self.all_args_info, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
@ -2618,6 +2634,7 @@ class UnloadedMeshExecutable:
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
compiler_options=None,
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
@ -2710,7 +2727,8 @@ class UnloadedMeshExecutable:
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
jaxpr_debug_info=jaxpr_debug_info,
|
||||
in_layouts=in_layouts, # type: ignore
|
||||
out_layouts=out_layouts).load() # type: ignore
|
||||
out_layouts=out_layouts, # type: ignore
|
||||
all_args_info=all_args_info).load() # type: ignore
|
||||
|
||||
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
@ -2735,12 +2753,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable",
|
||||
"_in_layouts", "_out_layouts", "_jaxpr_debug_info",
|
||||
"_all_args_info", "_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
in_layouts, out_layouts, jaxpr_debug_info=None,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
@ -2755,13 +2775,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._in_layouts = in_layouts
|
||||
self._out_layouts = out_layouts
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._all_args_info = all_args_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
def unsafe_call(self) -> Callable[..., Any]:
|
||||
if self._unsafe_call is None:
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
return self._unsafe_call # type: ignore
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -2769,13 +2790,23 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return self.xla_executable
|
||||
|
||||
def call(self, *args):
|
||||
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
if self._all_args_info is None:
|
||||
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
ref_avals = self.in_avals
|
||||
in_shardings = self._in_shardings
|
||||
debug_info = self._jaxpr_debug_info
|
||||
else:
|
||||
kept_args = args
|
||||
ref_avals = self._all_args_info.in_avals
|
||||
iter_in_shardings = iter(self._in_shardings)
|
||||
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
|
||||
for i, s in enumerate(self._all_args_info.in_shardings)]
|
||||
debug_info = self._all_args_info.debug_info
|
||||
|
||||
arg_avals = map(xla.abstractify, kept_args)
|
||||
ref_avals = self.in_avals
|
||||
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
|
||||
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
|
||||
self._jaxpr_debug_info)
|
||||
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
@ -2922,7 +2953,8 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, (None,) * len(global_in_avals),
|
||||
(None,) * len(global_out_avals), jaxpr_debug_info, None)
|
||||
(None,) * len(global_out_avals), jaxpr_debug_info,
|
||||
None, None)
|
||||
|
||||
|
||||
@lru_cache
|
||||
@ -2956,6 +2988,8 @@ def check_gda_or_array_xla_sharding_match(
|
||||
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
if is_unspecified_or_auto(xs):
|
||||
continue
|
||||
|
||||
db_xs = check_device_backend_on_shardings([xs])
|
||||
if not db_xs:
|
||||
|
@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None:
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
"all_default_mem_kind", "in_layouts", "out_layouts"]
|
||||
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
@ -4018,6 +4018,52 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
x.delete()
|
||||
_ = f(x)
|
||||
|
||||
def test_aot_error_on_dced_avals_mismatch(self):
|
||||
x, y1, y2 = jnp.ones(4), jnp.ones(4), jnp.ones(1)
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x + 1 if y.shape[0] > 2 else x + 2
|
||||
|
||||
f_out1 = f(x, y1)
|
||||
f(x, y2)
|
||||
|
||||
g = f.lower(x, y1).compile()
|
||||
g_out1 = g(x, y1)
|
||||
self.assertArraysEqual(f_out1, g_out1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'Argument types differ from the types for which this computation was'
|
||||
' compiled'):
|
||||
g(x, y2)
|
||||
|
||||
def test_aot_error_on_dced_shardings_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y')))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x + 1
|
||||
|
||||
f_out1 = f(x, y1)
|
||||
f(x, y2)
|
||||
|
||||
g = f.lower(x, y1).compile()
|
||||
g_out1 = g(x, y1)
|
||||
self.assertArraysEqual(f_out1, g_out1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Compiled object called with input sharding.*does not match the "
|
||||
r"sharding.*the computation was compiled with"):
|
||||
g(x, y2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class UtilTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user