Add a private API to allow setting layouts on jitted computations.

We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
This commit is contained in:
Yash Katariya 2023-11-15 08:48:17 -08:00 committed by jax authors
parent b032a0271e
commit 5c3da219c0
12 changed files with 572 additions and 44 deletions

View File

@ -228,6 +228,7 @@ py_library_providing_imports_info(
":effects",
":environment_info",
":jaxpr_util",
":layout",
":lazy_loader",
":mesh",
":mlir",
@ -478,6 +479,7 @@ pytype_strict_library(
":core",
":dtypes",
":effects",
":layout",
":op_shardings",
":partial_eval",
":pickle_util",
@ -633,6 +635,16 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],
deps = [
":util",
":xla_bridge",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "sharding_impls",
srcs = ["_src/sharding_impls.py"],

View File

@ -304,13 +304,18 @@ def jit(
static_argnums, static_argnames, device, backend, abstracted_axes)
def infer_params(*args, **kwargs):
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
# layout.DefaultLayout() when out of experimental.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = pjit.PjitInfo(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=None,
abstracted_axes=abstracted_axes)
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = pjit._pjit_explicit_sharding(

View File

@ -180,8 +180,8 @@ def sharded_lowering(
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None,
lowering_parameters=lowering_parameters)
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=None)
def simple_impl(prim):

View File

@ -41,6 +41,7 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, LayoutRequest
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -691,6 +692,15 @@ def _to_logical_op_sharding(
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
return sharding._to_xla_hlo_sharding(aval.ndim)
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
if layout is None:
return None
if isinstance(layout, LayoutRequest):
return "auto"
return layout._to_xla_layout()
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
if s is None:
return None
@ -711,6 +721,8 @@ def lower_jaxpr_to_module(
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
num_replicas: int = 1,
@ -784,6 +796,11 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
else in_layouts)
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
else out_layouts)
ctx = ModuleContext(backend_or_name=backend_or_name,
platforms=platforms, axis_context=axis_context,
name_stack=name_stack,
@ -815,7 +832,9 @@ def lower_jaxpr_to_module(
arg_names=arg_names,
result_names=result_names,
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds)
result_memory_kinds=result_memory_kinds,
arg_layouts=arg_layouts,
result_layouts=result_layouts)
try:
if not ctx.module.operation.verify():
@ -969,6 +988,8 @@ def lower_jaxpr_to_fun(
result_names: Sequence[str | None] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
arg_layouts: Sequence[str | None] | None = None,
result_layouts: Sequence[str | None] | None = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -1055,6 +1076,12 @@ def lower_jaxpr_to_fun(
if result_memory_kinds is not None:
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
if arg_layouts is not None:
token_layouts = [None] * (num_dim_vars + num_tokens)
arg_layouts = [*token_layouts, *arg_layouts]
if result_layouts is not None:
token_layouts = [None] * (num_tokens + num_output_tokens)
result_layouts = [*token_layouts, *result_layouts]
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
@ -1077,6 +1104,11 @@ def lower_jaxpr_to_fun(
ir_arg_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
ir_arg_layouts = None
if arg_layouts is not None:
ir_arg_layouts = util.flatten(
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
@ -1090,9 +1122,15 @@ def lower_jaxpr_to_fun(
ir_result_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
ir_result_layouts = None
if result_layouts is not None:
ir_result_layouts = util.flatten(
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
if (
replicated_args is not None
or ir_arg_shardings is not None
or ir_arg_layouts is not None
or input_output_aliases is not None
or arg_names is not None
or num_tokens > 0
@ -1113,6 +1151,11 @@ def lower_jaxpr_to_fun(
if sharding is not None:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if ir_arg_layouts is not None:
for attrs, layout in zip(arg_attrs, ir_arg_layouts):
if layout is not None:
attrs["mhlo.layout_mode"] = ir.StringAttr.get(layout)
if input_output_aliases is not None:
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
@ -1162,6 +1205,11 @@ def lower_jaxpr_to_fun(
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
if ir_result_layouts is not None:
for attrs, layout in zip(result_attrs, ir_result_layouts):
if layout is not None:
attrs['mhlo.layout_mode'] = ir.StringAttr.get(layout)
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])

View File

@ -59,6 +59,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
@ -1768,7 +1769,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object,
in_layouts, out_layouts, da_object,
donated_invars, name_stack, all_default_mem_kind,
lowering_parameters: mlir.LoweringParameters):
jaxpr = closed_jaxpr.jaxpr
@ -1842,6 +1843,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
@ -1939,6 +1942,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True
MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
@profiler.annotate_function
def lower_sharding_computation(
@ -1954,6 +1958,8 @@ def lower_sharding_computation(
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
lowering_parameters: mlir.LoweringParameters,
in_layouts: MaybeLayout,
out_layouts: Optional[MaybeLayout],
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -1973,12 +1979,16 @@ def lower_sharding_computation(
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
if is_unspecified(out_shardings):
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
if out_layouts is None:
out_layouts = (None,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
assert isinstance(out_layouts, tuple)
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
@ -2030,7 +2040,7 @@ def lower_sharding_computation(
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object,
semantic_out_shardings, in_layouts, out_layouts, da_object,
donated_invars, name_stack, all_default_mem_kind,
lowering_parameters=lowering_parameters)
@ -2058,6 +2068,8 @@ def lower_sharding_computation(
backend=backend,
device_assignment=da_object,
committed=committed,
in_layouts=in_layouts,
out_layouts=out_layouts,
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=shape_poly_state,
@ -2233,6 +2245,8 @@ def lower_mesh_computation(
backend=backend,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)
@ -2447,6 +2461,42 @@ def maybe_get_orig_out_sharding(
return out_shardings, are_out_shardings_from_xla
def _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts
) -> Sequence[Sequence[XLACompatibleLayout | None], Sequence[XLACompatibleLayout | None]]: # type: ignore
if all(i is None for i in in_layouts) and all(o is None for o in out_layouts):
return in_layouts, out_layouts # type: ignore
in_layouts_xla = xla_executable.get_parameter_layouts()
out_layouts_xla = xla_executable.get_output_layouts()
new_in_layouts = []
for x, i in safe_zip(in_layouts_xla, in_layouts):
x = SpecifiedLayout._from_xla_layout(x)
if isinstance(i, SpecifiedLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
new_in_layouts.append(i)
else:
new_in_layouts.append(x)
new_out_layouts = []
for x, o in safe_zip(out_layouts_xla, out_layouts):
x = SpecifiedLayout._from_xla_layout(x)
if isinstance(o, SpecifiedLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
new_out_layouts.append(o)
else:
new_out_layouts.append(x)
assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts)
assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts)
return new_in_layouts, new_out_layouts
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering,
@ -2534,6 +2584,8 @@ class UnloadedMeshExecutable:
kept_var_idx: set[int]
auto_spmd_lowering: bool
jaxpr_debug_info: core.JaxprDebugInfo | None
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
@ -2555,6 +2607,7 @@ class UnloadedMeshExecutable:
self.input_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
self.jaxpr_debug_info, self)
# May return a MeshExecutable in the compile_replicated case.
@ -2577,6 +2630,8 @@ class UnloadedMeshExecutable:
backend: xb.XlaBackend,
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
pmap_nreps: int = 1,
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
@ -2659,6 +2714,9 @@ class UnloadedMeshExecutable:
else:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
in_layouts, out_layouts = _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts)
if pmap_nreps > 1:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
@ -2684,7 +2742,9 @@ class UnloadedMeshExecutable:
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering,
jaxpr_debug_info=jaxpr_debug_info).load()
jaxpr_debug_info=jaxpr_debug_info,
in_layouts=in_layouts, # type: ignore
out_layouts=out_layouts).load() # type: ignore
class MeshExecutableFastpathData(NamedTuple):
@ -2709,12 +2769,13 @@ class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_jaxpr_debug_info", "_unloaded_executable",
"_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
jaxpr_debug_info=None, unloaded_executable=None):
in_layouts, out_layouts, jaxpr_debug_info=None,
unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
@ -2725,6 +2786,8 @@ class MeshExecutable(stages.XlaExecutable):
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._in_layouts = in_layouts
self._out_layouts = out_layouts
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
@ -2755,6 +2818,12 @@ class MeshExecutable(stages.XlaExecutable):
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
return self._out_shardings
def input_layouts(self):
return self._in_layouts
def output_layouts(self):
return self._out_layouts
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
not self.unsafe_call.has_unordered_effects and
@ -2886,7 +2955,8 @@ def _compile_replicated_mesh_executable_from_hlo(
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, jaxpr_debug_info, None)
kept_var_idx, (None,) * len(global_in_avals),
(None,) * len(global_out_avals), jaxpr_debug_info, None)
@lru_cache

78
jax/_src/layout.py Normal file
View File

@ -0,0 +1,78 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax._src.lib import xla_client as xc
class Layout:
pass
class XLACompatibleLayout(Layout):
@classmethod
def _from_xla_layout(cls, xla_layout) -> XLACompatibleLayout:
raise NotImplementedError("Subclasses should implement this method.")
def _to_xla_layout(self) -> str:
raise NotImplementedError("Subclasses should implement this method.")
class SpecifiedLayout(XLACompatibleLayout):
minor_to_major: tuple[int, ...]
def __init__(self, minor_to_major: tuple[int, ...]):
self.minor_to_major = minor_to_major
def __repr__(self):
return f'SpecifiedLayout(minor_to_major={self.minor_to_major})'
def __hash__(self):
return hash(self.minor_to_major)
def __eq__(self, other):
if not isinstance(other, SpecifiedLayout):
return False
return self.minor_to_major == other.minor_to_major
@classmethod
def _from_xla_layout(cls, xla_layout: xc.Layout) -> XLACompatibleLayout:
return cls(xla_layout.minor_to_major())
def _to_xla_layout(self) -> str:
return xc.Layout(self.minor_to_major).to_string()
class DefaultLayout(XLACompatibleLayout):
def __repr__(self):
return 'DefaultLayout()'
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return isinstance(other, DefaultLayout) and type(self) == type(other)
def _to_xla_layout(self) -> str:
return "default"
class LayoutRequest:
def __repr__(self):
return "Request a layout from the compiler"
AUTO = LayoutRequest()

View File

@ -159,7 +159,7 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
*args, **kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
@ -328,8 +328,13 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
def lower(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
# layout.DefaultLayout() when out of experimental.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars) = infer_params_fn(*args, **kwargs)
donated_invars, in_layouts_flat, out_layouts_flat) = infer_params_fn(
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
@ -338,8 +343,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'],
lowering_parameters=lowering_parameters)
params['keep_unused'], params['inline'], in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
@ -387,12 +392,14 @@ class PjitInfo(NamedTuple):
inline: bool
resource_env: Any
abstracted_axes: Optional[Any]
in_layouts: Any # pytree[XlaCompatibleLayout] | None
out_layouts: Any # pytree[XlaCompatibleLayout] | None
def common_infer_params(pjit_info_args, *args, **kwargs):
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
resource_env, abstracted_axes) = pjit_info_args
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args
if (kwargs and user_in_shardings is not None and
not is_unspecified(user_in_shardings)):
@ -479,16 +486,16 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
) from e
in_type = in_avals = tuple(avals)
canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg,
device_or_backend_set)
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals,
in_tree, resource_env, dbg, device_or_backend_set)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), in_type, dbg,
device_or_backend_set, HashableFunction(out_tree, closure=()),
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts),
in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()),
HashableFunction(res_paths, closure=()))
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat)
if config.dynamic_shapes.value:
implicit_args = _extract_implicit_args(in_type, explicit_args)
@ -499,9 +506,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
num_extra_args = len(implicit_args) + len(consts)
canonicalized_in_shardings_flat = \
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
len(consts) + len(args_flat))
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(consts) + len(args_flat))
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
@ -515,7 +523,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
inline=inline,
)
return (consts + args_flat, in_type, params, in_tree, out_tree(),
donated_invars)
donated_invars, in_layouts_flat, out_layouts_flat)
def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
@ -758,13 +766,18 @@ def pjit(
def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = mesh_lib.thread_resources.env
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
# layout.DefaultLayout() when out of experimental.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = PjitInfo(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
abstracted_axes=abstracted_axes)
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = _pjit_explicit_sharding(
@ -880,8 +893,9 @@ class PytreeLeaf:
@lru_cache(maxsize=4096)
def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
resource_env, debug_info, device_or_backend_set):
def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals,
in_tree, resource_env, debug_info,
device_or_backend_set):
orig_in_shardings = in_shardings_thunk()
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
@ -889,8 +903,14 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
else:
in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, orig_in_shardings,
tupled_args=True)
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
in_layouts = in_layouts_thunk()
if in_layouts is None:
in_layouts_flat = (in_layouts,) * len(in_avals)
else:
in_layouts_flat = flatten_axis_resources(
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
@ -900,7 +920,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
i if is_unspecified_or_auto(i) else
to_gspmd_sharding(i, aval.ndim, device_or_backend_set)
for i, aval in zip(in_shardings_flat, in_avals))
return canonicalized_shardings
return canonicalized_shardings, tuple(in_layouts_flat)
@lu.cache
@ -930,7 +950,8 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
@lru_cache(maxsize=4096)
def _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, out_type, debug_info, device_or_backend_set):
out_shardings_thunk, out_layouts_thunk, out_tree, out_type, debug_info,
device_or_backend_set):
orig_out_shardings = out_shardings_thunk()
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
# instead. This condition exists because flatten_axis_resources passes in an
@ -944,6 +965,13 @@ def _check_and_canonicalize_out_shardings(
"pjit out_shardings", out_tree(), orig_out_shardings,
tupled_args=False)
out_layouts = out_layouts_thunk()
if out_layouts is None:
out_layouts_flat = (out_layouts,) * len(out_type)
else:
out_layouts_flat = flatten_axis_resources(
"pjit out_layouts", out_tree(), out_layouts, tupled_args=False)
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(
out_shardings_flat, out_type,
@ -955,18 +983,18 @@ def _check_and_canonicalize_out_shardings(
to_gspmd_sharding(o, aval.ndim, device_or_backend_set)
for o, aval in zip(out_shardings_flat, out_type)
)
return canonicalized_out_shardings_flat
return canonicalized_out_shardings_flat, tuple(out_layouts_flat)
def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info,
def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info,
device_or_backend_set, out_tree, result_paths):
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
fun, in_type, debug_info, result_paths)
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info,
device_or_backend_set)
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type),
jaxpr.jaxpr.debug_info, device_or_backend_set)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return jaxpr, final_consts, canonicalized_out_shardings_flat
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat
def pjit_check_aval_sharding(
@ -1271,11 +1299,20 @@ def _pjit_lower_cached(
keep_unused: bool,
inline: bool,
*,
lowering_parameters: mlir.LoweringParameters):
lowering_parameters: mlir.LoweringParameters,
in_layouts: Optional[pxla.MaybeLayout] = None,
out_layouts: Optional[pxla.MaybeLayout] = None):
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
# TODO(yashkatariya): Remove this when layouts are supported on jit and
# passed to params.
if in_layouts is None:
in_layouts = (None,) * len(in_shardings)
if out_layouts is None:
out_layouts = (None,) * len(out_shardings)
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
@ -1302,8 +1339,8 @@ def _pjit_lower_cached(
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_parameters=lowering_parameters,
)
lowering_parameters=lowering_parameters, in_layouts=in_layouts,
out_layouts=out_layouts)
def pjit_staging_rule(trace, *args, **params):

View File

@ -42,6 +42,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.layout import SpecifiedLayout
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
@ -84,6 +85,14 @@ class Executable(Protocol):
"""
raise NotImplementedError
# Layouts are exposed via jax.experimental.layouts
# TODO(frostig,yashkatariya): expose here when no longer experimental.
def _input_layouts(self):
raise NotImplementedError
def _output_layouts(self):
raise NotImplementedError
def as_text(self) -> str:
"""A human-readable text representation of this executable.
@ -216,6 +225,14 @@ class XlaExecutable(Executable):
raise NotImplementedError(
"compiled executable carries no output sharding information")
def _input_layouts(self):
raise NotImplementedError(
"compiled executable carries no input layout information")
def _output_layouts(self):
raise NotImplementedError(
"compiled executable carries no input layout information")
def as_text(self) -> str:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("text view unsupported on current XLA backend: "
@ -481,6 +498,16 @@ class Compiled(Stage):
shardings_flat = self._executable.output_shardings()
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
def _input_layouts(self):
layouts_flat = self._executable.input_layouts()
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
def _output_layouts(self):
layouts_flat = self._executable.output_layouts()
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
@staticmethod
def call(*args, **kwargs):
# This is because `__call__` passes in `self._params` as the first argument.

View File

@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None:
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind"]
"all_default_mem_kind", "in_layouts", "out_layouts"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -0,0 +1,19 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.
from jax._src.layout import (
DefaultLayout as DefaultLayout,
SpecifiedLayout as SpecifiedLayout,
AUTO as AUTO,
)

View File

@ -222,6 +222,12 @@ jax_test(
],
)
jax_test(
name = "layout_test",
srcs = ["layout_test.py"],
tags = ["multiaccelerator"],
)
jax_test(
name = "pgle_test",
srcs = ["pgle_test.py"],

226
tests/layout_test.py Normal file
View File

@ -0,0 +1,226 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from absl.testing import absltest
import numpy as np
import jax
from jax.sharding import NamedSharding, PartitionSpec as P
from jax._src import config
from jax._src import layout
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
class LayoutTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(['tpu']):
self.skipTest("Layouts do not work on CPU and GPU backends yet.")
if xla_extension_version < 215:
self.skipTest('All tests require xla_extension_version >= 215')
super().setUp()
def test_auto_layout(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape1 = (128, 128)
shape2 = (128, 128)
def apply(x, y):
return x.T, y.T
def init(x, y):
return x * 2, y * 2
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y')))
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x')))
lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO,
_out_layouts=layout.AUTO)
compiled_apply = lowered_apply.compile()
arg_layouts, kw_layouts = compiled_apply._input_layouts()
self.assertEmpty(kw_layouts)
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
self.assertEqual(i.minor_to_major, o.minor_to_major[::-1])
init_compiled = jax.jit(init).lower(
arr1, arr2, _out_layouts=arg_layouts).compile()
for i, o in zip(init_compiled._input_layouts()[0],
init_compiled._output_layouts()):
self.assertEqual(i.minor_to_major, o.minor_to_major)
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
init_out = init_compiled(arr1, arr2)
init_compiled(arr1, arr2)
self.assertEqual(init_count[0], 1)
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
apply_out = compiled_apply(*init_out)
compiled_apply(*init_out)
self.assertEqual(apply_count[0], 1)
self.assertArraysEqual(init_out[0], np_inp1 * 2)
self.assertArraysEqual(init_out[1], np_inp2 * 2)
self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T)
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
def test_specified_layouts_on_jit(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
def f(x):
return x.T
sl = layout.SpecifiedLayout((0, 2, 1))
out1 = jax.jit(lambda x: x).lower(arr, _out_layouts=sl).compile()(arr)
compiled = jax.jit(f).lower(out1, _in_layouts=sl, _out_layouts=sl).compile()
out2 = compiled(out1)
self.assertEqual(compiled._input_layouts()[0][0], sl)
self.assertEqual(compiled._output_layouts(), sl)
self.assertArraysEqual(out2, np_inp.T)
self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
def test_default_layout(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
def f(x):
return x.T
compiled = jax.jit(f).lower(
arr,
_in_layouts=layout.DefaultLayout(),
_out_layouts=layout.DefaultLayout()).compile()
out = compiled(arr)
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (2, 1, 0))
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (2, 1, 0))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO,
_out_layouts=layout.AUTO).compile()
self.assertTupleEqual(compiled_auto._input_layouts()[0][0].minor_to_major,
(2, 1, 0))
self.assertTupleEqual(compiled_auto._output_layouts().minor_to_major,
(0, 1, 2))
# TODO(yashkatariya): Enable after mixture of auto, default and specified
# layouts work.
# def test_auto_specified_default_layout_with_sharding(self):
# mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
# shape = (8, 4, 2)
# np_inp = np.arange(math.prod(shape)).reshape(shape)
# s = NamedSharding(mesh, P('x', 'y'))
# arr = jax.device_put(np_inp, s)
# def f(x, y, z):
# return x.T, y.T, z * 2
# lowered = jax.jit(f).lower(
# arr, arr, arr,
# _in_layouts=(
# layout.SpecifiedLayout((0, 2, 1)),
# layout.AUTO,
# layout.DefaultLayout()),
# _out_layouts=layout.AUTO)
# compiled = lowered.compile()
# il1, il2, il3 = compiled._input_layouts()[0]
# ol1, ol2, ol3 = compiled._output_layouts()
# self.assertTupleEqual(il1.minor_to_major, ol1.minor_to_major[::-1])
# self.assertTupleEqual(il2.minor_to_major, ol2.minor_to_major[::-1])
# self.assertEqual(il3, ol3)
# out1, out2, out3 = compiled(arr, arr, arr)
# np_inp_t = np_inp.T
# self.assertArraysEqual(out1, np_inp_t)
# self.assertArraysEqual(out2, np_inp_t)
# self.assertArraysEqual(out3, np_inp * 2)
def test_in_layouts_out_layouts(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
def f(x):
return x.T
compiled = jax.jit(f).lower(arr, _in_layouts=layout.DefaultLayout(),
_out_layouts=layout.AUTO).compile()
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0))
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1))
out = compiled(arr)
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
def test_sharding_and_layouts(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (4, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
out = compiled(np_inp)
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0))
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, s)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())