From ef33cf5acee2668b3e847aa19c91a52f6c18328d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Aug 2024 11:05:45 -0700 Subject: [PATCH] Standardize default layout to `None` in internals (dispatch, lowering and compilation) and non-default layouts to concrete layouts. This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX. PiperOrigin-RevId: 668527139 --- jax/_src/array.py | 11 ++--- jax/_src/interpreters/mlir.py | 23 +++++---- jax/_src/interpreters/pxla.py | 93 ++++++++++++----------------------- jax/_src/pjit.py | 21 +++++--- tests/layout_test.py | 24 +++++++-- 5 files changed, 80 insertions(+), 92 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 0f554a86a..7659c180d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -751,8 +751,7 @@ def make_array_from_callback( and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and (first_value.layout.device_local_layout == - pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))): + and first_value.layout.device_local_layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -1105,11 +1104,6 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _layout_eq(x, dst_layout, sharding): - if pxla.is_default_layout(dst_layout, sharding, x.aval): - return True - return x.layout.device_local_layout == dst_layout - def _array_shard_arg(xs, shardings, layouts): results = [] @@ -1118,7 +1112,8 @@ def _array_shard_arg(xs, shardings, layouts): for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - same_layout = _layout_eq(x, layout, sharding) + same_layout = (True if layout is None else + x.layout.device_local_layout == layout) if not x.is_fully_addressable: if same_indices and same_layout: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ab2c77833..0e7e0146e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1053,6 +1053,7 @@ def lower_jaxpr_to_module( result_memory_kinds = (map(_get_mem_kind, result_shardings) if result_shardings is not None else None) + # TODO(yashkatariya): Simplify the donation logic. xla_donated_args = None platforms_with_donation = [p for p in platforms if p in _platforms_with_donation] @@ -1071,9 +1072,6 @@ def lower_jaxpr_to_module( input_output_aliases, donated_args, xla_donated_args = _set_up_aliases( input_output_aliases, in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts) - unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) - if unlowerable_effects: - raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." @@ -1082,10 +1080,13 @@ def lower_jaxpr_to_module( if unused_donations: warnings.warn("Some donated buffers were not usable:" f" {', '.join(unused_donations)}.\n{msg}") - # Delete donated_args by default here, since it's not needed beyond this point del donated_args + unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) + if unlowerable_effects: + raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') + # HLO channels need to start at 1 channel_iter = itertools.count(1) # Create a keepalives list that will be mutated during the lowering. @@ -1167,8 +1168,7 @@ def lower_jaxpr_to_module( def _set_up_aliases(input_output_aliases, avals_in, avals_out, - donated_args, - arg_memory_kinds, result_memory_kinds, + donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts): if input_output_aliases is None: input_output_aliases = [None] * len(avals_in) @@ -1207,15 +1207,14 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, if donations.get(key, ()): input_id = donations[key].popleft() out_donated_args[input_id] = False + # We can alias if XLA performs layout assignment because XLA will + # respect the aliases when assigning layouts. Its only for two + # mismatched explicitly assigned layouts that XLA will certainly fail. if (in_layouts is None or out_layouts is None or in_layouts[input_id] == out_layouts[i] or - # We can alias if XLA performs layout assignment because XLA will - # respect the aliases when assigning layouts. Its only for two - # mismatched explicitly assigned layouts that XLA will certainly - # fail. - isinstance(in_layouts[input_id], (AutoLayout, type(None))) or - isinstance(out_layouts[i], (AutoLayout, type(None)))): + isinstance(in_layouts[input_id], AutoLayout) or + isinstance(out_layouts[i], AutoLayout)): input_output_aliases[input_id] = i else: # Fallback to xla donation if layouts don't match. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f6f413307..4e0986cfa 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -150,7 +150,7 @@ shard_arg_handlers: dict[ @lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None: + if curr_layout is None or sharding is None or is_unspecified(sharding): return True if (aval is core.abstract_token or aval.dtype == dtypes.float0 or dtypes.issubdtype(aval.dtype, dtypes.extended)): @@ -191,7 +191,7 @@ def _shard_np_array(xs, shardings, layouts): if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) - if not is_default_layout(layout, sharding, aval): + if layout is not None: results.append(api.device_put(x, Layout(layout, sharding))) else: if sharding.is_fully_replicated: @@ -1884,35 +1884,6 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "extra data movement anyway, so maybe you don't want it after all).") -@lru_cache(maxsize=2048) -def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval - ) -> DeviceLocalLayout | None: - if is_unspecified_or_auto(sharding): - return None - # TODO(yashkatariya): Figure out how layouts work with extended dtypes. - if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended): - return None - if not core.is_constant_shape(aval.shape): - return None - shard_shape = sharding.shard_shape(aval.shape) - d = sharding._device_assignment[0] - # If a backend doesn't implement `get_default_layout` return `None` to avoid - # cache misses. This can happen when you have `jit(f, in_shardings=s)`. On - # first call you pass it a sharded array with layout and on second call you - # pass a numpy array. The layouts should be the same to get cache hits. - try: - al = DeviceLocalLayout.from_pjrt_layout( - d.client.get_default_layout(aval.dtype, shard_shape, d)) - except: - return None - # argument does not have `.layout` property. ShapedArray, numpy array, etc - # are some examples. - if arg_layout is None: - return al if jit_in_layout is None else arg_layout # arg_layout is None - # If arg has a `.layout` property, then return device_local_layout as is. - return arg_layout.device_local_layout - - @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, @@ -2775,13 +2746,14 @@ class UnloadedMeshExecutable: kept_var_idx: set[int] mut: MutationData | None auto_spmd_lowering: bool - in_layouts: Sequence[DeviceLocalLayout | None] - out_layouts: Sequence[DeviceLocalLayout | None] + xla_in_layouts: Sequence[DeviceLocalLayout | None] + dispatch_in_layouts: Sequence[DeviceLocalLayout | None] + xla_out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - handle_args = InputsHandler(self.input_shardings, self.in_layouts) + handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed) @@ -2797,8 +2769,8 @@ class UnloadedMeshExecutable: self.input_avals, self.output_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.in_layouts, self.out_layouts, - self.all_args_info, self) + self.xla_in_layouts, self.dispatch_in_layouts, + self.xla_out_layouts, self.all_args_info, self) @staticmethod def from_hlo(name: str, @@ -2881,8 +2853,18 @@ class UnloadedMeshExecutable: in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - in_layouts, out_layouts = _get_layouts_from_executable( + # xla_in_layouts are all either None or DeviceLocalLayout. Even default + # layout are concrete layouts and they are used in `compiled.input_layouts` + # to return concrete layouts to users. + # `dispatch_in_layouts` replaces default layouts with `None` to simplify + # dispatch logic downstream. + xla_in_layouts, xla_out_layouts = _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, len(ordered_effects)) + del in_layouts, out_layouts + dispatch_in_layouts = [ + None if is_default_layout(l, s, a) else l + for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals) + ] out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals, @@ -2907,8 +2889,9 @@ class UnloadedMeshExecutable: kept_var_idx=kept_var_idx, mut=mut, auto_spmd_lowering=auto_spmd_lowering, - in_layouts=in_layouts, - out_layouts=out_layouts, + xla_in_layouts=xla_in_layouts, + dispatch_in_layouts=dispatch_in_layouts, + xla_out_layouts=xla_out_layouts, all_args_info=all_args_info, pgle_profiler=pgle_profiler).load() @@ -2964,13 +2947,13 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", - "_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info", - "_unloaded_executable", + "_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts", + "_xla_out_layouts", "_all_args_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - in_layouts, out_layouts, + xla_in_layouts, dispatch_in_layouts, xla_out_layouts, all_args_info: AllArgsInfo | None = None, unloaded_executable=None): self.xla_executable = xla_executable @@ -2984,8 +2967,9 @@ class MeshExecutable(stages.XlaExecutable): self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx - self._in_layouts = in_layouts - self._out_layouts = out_layouts + self._xla_in_layouts = xla_in_layouts + self._dispatch_in_layouts = dispatch_in_layouts + self._xla_out_layouts = xla_out_layouts self._all_args_info = all_args_info self._unloaded_executable = unloaded_executable @@ -3013,9 +2997,8 @@ class MeshExecutable(stages.XlaExecutable): all_arg_avals = map(xla.abstractify, kept_args) check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) - # Check the GDA sharding and the input sharding. check_array_xla_sharding_layout_match( - args_after_dce, self._in_shardings, self._in_layouts, debug_info, + args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info, self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -3027,11 +3010,11 @@ class MeshExecutable(stages.XlaExecutable): def input_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._in_layouts, self._in_shardings)] + for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] def output_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._out_layouts, self._out_shardings)] + for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and @@ -3057,12 +3040,10 @@ class MeshExecutable(stages.XlaExecutable): else s for s, a in zip(self._in_shardings, self.in_avals) ] - in_dlls = get_layouts_for_fasthpath_data( - self._in_layouts, in_shardings, self.in_avals) fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - in_dlls) + self._dispatch_in_layouts) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry @@ -3084,16 +3065,6 @@ else: return shard_args([sharding], [layout], [x])[0] -def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals): - in_dlls = [] - for l, s, a in zip(in_layouts, in_shardings, in_avals): - if is_default_layout(l, s, a): - in_dlls.append(None) - else: - in_dlls.append(l) - return in_dlls - - def check_arg_avals_for_call(ref_avals, arg_avals, jaxpr_debug_info: core.JaxprDebugInfo | None = None): if len(ref_avals) != len(arg_avals): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 197388afe..bcd31ec09 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -279,12 +279,10 @@ def _get_fastpath_data( else s for s, a in zip(executable._in_shardings, executable.in_avals) ] - in_dlls = pxla.get_layouts_for_fasthpath_data( - executable._in_layouts, in_shardings, executable.in_avals) fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - in_dlls) + executable._dispatch_in_layouts) else: fastpath_data = None return fastpath_data @@ -1479,10 +1477,17 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts = [] for arg, jit_in_l, rs, aval in safe_zip( args, jit_in_layouts, resolved_in_shardings, in_avals): - arg_layout, committed = ( - pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, - rs, aval), - getattr(arg, '_committed', True)) + committed = getattr(arg, '_committed', True) + # `arg_layout` is only used for checking purposes in the `else` branch + # below. We cannot replace default layout with None to raise nicer errors. + # `dispatch_arg_layout` replaces default layouts with `None` to simplify + # dispatch and lowering logic downstream. + if hasattr(arg, 'layout'): + arg_layout = arg.layout.device_local_layout + dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) + else arg_layout) + else: + arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. is_pmap_sharding = (is_unspecified(rs) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) @@ -1491,7 +1496,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): if is_pmap_sharding: resolved_in_layouts.append(None) else: - resolved_in_layouts.append(arg_layout) + resolved_in_layouts.append(dispatch_arg_layout) else: resolved_in_layouts.append(None) else: diff --git a/tests/layout_test.py b/tests/layout_test.py index 33f3318bf..7a587d099 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -540,7 +540,7 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) def test_layout_donation_auto(self): @@ -555,7 +555,7 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x * x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) def test_layout_donation_matching_in_and_out(self): @@ -572,9 +572,27 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x * x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) + def test_layout_donation_mismatching_in_and_out_fails(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16*2, 32016*2) + np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + + custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) + l1 = Layout(custom_dll1, s) + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=l1, donate_argnums=0) + def f(x): + return x * x + + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + f.lower(sds).compile()(arr) + self.assertFalse(arr.is_deleted()) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())