mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a private API to allow setting layouts on jitted computations.
We expose 3 modes: * `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet. * `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior. * `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit. Public API coming soon. Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 582692036
This commit is contained in:
parent
b032a0271e
commit
5c3da219c0
12
jax/BUILD
12
jax/BUILD
@ -228,6 +228,7 @@ py_library_providing_imports_info(
|
||||
":effects",
|
||||
":environment_info",
|
||||
":jaxpr_util",
|
||||
":layout",
|
||||
":lazy_loader",
|
||||
":mesh",
|
||||
":mlir",
|
||||
@ -478,6 +479,7 @@ pytype_strict_library(
|
||||
":core",
|
||||
":dtypes",
|
||||
":effects",
|
||||
":layout",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":pickle_util",
|
||||
@ -633,6 +635,16 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "layout",
|
||||
srcs = ["_src/layout.py"],
|
||||
deps = [
|
||||
":util",
|
||||
":xla_bridge",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sharding_impls",
|
||||
srcs = ["_src/sharding_impls.py"],
|
||||
|
@ -304,13 +304,18 @@ def jit(
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
|
||||
# layout.DefaultLayout() when out of experimental.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
pjit_info_args = pjit.PjitInfo(
|
||||
fun=fun, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=None,
|
||||
abstracted_axes=abstracted_axes)
|
||||
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = pjit._pjit_explicit_sharding(
|
||||
|
@ -180,8 +180,8 @@ def sharded_lowering(
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None,
|
||||
lowering_parameters=lowering_parameters)
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters,
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=None)
|
||||
|
||||
|
||||
def simple_impl(prim):
|
||||
|
@ -41,6 +41,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import XLACompatibleLayout, LayoutRequest
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
@ -691,6 +692,15 @@ def _to_logical_op_sharding(
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim)
|
||||
|
||||
|
||||
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
|
||||
if layout is None:
|
||||
return None
|
||||
if isinstance(layout, LayoutRequest):
|
||||
return "auto"
|
||||
return layout._to_xla_layout()
|
||||
|
||||
|
||||
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
|
||||
if s is None:
|
||||
return None
|
||||
@ -711,6 +721,8 @@ def lower_jaxpr_to_module(
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
|
||||
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
num_replicas: int = 1,
|
||||
@ -784,6 +796,11 @@ def lower_jaxpr_to_module(
|
||||
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
|
||||
else in_layouts)
|
||||
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
|
||||
else out_layouts)
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
name_stack=name_stack,
|
||||
@ -815,7 +832,9 @@ def lower_jaxpr_to_module(
|
||||
arg_names=arg_names,
|
||||
result_names=result_names,
|
||||
arg_memory_kinds=arg_memory_kinds,
|
||||
result_memory_kinds=result_memory_kinds)
|
||||
result_memory_kinds=result_memory_kinds,
|
||||
arg_layouts=arg_layouts,
|
||||
result_layouts=result_layouts)
|
||||
|
||||
try:
|
||||
if not ctx.module.operation.verify():
|
||||
@ -969,6 +988,8 @@ def lower_jaxpr_to_fun(
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
arg_memory_kinds: Sequence[str | None] | None = None,
|
||||
result_memory_kinds: Sequence[str | None] | None = None,
|
||||
arg_layouts: Sequence[str | None] | None = None,
|
||||
result_layouts: Sequence[str | None] | None = None,
|
||||
) -> func_dialect.FuncOp:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
@ -1055,6 +1076,12 @@ def lower_jaxpr_to_fun(
|
||||
if result_memory_kinds is not None:
|
||||
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
|
||||
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
|
||||
if arg_layouts is not None:
|
||||
token_layouts = [None] * (num_dim_vars + num_tokens)
|
||||
arg_layouts = [*token_layouts, *arg_layouts]
|
||||
if result_layouts is not None:
|
||||
token_layouts = [None] * (num_tokens + num_output_tokens)
|
||||
result_layouts = [*token_layouts, *result_layouts]
|
||||
|
||||
flat_input_types = util.flatten(input_types)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
@ -1077,6 +1104,11 @@ def lower_jaxpr_to_fun(
|
||||
ir_arg_memory_kinds = util.flatten(
|
||||
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
|
||||
|
||||
ir_arg_layouts = None
|
||||
if arg_layouts is not None:
|
||||
ir_arg_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
|
||||
|
||||
ir_result_shardings = None
|
||||
if result_shardings is not None:
|
||||
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
|
||||
@ -1090,9 +1122,15 @@ def lower_jaxpr_to_fun(
|
||||
ir_result_memory_kinds = util.flatten(
|
||||
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
|
||||
|
||||
ir_result_layouts = None
|
||||
if result_layouts is not None:
|
||||
ir_result_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
|
||||
|
||||
if (
|
||||
replicated_args is not None
|
||||
or ir_arg_shardings is not None
|
||||
or ir_arg_layouts is not None
|
||||
or input_output_aliases is not None
|
||||
or arg_names is not None
|
||||
or num_tokens > 0
|
||||
@ -1113,6 +1151,11 @@ def lower_jaxpr_to_fun(
|
||||
if sharding is not None:
|
||||
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
|
||||
|
||||
if ir_arg_layouts is not None:
|
||||
for attrs, layout in zip(arg_attrs, ir_arg_layouts):
|
||||
if layout is not None:
|
||||
attrs["mhlo.layout_mode"] = ir.StringAttr.get(layout)
|
||||
|
||||
if input_output_aliases is not None:
|
||||
output_ids = util.unflatten(list(range(len(flat_output_types))),
|
||||
map(len, output_types))
|
||||
@ -1162,6 +1205,11 @@ def lower_jaxpr_to_fun(
|
||||
if sharding is not None:
|
||||
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
|
||||
|
||||
if ir_result_layouts is not None:
|
||||
for attrs, layout in zip(result_attrs, ir_result_layouts):
|
||||
if layout is not None:
|
||||
attrs['mhlo.layout_mode'] = ir.StringAttr.get(layout)
|
||||
|
||||
func_op.result_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in result_attrs])
|
||||
|
||||
|
@ -59,6 +59,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -1768,7 +1769,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
@weakref_lru_cache
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
da_object,
|
||||
in_layouts, out_layouts, da_object,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
@ -1842,6 +1843,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_mlir_shardings,
|
||||
result_shardings=out_mlir_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
@ -1939,6 +1942,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
return False
|
||||
return True
|
||||
|
||||
MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
@ -1954,6 +1958,8 @@ def lower_sharding_computation(
|
||||
inline: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: Optional[MaybeLayout],
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -1973,12 +1979,16 @@ def lower_sharding_computation(
|
||||
donated_invars, auto_spmd_lowering)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
|
||||
|
||||
if is_unspecified(out_shardings):
|
||||
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
|
||||
if out_layouts is None:
|
||||
out_layouts = (None,) * len(global_out_avals)
|
||||
assert isinstance(out_shardings, tuple)
|
||||
assert len(out_shardings) == len(global_out_avals), (
|
||||
len(out_shardings), len(global_out_avals))
|
||||
assert isinstance(out_layouts, tuple)
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
@ -2030,7 +2040,7 @@ def lower_sharding_computation(
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, da_object,
|
||||
semantic_out_shardings, in_layouts, out_layouts, da_object,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
@ -2058,6 +2068,8 @@ def lower_sharding_computation(
|
||||
backend=backend,
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
pmap_nreps=nreps,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=shape_poly_state,
|
||||
@ -2233,6 +2245,8 @@ def lower_mesh_computation(
|
||||
backend=backend,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True,
|
||||
in_layouts=(None,) * len(global_in_avals),
|
||||
out_layouts=(None,) * len(global_out_avals),
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
@ -2447,6 +2461,42 @@ def maybe_get_orig_out_sharding(
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
||||
def _get_layouts_from_executable(
|
||||
xla_executable, in_layouts, out_layouts
|
||||
) -> Sequence[Sequence[XLACompatibleLayout | None], Sequence[XLACompatibleLayout | None]]: # type: ignore
|
||||
if all(i is None for i in in_layouts) and all(o is None for o in out_layouts):
|
||||
return in_layouts, out_layouts # type: ignore
|
||||
|
||||
in_layouts_xla = xla_executable.get_parameter_layouts()
|
||||
out_layouts_xla = xla_executable.get_output_layouts()
|
||||
|
||||
new_in_layouts = []
|
||||
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
||||
x = SpecifiedLayout._from_xla_layout(x)
|
||||
if isinstance(i, SpecifiedLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
|
||||
new_in_layouts.append(i)
|
||||
else:
|
||||
new_in_layouts.append(x)
|
||||
|
||||
new_out_layouts = []
|
||||
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
||||
x = SpecifiedLayout._from_xla_layout(x)
|
||||
if isinstance(o, SpecifiedLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
|
||||
new_out_layouts.append(o)
|
||||
else:
|
||||
new_out_layouts.append(x)
|
||||
|
||||
assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts)
|
||||
assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts)
|
||||
return new_in_layouts, new_out_layouts
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering,
|
||||
@ -2534,6 +2584,8 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx: set[int]
|
||||
auto_spmd_lowering: bool
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None
|
||||
in_layouts: Sequence[SpecifiedLayout | None]
|
||||
out_layouts: Sequence[SpecifiedLayout | None]
|
||||
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
@ -2555,6 +2607,7 @@ class UnloadedMeshExecutable:
|
||||
self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.in_layouts, self.out_layouts,
|
||||
self.jaxpr_debug_info, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@ -2577,6 +2630,8 @@ class UnloadedMeshExecutable:
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
pmap_nreps: int = 1,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
@ -2659,6 +2714,9 @@ class UnloadedMeshExecutable:
|
||||
else:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
in_layouts, out_layouts = _get_layouts_from_executable(
|
||||
xla_executable, in_layouts, out_layouts)
|
||||
|
||||
if pmap_nreps > 1:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
@ -2684,7 +2742,9 @@ class UnloadedMeshExecutable:
|
||||
host_callbacks=host_callbacks,
|
||||
kept_var_idx=kept_var_idx,
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
jaxpr_debug_info=jaxpr_debug_info).load()
|
||||
jaxpr_debug_info=jaxpr_debug_info,
|
||||
in_layouts=in_layouts, # type: ignore
|
||||
out_layouts=out_layouts).load() # type: ignore
|
||||
|
||||
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
@ -2709,12 +2769,13 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_jaxpr_debug_info", "_unloaded_executable",
|
||||
"_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
jaxpr_debug_info=None, unloaded_executable=None):
|
||||
in_layouts, out_layouts, jaxpr_debug_info=None,
|
||||
unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
@ -2725,6 +2786,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._in_layouts = in_layouts
|
||||
self._out_layouts = out_layouts
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@ -2755,6 +2818,12 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
return self._out_shardings
|
||||
|
||||
def input_layouts(self):
|
||||
return self._in_layouts
|
||||
|
||||
def output_layouts(self):
|
||||
return self._out_layouts
|
||||
|
||||
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
|
||||
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
||||
not self.unsafe_call.has_unordered_effects and
|
||||
@ -2886,7 +2955,8 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, jaxpr_debug_info, None)
|
||||
kept_var_idx, (None,) * len(global_in_avals),
|
||||
(None,) * len(global_out_avals), jaxpr_debug_info, None)
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
78
jax/_src/layout.py
Normal file
78
jax/_src/layout.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the ific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
class Layout:
|
||||
pass
|
||||
|
||||
|
||||
class XLACompatibleLayout(Layout):
|
||||
@classmethod
|
||||
def _from_xla_layout(cls, xla_layout) -> XLACompatibleLayout:
|
||||
raise NotImplementedError("Subclasses should implement this method.")
|
||||
|
||||
def _to_xla_layout(self) -> str:
|
||||
raise NotImplementedError("Subclasses should implement this method.")
|
||||
|
||||
|
||||
class SpecifiedLayout(XLACompatibleLayout):
|
||||
minor_to_major: tuple[int, ...]
|
||||
|
||||
def __init__(self, minor_to_major: tuple[int, ...]):
|
||||
self.minor_to_major = minor_to_major
|
||||
|
||||
def __repr__(self):
|
||||
return f'SpecifiedLayout(minor_to_major={self.minor_to_major})'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.minor_to_major)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, SpecifiedLayout):
|
||||
return False
|
||||
return self.minor_to_major == other.minor_to_major
|
||||
|
||||
@classmethod
|
||||
def _from_xla_layout(cls, xla_layout: xc.Layout) -> XLACompatibleLayout:
|
||||
return cls(xla_layout.minor_to_major())
|
||||
|
||||
def _to_xla_layout(self) -> str:
|
||||
return xc.Layout(self.minor_to_major).to_string()
|
||||
|
||||
|
||||
class DefaultLayout(XLACompatibleLayout):
|
||||
|
||||
def __repr__(self):
|
||||
return 'DefaultLayout()'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(type(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DefaultLayout) and type(self) == type(other)
|
||||
|
||||
def _to_xla_layout(self) -> str:
|
||||
return "default"
|
||||
|
||||
|
||||
class LayoutRequest:
|
||||
|
||||
def __repr__(self):
|
||||
return "Request a layout from the compiler"
|
||||
|
||||
AUTO = LayoutRequest()
|
@ -159,7 +159,7 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
|
||||
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
|
||||
*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
@ -328,8 +328,13 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
def lower(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
|
||||
# layout.DefaultLayout() when out of experimental.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars) = infer_params_fn(*args, **kwargs)
|
||||
donated_invars, in_layouts_flat, out_layouts_flat) = infer_params_fn(
|
||||
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
try:
|
||||
@ -338,8 +343,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'],
|
||||
lowering_parameters=lowering_parameters)
|
||||
params['keep_unused'], params['inline'], in_layouts=in_layouts_flat,
|
||||
out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
@ -387,12 +392,14 @@ class PjitInfo(NamedTuple):
|
||||
inline: bool
|
||||
resource_env: Any
|
||||
abstracted_axes: Optional[Any]
|
||||
in_layouts: Any # pytree[XlaCompatibleLayout] | None
|
||||
out_layouts: Any # pytree[XlaCompatibleLayout] | None
|
||||
|
||||
|
||||
def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
|
||||
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
|
||||
resource_env, abstracted_axes) = pjit_info_args
|
||||
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args
|
||||
|
||||
if (kwargs and user_in_shardings is not None and
|
||||
not is_unspecified(user_in_shardings)):
|
||||
@ -479,16 +486,16 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
) from e
|
||||
in_type = in_avals = tuple(avals)
|
||||
|
||||
canonicalized_in_shardings_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg,
|
||||
device_or_backend_set)
|
||||
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals,
|
||||
in_tree, resource_env, dbg, device_or_backend_set)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), in_type, dbg,
|
||||
device_or_backend_set, HashableFunction(out_tree, closure=()),
|
||||
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts),
|
||||
in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()),
|
||||
HashableFunction(res_paths, closure=()))
|
||||
|
||||
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
|
||||
assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat)
|
||||
|
||||
if config.dynamic_shapes.value:
|
||||
implicit_args = _extract_implicit_args(in_type, explicit_args)
|
||||
@ -499,9 +506,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
num_extra_args = len(implicit_args) + len(consts)
|
||||
canonicalized_in_shardings_flat = \
|
||||
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
|
||||
donated_invars = (False,) * num_extra_args + donated_invars
|
||||
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
|
||||
len(consts) + len(args_flat))
|
||||
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
|
||||
len(donated_invars) == len(consts) + len(args_flat))
|
||||
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
@ -515,7 +523,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
inline=inline,
|
||||
)
|
||||
return (consts + args_flat, in_type, params, in_tree, out_tree(),
|
||||
donated_invars)
|
||||
donated_invars, in_layouts_flat, out_layouts_flat)
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
@ -758,13 +766,18 @@ def pjit(
|
||||
def infer_params(*args, **kwargs):
|
||||
# Putting this outside of wrapped would make resources lexically scoped
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
|
||||
# layout.DefaultLayout() when out of experimental.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
pjit_info_args = PjitInfo(
|
||||
fun=fun, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
|
||||
abstracted_axes=abstracted_axes)
|
||||
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
@ -880,8 +893,9 @@ class PytreeLeaf:
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
resource_env, debug_info, device_or_backend_set):
|
||||
def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals,
|
||||
in_tree, resource_env, debug_info,
|
||||
device_or_backend_set):
|
||||
orig_in_shardings = in_shardings_thunk()
|
||||
# Only do this if original in_shardings are unspecified. If it is AUTO, go
|
||||
# via flatten_axis_resources.
|
||||
@ -889,8 +903,14 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
|
||||
else:
|
||||
in_shardings_flat = flatten_axis_resources(
|
||||
"pjit in_shardings", in_tree, orig_in_shardings,
|
||||
tupled_args=True)
|
||||
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
|
||||
|
||||
in_layouts = in_layouts_thunk()
|
||||
if in_layouts is None:
|
||||
in_layouts_flat = (in_layouts,) * len(in_avals)
|
||||
else:
|
||||
in_layouts_flat = flatten_axis_resources(
|
||||
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
|
||||
|
||||
if not config.dynamic_shapes.value:
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||
@ -900,7 +920,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
i if is_unspecified_or_auto(i) else
|
||||
to_gspmd_sharding(i, aval.ndim, device_or_backend_set)
|
||||
for i, aval in zip(in_shardings_flat, in_avals))
|
||||
return canonicalized_shardings
|
||||
return canonicalized_shardings, tuple(in_layouts_flat)
|
||||
|
||||
|
||||
@lu.cache
|
||||
@ -930,7 +950,8 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, out_type, debug_info, device_or_backend_set):
|
||||
out_shardings_thunk, out_layouts_thunk, out_tree, out_type, debug_info,
|
||||
device_or_backend_set):
|
||||
orig_out_shardings = out_shardings_thunk()
|
||||
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
|
||||
# instead. This condition exists because flatten_axis_resources passes in an
|
||||
@ -944,6 +965,13 @@ def _check_and_canonicalize_out_shardings(
|
||||
"pjit out_shardings", out_tree(), orig_out_shardings,
|
||||
tupled_args=False)
|
||||
|
||||
out_layouts = out_layouts_thunk()
|
||||
if out_layouts is None:
|
||||
out_layouts_flat = (out_layouts,) * len(out_type)
|
||||
else:
|
||||
out_layouts_flat = flatten_axis_resources(
|
||||
"pjit out_layouts", out_tree(), out_layouts, tupled_args=False)
|
||||
|
||||
if not config.dynamic_shapes.value:
|
||||
pjit_check_aval_sharding(
|
||||
out_shardings_flat, out_type,
|
||||
@ -955,18 +983,18 @@ def _check_and_canonicalize_out_shardings(
|
||||
to_gspmd_sharding(o, aval.ndim, device_or_backend_set)
|
||||
for o, aval in zip(out_shardings_flat, out_type)
|
||||
)
|
||||
return canonicalized_out_shardings_flat
|
||||
return canonicalized_out_shardings_flat, tuple(out_layouts_flat)
|
||||
|
||||
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info,
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info,
|
||||
device_or_backend_set, out_tree, result_paths):
|
||||
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
|
||||
fun, in_type, debug_info, result_paths)
|
||||
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info,
|
||||
device_or_backend_set)
|
||||
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type),
|
||||
jaxpr.jaxpr.debug_info, device_or_backend_set)
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat
|
||||
|
||||
|
||||
def pjit_check_aval_sharding(
|
||||
@ -1271,11 +1299,20 @@ def _pjit_lower_cached(
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: Optional[pxla.MaybeLayout] = None,
|
||||
out_layouts: Optional[pxla.MaybeLayout] = None):
|
||||
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
|
||||
# TODO(yashkatariya): Remove this when layouts are supported on jit and
|
||||
# passed to params.
|
||||
if in_layouts is None:
|
||||
in_layouts = (None,) * len(in_shardings)
|
||||
if out_layouts is None:
|
||||
out_layouts = (None,) * len(out_shardings)
|
||||
|
||||
if resource_env is not None:
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
|
||||
@ -1302,8 +1339,8 @@ def _pjit_lower_cached(
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_parameters=lowering_parameters,
|
||||
)
|
||||
lowering_parameters=lowering_parameters, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
|
@ -42,6 +42,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.layout import SpecifiedLayout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -84,6 +85,14 @@ class Executable(Protocol):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Layouts are exposed via jax.experimental.layouts
|
||||
# TODO(frostig,yashkatariya): expose here when no longer experimental.
|
||||
def _input_layouts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _output_layouts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def as_text(self) -> str:
|
||||
"""A human-readable text representation of this executable.
|
||||
|
||||
@ -216,6 +225,14 @@ class XlaExecutable(Executable):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no output sharding information")
|
||||
|
||||
def _input_layouts(self):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no input layout information")
|
||||
|
||||
def _output_layouts(self):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no input layout information")
|
||||
|
||||
def as_text(self) -> str:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
err_msg = ("text view unsupported on current XLA backend: "
|
||||
@ -481,6 +498,16 @@ class Compiled(Stage):
|
||||
shardings_flat = self._executable.output_shardings()
|
||||
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
|
||||
|
||||
def _input_layouts(self):
|
||||
layouts_flat = self._executable.input_layouts()
|
||||
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
||||
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
def _output_layouts(self):
|
||||
layouts_flat = self._executable.output_layouts()
|
||||
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
||||
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
@staticmethod
|
||||
def call(*args, **kwargs):
|
||||
# This is because `__call__` passes in `self._params` as the first argument.
|
||||
|
@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None:
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
"all_default_mem_kind"]
|
||||
"all_default_mem_kind", "in_layouts", "out_layouts"]
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
19
jax/experimental/layout.py
Normal file
19
jax/experimental/layout.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the ific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.layout import (
|
||||
DefaultLayout as DefaultLayout,
|
||||
SpecifiedLayout as SpecifiedLayout,
|
||||
AUTO as AUTO,
|
||||
)
|
@ -222,6 +222,12 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "layout_test",
|
||||
srcs = ["layout_test.py"],
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "pgle_test",
|
||||
srcs = ["pgle_test.py"],
|
||||
|
226
tests/layout_test.py
Normal file
226
tests/layout_test.py
Normal file
@ -0,0 +1,226 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import layout
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
flags_str = prev_xla_flags or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in flags_str:
|
||||
os.environ["XLA_FLAGS"] = (flags_str +
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
if prev_xla_flags is None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("Layouts do not work on CPU and GPU backends yet.")
|
||||
if xla_extension_version < 215:
|
||||
self.skipTest('All tests require xla_extension_version >= 215')
|
||||
super().setUp()
|
||||
|
||||
def test_auto_layout(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
shape2 = (128, 128)
|
||||
|
||||
def apply(x, y):
|
||||
return x.T, y.T
|
||||
|
||||
def init(x, y):
|
||||
return x * 2, y * 2
|
||||
|
||||
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y')))
|
||||
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
|
||||
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x')))
|
||||
|
||||
lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO,
|
||||
_out_layouts=layout.AUTO)
|
||||
compiled_apply = lowered_apply.compile()
|
||||
|
||||
arg_layouts, kw_layouts = compiled_apply._input_layouts()
|
||||
self.assertEmpty(kw_layouts)
|
||||
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
|
||||
self.assertEqual(i.minor_to_major, o.minor_to_major[::-1])
|
||||
|
||||
init_compiled = jax.jit(init).lower(
|
||||
arr1, arr2, _out_layouts=arg_layouts).compile()
|
||||
|
||||
for i, o in zip(init_compiled._input_layouts()[0],
|
||||
init_compiled._output_layouts()):
|
||||
self.assertEqual(i.minor_to_major, o.minor_to_major)
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
|
||||
init_out = init_compiled(arr1, arr2)
|
||||
init_compiled(arr1, arr2)
|
||||
self.assertEqual(init_count[0], 1)
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
||||
apply_out = compiled_apply(*init_out)
|
||||
compiled_apply(*init_out)
|
||||
self.assertEqual(apply_count[0], 1)
|
||||
|
||||
self.assertArraysEqual(init_out[0], np_inp1 * 2)
|
||||
self.assertArraysEqual(init_out[1], np_inp2 * 2)
|
||||
self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T)
|
||||
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
|
||||
|
||||
def test_specified_layouts_on_jit(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
sl = layout.SpecifiedLayout((0, 2, 1))
|
||||
|
||||
out1 = jax.jit(lambda x: x).lower(arr, _out_layouts=sl).compile()(arr)
|
||||
|
||||
compiled = jax.jit(f).lower(out1, _in_layouts=sl, _out_layouts=sl).compile()
|
||||
out2 = compiled(out1)
|
||||
|
||||
self.assertEqual(compiled._input_layouts()[0][0], sl)
|
||||
self.assertEqual(compiled._output_layouts(), sl)
|
||||
self.assertArraysEqual(out2, np_inp.T)
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
def test_default_layout(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
arr,
|
||||
_in_layouts=layout.DefaultLayout(),
|
||||
_out_layouts=layout.DefaultLayout()).compile()
|
||||
out = compiled(arr)
|
||||
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (2, 1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (2, 1, 0))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO,
|
||||
_out_layouts=layout.AUTO).compile()
|
||||
self.assertTupleEqual(compiled_auto._input_layouts()[0][0].minor_to_major,
|
||||
(2, 1, 0))
|
||||
self.assertTupleEqual(compiled_auto._output_layouts().minor_to_major,
|
||||
(0, 1, 2))
|
||||
|
||||
# TODO(yashkatariya): Enable after mixture of auto, default and specified
|
||||
# layouts work.
|
||||
# def test_auto_specified_default_layout_with_sharding(self):
|
||||
# mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
# shape = (8, 4, 2)
|
||||
# np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
# s = NamedSharding(mesh, P('x', 'y'))
|
||||
# arr = jax.device_put(np_inp, s)
|
||||
|
||||
# def f(x, y, z):
|
||||
# return x.T, y.T, z * 2
|
||||
|
||||
# lowered = jax.jit(f).lower(
|
||||
# arr, arr, arr,
|
||||
# _in_layouts=(
|
||||
# layout.SpecifiedLayout((0, 2, 1)),
|
||||
# layout.AUTO,
|
||||
# layout.DefaultLayout()),
|
||||
# _out_layouts=layout.AUTO)
|
||||
# compiled = lowered.compile()
|
||||
|
||||
# il1, il2, il3 = compiled._input_layouts()[0]
|
||||
# ol1, ol2, ol3 = compiled._output_layouts()
|
||||
|
||||
# self.assertTupleEqual(il1.minor_to_major, ol1.minor_to_major[::-1])
|
||||
# self.assertTupleEqual(il2.minor_to_major, ol2.minor_to_major[::-1])
|
||||
# self.assertEqual(il3, ol3)
|
||||
|
||||
# out1, out2, out3 = compiled(arr, arr, arr)
|
||||
# np_inp_t = np_inp.T
|
||||
# self.assertArraysEqual(out1, np_inp_t)
|
||||
# self.assertArraysEqual(out2, np_inp_t)
|
||||
# self.assertArraysEqual(out3, np_inp * 2)
|
||||
|
||||
def test_in_layouts_out_layouts(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
compiled = jax.jit(f).lower(arr, _in_layouts=layout.DefaultLayout(),
|
||||
_out_layouts=layout.AUTO).compile()
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1))
|
||||
|
||||
out = compiled(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def test_sharding_and_layouts(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (4, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
|
||||
np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
|
||||
out = compiled(np_inp)
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0].minor_to_major, (1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts().minor_to_major, (0, 1))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user