diff --git a/jax/BUILD b/jax/BUILD index 4ce401ede..2ceea2c4d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -228,6 +228,7 @@ py_library_providing_imports_info( ":effects", ":environment_info", ":jaxpr_util", + ":layout", ":lazy_loader", ":mesh", ":mlir", @@ -478,6 +479,7 @@ pytype_strict_library( ":core", ":dtypes", ":effects", + ":layout", ":op_shardings", ":partial_eval", ":pickle_util", @@ -633,6 +635,16 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "layout", + srcs = ["_src/layout.py"], + deps = [ + ":util", + ":xla_bridge", + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "sharding_impls", srcs = ["_src/sharding_impls.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index 4dcdff7e4..5fb5b636a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -304,13 +304,18 @@ def jit( static_argnums, static_argnames, device, backend, abstracted_axes) def infer_params(*args, **kwargs): + # TODO(yashkatariya): Remove this when it's added on jit. Also default to + # layout.DefaultLayout() when out of experimental. + in_layouts = kwargs.pop('_in_layouts', None) + out_layouts = kwargs.pop('_out_layouts', None) pjit_info_args = pjit.PjitInfo( fun=fun, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, resource_env=None, - abstracted_axes=abstracted_axes) + abstracted_axes=abstracted_axes, in_layouts=in_layouts, + out_layouts=out_layouts) return pjit.common_infer_params(pjit_info_args, *args, **kwargs) has_explicit_sharding = pjit._pjit_explicit_sharding( diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2f240abc6..22dbf401e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -180,8 +180,8 @@ def sharded_lowering( return pxla.lower_sharding_computation( fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, in_avals, keep_unused=keep_unused, inline=inline, - devices_from_context=None, - lowering_parameters=lowering_parameters) + devices_from_context=None, lowering_parameters=lowering_parameters, + in_layouts=(None,) * len(in_avals), out_layouts=None) def simple_impl(prim): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c1ff74b83..8111b59a2 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -41,6 +41,7 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla +from jax._src.layout import XLACompatibleLayout, LayoutRequest from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version @@ -691,6 +692,15 @@ def _to_logical_op_sharding( assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) return sharding._to_xla_hlo_sharding(aval.ndim) + +def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None: + if layout is None: + return None + if isinstance(layout, LayoutRequest): + return "auto" + return layout._to_xla_layout() + + def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]: if s is None: return None @@ -711,6 +721,8 @@ def lower_jaxpr_to_module( replicated_args: Sequence[bool] | None = None, arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, result_shardings: Sequence[XLACompatibleSharding | None] | None = None, + in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, + out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, arg_names: Sequence[str | None] | None = None, result_names: Sequence[str | None] | None = None, num_replicas: int = 1, @@ -784,6 +796,11 @@ def lower_jaxpr_to_module( map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) if result_shardings is not None else result_shardings) + arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None + else in_layouts) + result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None + else out_layouts) + ctx = ModuleContext(backend_or_name=backend_or_name, platforms=platforms, axis_context=axis_context, name_stack=name_stack, @@ -815,7 +832,9 @@ def lower_jaxpr_to_module( arg_names=arg_names, result_names=result_names, arg_memory_kinds=arg_memory_kinds, - result_memory_kinds=result_memory_kinds) + result_memory_kinds=result_memory_kinds, + arg_layouts=arg_layouts, + result_layouts=result_layouts) try: if not ctx.module.operation.verify(): @@ -969,6 +988,8 @@ def lower_jaxpr_to_fun( result_names: Sequence[str | None] | None = None, arg_memory_kinds: Sequence[str | None] | None = None, result_memory_kinds: Sequence[str | None] | None = None, + arg_layouts: Sequence[str | None] | None = None, + result_layouts: Sequence[str | None] | None = None, ) -> func_dialect.FuncOp: """Lowers jaxpr and its callees to an IR function. @@ -1055,6 +1076,12 @@ def lower_jaxpr_to_fun( if result_memory_kinds is not None: token_memory_kinds = [None] * (num_tokens + num_output_tokens) result_memory_kinds = [*token_memory_kinds, *result_memory_kinds] + if arg_layouts is not None: + token_layouts = [None] * (num_dim_vars + num_tokens) + arg_layouts = [*token_layouts, *arg_layouts] + if result_layouts is not None: + token_layouts = [None] * (num_tokens + num_output_tokens) + result_layouts = [*token_layouts, *result_layouts] flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) @@ -1077,6 +1104,11 @@ def lower_jaxpr_to_fun( ir_arg_memory_kinds = util.flatten( [[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)]) + ir_arg_layouts = None + if arg_layouts is not None: + ir_arg_layouts = util.flatten( + [[l] * len(types) for l, types in zip(arg_layouts, input_types)]) + ir_result_shardings = None if result_shardings is not None: out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals) @@ -1090,9 +1122,15 @@ def lower_jaxpr_to_fun( ir_result_memory_kinds = util.flatten( [[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)]) + ir_result_layouts = None + if result_layouts is not None: + ir_result_layouts = util.flatten( + [[l] * len(types) for l, types in zip(result_layouts, output_types)]) + if ( replicated_args is not None or ir_arg_shardings is not None + or ir_arg_layouts is not None or input_output_aliases is not None or arg_names is not None or num_tokens > 0 @@ -1113,6 +1151,11 @@ def lower_jaxpr_to_fun( if sharding is not None: attrs["mhlo.sharding"] = get_sharding_attr(sharding) + if ir_arg_layouts is not None: + for attrs, layout in zip(arg_attrs, ir_arg_layouts): + if layout is not None: + attrs["mhlo.layout_mode"] = ir.StringAttr.get(layout) + if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) @@ -1162,6 +1205,11 @@ def lower_jaxpr_to_fun( if sharding is not None: attrs['mhlo.sharding'] = get_sharding_attr(sharding) + if ir_result_layouts is not None: + for attrs, layout in zip(result_attrs, ir_result_layouts): + if layout is not None: + attrs['mhlo.layout_mode'] = ir.StringAttr.get(layout) + func_op.result_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in result_attrs]) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1660549f7..e0b84bb8b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -59,6 +59,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla +from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir @@ -1768,7 +1769,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, - da_object, + in_layouts, out_layouts, da_object, donated_invars, name_stack, all_default_mem_kind, lowering_parameters: mlir.LoweringParameters): jaxpr = closed_jaxpr.jaxpr @@ -1842,6 +1843,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, replicated_args=replicated_args, arg_shardings=in_mlir_shardings, result_shardings=out_mlir_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=nreps, @@ -1939,6 +1942,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): return False return True +MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]] @profiler.annotate_function def lower_sharding_computation( @@ -1954,6 +1958,8 @@ def lower_sharding_computation( inline: bool, devices_from_context: Sequence[xc.Device] | None = None, lowering_parameters: mlir.LoweringParameters, + in_layouts: MaybeLayout, + out_layouts: Optional[MaybeLayout], ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -1973,12 +1979,16 @@ def lower_sharding_computation( donated_invars, auto_spmd_lowering) jaxpr = closed_jaxpr.jaxpr in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) + in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx) if is_unspecified(out_shardings): out_shardings = (UNSPECIFIED,) * len(global_out_avals) + if out_layouts is None: + out_layouts = (None,) * len(global_out_avals) assert isinstance(out_shardings, tuple) - assert len(out_shardings) == len(global_out_avals), ( - len(out_shardings), len(global_out_avals)) + assert isinstance(out_layouts, tuple) + assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( + len(out_shardings), len(out_layouts), len(global_out_avals)) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. @@ -2030,7 +2040,7 @@ def lower_sharding_computation( (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, - semantic_out_shardings, da_object, + semantic_out_shardings, in_layouts, out_layouts, da_object, donated_invars, name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters) @@ -2058,6 +2068,8 @@ def lower_sharding_computation( backend=backend, device_assignment=da_object, committed=committed, + in_layouts=in_layouts, + out_layouts=out_layouts, pmap_nreps=nreps, jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, shape_poly_state=shape_poly_state, @@ -2233,6 +2245,8 @@ def lower_mesh_computation( backend=backend, device_assignment=_create_da_object(tuple(mesh.devices.flat)), committed=True, + in_layouts=(None,) * len(global_in_avals), + out_layouts=(None,) * len(global_out_avals), jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, shape_poly_state=lowering_result.shape_poly_state) @@ -2447,6 +2461,42 @@ def maybe_get_orig_out_sharding( return out_shardings, are_out_shardings_from_xla +def _get_layouts_from_executable( + xla_executable, in_layouts, out_layouts +) -> Sequence[Sequence[XLACompatibleLayout | None], Sequence[XLACompatibleLayout | None]]: # type: ignore + if all(i is None for i in in_layouts) and all(o is None for o in out_layouts): + return in_layouts, out_layouts # type: ignore + + in_layouts_xla = xla_executable.get_parameter_layouts() + out_layouts_xla = xla_executable.get_output_layouts() + + new_in_layouts = [] + for x, i in safe_zip(in_layouts_xla, in_layouts): + x = SpecifiedLayout._from_xla_layout(x) + if isinstance(i, SpecifiedLayout): + if i != x: + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)") + new_in_layouts.append(i) + else: + new_in_layouts.append(x) + + new_out_layouts = [] + for x, o in safe_zip(out_layouts_xla, out_layouts): + x = SpecifiedLayout._from_xla_layout(x) + if isinstance(o, SpecifiedLayout): + if o != x: + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)") + new_out_layouts.append(o) + else: + new_out_layouts.append(x) + + assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts) + assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts) + return new_in_layouts, new_out_layouts + + @weakref_lru_cache def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, @@ -2534,6 +2584,8 @@ class UnloadedMeshExecutable: kept_var_idx: set[int] auto_spmd_lowering: bool jaxpr_debug_info: core.JaxprDebugInfo | None + in_layouts: Sequence[SpecifiedLayout | None] + out_layouts: Sequence[SpecifiedLayout | None] def build_unsafe_call(self): input_indices = _get_input_indices(self.input_avals, self.input_shardings, @@ -2555,6 +2607,7 @@ class UnloadedMeshExecutable: self.input_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, + self.in_layouts, self.out_layouts, self.jaxpr_debug_info, self) # May return a MeshExecutable in the compile_replicated case. @@ -2577,6 +2630,8 @@ class UnloadedMeshExecutable: backend: xb.XlaBackend, device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore committed: bool, + in_layouts: MaybeLayout, + out_layouts: MaybeLayout, pmap_nreps: int = 1, jaxpr_debug_info: core.JaxprDebugInfo | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, @@ -2659,6 +2714,9 @@ class UnloadedMeshExecutable: else: are_out_shardings_from_xla = (False,) * len(global_out_avals) + in_layouts, out_layouts = _get_layouts_from_executable( + xla_executable, in_layouts, out_layouts) + if pmap_nreps > 1: in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) @@ -2684,7 +2742,9 @@ class UnloadedMeshExecutable: host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, auto_spmd_lowering=auto_spmd_lowering, - jaxpr_debug_info=jaxpr_debug_info).load() + jaxpr_debug_info=jaxpr_debug_info, + in_layouts=in_layouts, # type: ignore + out_layouts=out_layouts).load() # type: ignore class MeshExecutableFastpathData(NamedTuple): @@ -2709,12 +2769,13 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_jaxpr_debug_info", "_unloaded_executable", + "_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - jaxpr_debug_info=None, unloaded_executable=None): + in_layouts, out_layouts, jaxpr_debug_info=None, + unloaded_executable=None): self.xla_executable = xla_executable self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input @@ -2725,6 +2786,8 @@ class MeshExecutable(stages.XlaExecutable): self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx + self._in_layouts = in_layouts + self._out_layouts = out_layouts self._jaxpr_debug_info = jaxpr_debug_info self._unloaded_executable = unloaded_executable @@ -2755,6 +2818,12 @@ class MeshExecutable(stages.XlaExecutable): def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: return self._out_shardings + def input_layouts(self): + return self._in_layouts + + def output_layouts(self): + return self._out_layouts + def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and not self.unsafe_call.has_unordered_effects and @@ -2886,7 +2955,8 @@ def _compile_replicated_mesh_executable_from_hlo( xla_executable = None return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, jaxpr_debug_info, None) + kept_var_idx, (None,) * len(global_in_avals), + (None,) * len(global_out_avals), jaxpr_debug_info, None) @lru_cache diff --git a/jax/_src/layout.py b/jax/_src/layout.py new file mode 100644 index 000000000..dc63d7021 --- /dev/null +++ b/jax/_src/layout.py @@ -0,0 +1,78 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from jax._src.lib import xla_client as xc + + +class Layout: + pass + + +class XLACompatibleLayout(Layout): + @classmethod + def _from_xla_layout(cls, xla_layout) -> XLACompatibleLayout: + raise NotImplementedError("Subclasses should implement this method.") + + def _to_xla_layout(self) -> str: + raise NotImplementedError("Subclasses should implement this method.") + + +class SpecifiedLayout(XLACompatibleLayout): + minor_to_major: tuple[int, ...] + + def __init__(self, minor_to_major: tuple[int, ...]): + self.minor_to_major = minor_to_major + + def __repr__(self): + return f'SpecifiedLayout(minor_to_major={self.minor_to_major})' + + def __hash__(self): + return hash(self.minor_to_major) + + def __eq__(self, other): + if not isinstance(other, SpecifiedLayout): + return False + return self.minor_to_major == other.minor_to_major + + @classmethod + def _from_xla_layout(cls, xla_layout: xc.Layout) -> XLACompatibleLayout: + return cls(xla_layout.minor_to_major()) + + def _to_xla_layout(self) -> str: + return xc.Layout(self.minor_to_major).to_string() + + +class DefaultLayout(XLACompatibleLayout): + + def __repr__(self): + return 'DefaultLayout()' + + def __hash__(self): + return hash(type(self)) + + def __eq__(self, other): + return isinstance(other, DefaultLayout) and type(self) == type(other) + + def _to_xla_layout(self) -> str: + return "default" + + +class LayoutRequest: + + def __repr__(self): + return "Request a layout from the compiler" + +AUTO = LayoutRequest() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b80249dce..4e8f33d06 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -159,7 +159,7 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): - args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( + args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn( *args, **kwargs) for arg in args_flat: dispatch.check_arg(arg) @@ -328,8 +328,13 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, def lower(*args, **kwargs): lowering_parameters = kwargs.pop( '_experimental_lowering_parameters', mlir.LoweringParameters()) + # TODO(yashkatariya): Remove this when it's added on jit. Also default to + # layout.DefaultLayout() when out of experimental. + in_layouts = kwargs.pop('_in_layouts', None) + out_layouts = kwargs.pop('_out_layouts', None) (args_flat, flat_global_in_avals, params, in_tree, out_tree, - donated_invars) = infer_params_fn(*args, **kwargs) + donated_invars, in_layouts_flat, out_layouts_flat) = infer_params_fn( + *args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts) resource_env = params['resource_env'] mesh = None if resource_env is None else resource_env.physical_mesh try: @@ -338,8 +343,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], - params['keep_unused'], params['inline'], - lowering_parameters=lowering_parameters) + params['keep_unused'], params['inline'], in_layouts=in_layouts_flat, + out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if params['resource_env'] is None else 'pjit' @@ -387,12 +392,14 @@ class PjitInfo(NamedTuple): inline: bool resource_env: Any abstracted_axes: Optional[Any] + in_layouts: Any # pytree[XlaCompatibleLayout] | None + out_layouts: Any # pytree[XlaCompatibleLayout] | None def common_infer_params(pjit_info_args, *args, **kwargs): (fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames, donate_argnums, donate_argnames, device, backend, keep_unused, inline, - resource_env, abstracted_axes) = pjit_info_args + resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args if (kwargs and user_in_shardings is not None and not is_unspecified(user_in_shardings)): @@ -479,16 +486,16 @@ def common_infer_params(pjit_info_args, *args, **kwargs): ) from e in_type = in_avals = tuple(avals) - canonicalized_in_shardings_flat = _process_in_axis_resources( - hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg, - device_or_backend_set) + canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources( + hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals, + in_tree, resource_env, dbg, device_or_backend_set) - jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( - flat_fun, hashable_pytree(out_shardings), in_type, dbg, - device_or_backend_set, HashableFunction(out_tree, closure=()), + jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr( + flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts), + in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()), HashableFunction(res_paths, closure=())) - assert len(explicit_args) == len(canonicalized_in_shardings_flat) + assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat) if config.dynamic_shapes.value: implicit_args = _extract_implicit_args(in_type, explicit_args) @@ -499,9 +506,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs): num_extra_args = len(implicit_args) + len(consts) canonicalized_in_shardings_flat = \ (UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat + in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars - assert (len(canonicalized_in_shardings_flat) == len(donated_invars) == - len(consts) + len(args_flat)) + assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) == + len(donated_invars) == len(consts) + len(args_flat)) # in_shardings and out_shardings here are all GSPMDSharding. params = dict( @@ -515,7 +523,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): inline=inline, ) return (consts + args_flat, in_type, params, in_tree, out_tree(), - donated_invars) + donated_invars, in_layouts_flat, out_layouts_flat) def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], @@ -758,13 +766,18 @@ def pjit( def infer_params(*args, **kwargs): # Putting this outside of wrapped would make resources lexically scoped resource_env = mesh_lib.thread_resources.env + # TODO(yashkatariya): Remove this when it's added on jit. Also default to + # layout.DefaultLayout() when out of experimental. + in_layouts = kwargs.pop('_in_layouts', None) + out_layouts = kwargs.pop('_out_layouts', None) pjit_info_args = PjitInfo( fun=fun, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, resource_env=resource_env, - abstracted_axes=abstracted_axes) + abstracted_axes=abstracted_axes, in_layouts=in_layouts, + out_layouts=out_layouts) return common_infer_params(pjit_info_args, *args, **kwargs) has_explicit_sharding = _pjit_explicit_sharding( @@ -880,8 +893,9 @@ class PytreeLeaf: @lru_cache(maxsize=4096) -def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, - resource_env, debug_info, device_or_backend_set): +def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals, + in_tree, resource_env, debug_info, + device_or_backend_set): orig_in_shardings = in_shardings_thunk() # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. @@ -889,8 +903,14 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, in_shardings_flat = (orig_in_shardings,) * len(in_avals) else: in_shardings_flat = flatten_axis_resources( - "pjit in_shardings", in_tree, orig_in_shardings, - tupled_args=True) + "pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True) + + in_layouts = in_layouts_thunk() + if in_layouts is None: + in_layouts_flat = (in_layouts,) * len(in_avals) + else: + in_layouts_flat = flatten_axis_resources( + "pjit in_layouts", in_tree, in_layouts, tupled_args=True) if not config.dynamic_shapes.value: pjit_check_aval_sharding(in_shardings_flat, in_avals, @@ -900,7 +920,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, i if is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim, device_or_backend_set) for i, aval in zip(in_shardings_flat, in_avals)) - return canonicalized_shardings + return canonicalized_shardings, tuple(in_layouts_flat) @lu.cache @@ -930,7 +950,8 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths): @lru_cache(maxsize=4096) def _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_tree, out_type, debug_info, device_or_backend_set): + out_shardings_thunk, out_layouts_thunk, out_tree, out_type, debug_info, + device_or_backend_set): orig_out_shardings = out_shardings_thunk() # TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources # instead. This condition exists because flatten_axis_resources passes in an @@ -944,6 +965,13 @@ def _check_and_canonicalize_out_shardings( "pjit out_shardings", out_tree(), orig_out_shardings, tupled_args=False) + out_layouts = out_layouts_thunk() + if out_layouts is None: + out_layouts_flat = (out_layouts,) * len(out_type) + else: + out_layouts_flat = flatten_axis_resources( + "pjit out_layouts", out_tree(), out_layouts, tupled_args=False) + if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_type, @@ -955,18 +983,18 @@ def _check_and_canonicalize_out_shardings( to_gspmd_sharding(o, aval.ndim, device_or_backend_set) for o, aval in zip(out_shardings_flat, out_type) ) - return canonicalized_out_shardings_flat + return canonicalized_out_shardings_flat, tuple(out_layouts_flat) -def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, +def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info, device_or_backend_set, out_tree, result_paths): jaxpr, final_consts, out_type = _create_pjit_jaxpr( fun, in_type, debug_info, result_paths) - canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info, - device_or_backend_set) + canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( + out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type), + jaxpr.jaxpr.debug_info, device_or_backend_set) # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple - return jaxpr, final_consts, canonicalized_out_shardings_flat + return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat def pjit_check_aval_sharding( @@ -1271,11 +1299,20 @@ def _pjit_lower_cached( keep_unused: bool, inline: bool, *, - lowering_parameters: mlir.LoweringParameters): + lowering_parameters: mlir.LoweringParameters, + in_layouts: Optional[pxla.MaybeLayout] = None, + out_layouts: Optional[pxla.MaybeLayout] = None): in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings + # TODO(yashkatariya): Remove this when layouts are supported on jit and + # passed to params. + if in_layouts is None: + in_layouts = (None,) * len(in_shardings) + if out_layouts is None: + out_layouts = (None,) * len(out_shardings) + if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") @@ -1302,8 +1339,8 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_parameters=lowering_parameters, -) + lowering_parameters=lowering_parameters, in_layouts=in_layouts, + out_layouts=out_layouts) def pjit_staging_rule(trace, *args, **params): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index c470d1da2..30d4e1f3c 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -42,6 +42,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util from jax._src import util +from jax._src.layout import SpecifiedLayout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib import xla_client as xc @@ -84,6 +85,14 @@ class Executable(Protocol): """ raise NotImplementedError + # Layouts are exposed via jax.experimental.layouts + # TODO(frostig,yashkatariya): expose here when no longer experimental. + def _input_layouts(self): + raise NotImplementedError + + def _output_layouts(self): + raise NotImplementedError + def as_text(self) -> str: """A human-readable text representation of this executable. @@ -216,6 +225,14 @@ class XlaExecutable(Executable): raise NotImplementedError( "compiled executable carries no output sharding information") + def _input_layouts(self): + raise NotImplementedError( + "compiled executable carries no input layout information") + + def _output_layouts(self): + raise NotImplementedError( + "compiled executable carries no input layout information") + def as_text(self) -> str: xla_ext_exe = self.xla_extension_executable() err_msg = ("text view unsupported on current XLA backend: " @@ -481,6 +498,16 @@ class Compiled(Stage): shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error + def _input_layouts(self): + layouts_flat = self._executable.input_layouts() + assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat) + return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + + def _output_layouts(self): + layouts_flat = self._executable.output_layouts() + assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat) + return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error + @staticmethod def call(*args, **kwargs): # This is because `__call__` passes in `self._params` as the first argument. diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index f1eea1326..334e92c90 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None: "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", - "all_default_mem_kind"] + "all_default_mem_kind", "in_layouts", "out_layouts"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py new file mode 100644 index 000000000..5551a4ff6 --- /dev/null +++ b/jax/experimental/layout.py @@ -0,0 +1,19 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from jax._src.layout import ( + DefaultLayout as DefaultLayout, + SpecifiedLayout as SpecifiedLayout, + AUTO as AUTO, +) diff --git a/tests/BUILD b/tests/BUILD index 7e0af6c6a..79f91b5b2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -222,6 +222,12 @@ jax_test( ], ) +jax_test( + name = "layout_test", + srcs = ["layout_test.py"], + tags = ["multiaccelerator"], +) + jax_test( name = "pgle_test", srcs = ["pgle_test.py"], diff --git a/tests/layout_test.py b/tests/layout_test.py new file mode 100644 index 000000000..09383cb62 --- /dev/null +++ b/tests/layout_test.py @@ -0,0 +1,226 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from absl.testing import absltest +import numpy as np + +import jax +from jax.sharding import NamedSharding, PartitionSpec as P +from jax._src import config +from jax._src import layout +from jax._src import test_util as jtu +from jax._src import xla_bridge +from jax._src.lib import xla_extension_version + +config.parse_flags_with_absl() + +prev_xla_flags = None + +def setUpModule(): + global prev_xla_flags + prev_xla_flags = os.getenv("XLA_FLAGS") + flags_str = prev_xla_flags or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in flags_str: + os.environ["XLA_FLAGS"] = (flags_str + + " --xla_force_host_platform_device_count=8") + # Clear any cached backends so new CPU backend will pick up the env var. + xla_bridge.get_backend.cache_clear() + + +def tearDownModule(): + if prev_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = prev_xla_flags + xla_bridge.get_backend.cache_clear() + + +class LayoutTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest("Layouts do not work on CPU and GPU backends yet.") + if xla_extension_version < 215: + self.skipTest('All tests require xla_extension_version >= 215') + super().setUp() + + def test_auto_layout(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape1 = (128, 128) + shape2 = (128, 128) + + def apply(x, y): + return x.T, y.T + + def init(x, y): + return x * 2, y * 2 + + np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y'))) + np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x'))) + + lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO, + _out_layouts=layout.AUTO) + compiled_apply = lowered_apply.compile() + + arg_layouts, kw_layouts = compiled_apply._input_layouts() + self.assertEmpty(kw_layouts) + for i, o in zip(arg_layouts, compiled_apply._output_layouts()): + self.assertEqual(i.minor_to_major, o.minor_to_major[::-1]) + + init_compiled = jax.jit(init).lower( + arr1, arr2, _out_layouts=arg_layouts).compile() + + for i, o in zip(init_compiled._input_layouts()[0], + init_compiled._output_layouts()): + self.assertEqual(i.minor_to_major, o.minor_to_major) + + with jtu.count_aot_jit_cpp_cache_miss() as init_count: + init_out = init_compiled(arr1, arr2) + init_compiled(arr1, arr2) + self.assertEqual(init_count[0], 1) + + with jtu.count_aot_jit_cpp_cache_miss() as apply_count: + apply_out = compiled_apply(*init_out) + compiled_apply(*init_out) + self.assertEqual(apply_count[0], 1) + + self.assertArraysEqual(init_out[0], np_inp1 * 2) + self.assertArraysEqual(init_out[1], np_inp2 * 2) + self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T) + self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) + + def test_specified_layouts_on_jit(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape = (8, 4, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + def f(x): + return x.T + + sl = layout.SpecifiedLayout((0, 2, 1)) + + out1 = jax.jit(lambda x: x).lower(arr, _out_layouts=sl).compile()(arr) + + compiled = jax.jit(f).lower(out1, _in_layouts=sl, _out_layouts=sl).compile() + out2 = compiled(out1) + + self.assertEqual(compiled._input_layouts()[0][0], sl) + self.assertEqual(compiled._output_layouts(), sl) + self.assertArraysEqual(out2, np_inp.T) + self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + def test_default_layout(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape = (8, 4, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + def f(x): + return x.T + + compiled = jax.jit(f).lower( + arr, + _in_layouts=layout.DefaultLayout(), + _out_layouts=layout.DefaultLayout()).compile() + out = compiled(arr) + + self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (2, 1, 0)) + self.assertTupleEqual(compiled._output_layouts().minor_to_major, (2, 1, 0)) + self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO, + _out_layouts=layout.AUTO).compile() + self.assertTupleEqual(compiled_auto._input_layouts()[0][0].minor_to_major, + (2, 1, 0)) + self.assertTupleEqual(compiled_auto._output_layouts().minor_to_major, + (0, 1, 2)) + + # TODO(yashkatariya): Enable after mixture of auto, default and specified + # layouts work. + # def test_auto_specified_default_layout_with_sharding(self): + # mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + # shape = (8, 4, 2) + # np_inp = np.arange(math.prod(shape)).reshape(shape) + # s = NamedSharding(mesh, P('x', 'y')) + # arr = jax.device_put(np_inp, s) + + # def f(x, y, z): + # return x.T, y.T, z * 2 + + # lowered = jax.jit(f).lower( + # arr, arr, arr, + # _in_layouts=( + # layout.SpecifiedLayout((0, 2, 1)), + # layout.AUTO, + # layout.DefaultLayout()), + # _out_layouts=layout.AUTO) + # compiled = lowered.compile() + + # il1, il2, il3 = compiled._input_layouts()[0] + # ol1, ol2, ol3 = compiled._output_layouts() + + # self.assertTupleEqual(il1.minor_to_major, ol1.minor_to_major[::-1]) + # self.assertTupleEqual(il2.minor_to_major, ol2.minor_to_major[::-1]) + # self.assertEqual(il3, ol3) + + # out1, out2, out3 = compiled(arr, arr, arr) + # np_inp_t = np_inp.T + # self.assertArraysEqual(out1, np_inp_t) + # self.assertArraysEqual(out2, np_inp_t) + # self.assertArraysEqual(out3, np_inp * 2) + + def test_in_layouts_out_layouts(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape = (8, 8) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + def f(x): + return x.T + compiled = jax.jit(f).lower(arr, _in_layouts=layout.DefaultLayout(), + _out_layouts=layout.AUTO).compile() + self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0)) + self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1)) + + out = compiled(arr) + self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) + + def test_sharding_and_layouts(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape = (4, 8) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + + compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower( + np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile() + out = compiled(np_inp) + self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0)) + self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1)) + self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.sharding, s) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())