Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.

`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
This commit is contained in:
Yash Katariya 2024-06-27 16:46:44 -07:00 committed by jax authors
parent d577e29998
commit e1a496d3b6
9 changed files with 228 additions and 82 deletions

View File

@ -800,10 +800,11 @@ pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],
deps = [
":dtypes",
":sharding",
":sharding_impls",
"//jax/_src/lib",
],
] + py_deps("numpy"),
)
pytype_strict_library(

View File

@ -506,7 +506,8 @@ class ArrayImpl(basearray.Array):
if self.is_deleted():
return Layout(None, self.sharding)
try:
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout),
self.sharding)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):

View File

@ -824,12 +824,15 @@ def _to_physical_op_sharding(
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
aval: core.AbstractValue) -> str | None:
if layout is None:
return "default"
if isinstance(layout, AutoLayout):
return "auto"
return layout._to_xla_layout()
if aval is core.abstract_token:
return "default"
return layout._to_xla_layout(aval.dtype) # type: ignore
def _get_mem_kind(s: JSharding | None) -> str | None:
@ -1194,11 +1197,9 @@ def lower_jaxpr_to_fun(
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
ir_arg_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
for a, s, types in zip(input_avals, arg_shardings, input_types)])
ir_arg_memory_kinds = None
if arg_memory_kinds is not None:
@ -1208,8 +1209,8 @@ def lower_jaxpr_to_fun(
ir_arg_layouts = None
if arg_layouts is not None:
ir_arg_layouts = util.flatten(
[[_to_xla_layout(l)] * len(types)
for l, types in zip(arg_layouts, input_types)])
[[_to_xla_layout(l, a)] * len(types)
for l, a, types in zip(arg_layouts, input_avals, input_types)])
ir_donated_args = None
if xla_donated_args is not None:
@ -1244,8 +1245,8 @@ def lower_jaxpr_to_fun(
ir_result_layouts = None
if result_layouts is not None:
ir_result_layouts = util.flatten(
[[_to_xla_layout(l)] * len(types)
for l, types in zip(result_layouts, output_types)])
[[_to_xla_layout(l, a)] * len(types)
for l, a, types in zip(result_layouts, output_avals, output_types)])
if (
replicated_args is not None
@ -2171,6 +2172,29 @@ def get_sharding_attr(sharding_proto: xc.OpSharding):
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
def wrap_with_layout_op(ctx: LoweringRuleContext,
x: ir.Value,
aval_out: core.AbstractValue,
layout: DeviceLocalLayout,
aval_in: core.AbstractValue):
result_type = aval_to_ir_type(aval_out)
out_shape = core.physical_aval(aval_out).shape # type: ignore
if core.is_constant_shape(out_shape):
result_shapes = None
else:
result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]
op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x],
api_version=1,
result_shapes=result_shapes,
# Set operand layouts to anything. XLA will ignore it.
operand_layouts=[list(range(aval_in.ndim))], # type: ignore
# TODO(yashkatariya): Figure out how to pass tiling to the
# custom call.
result_layouts=[layout.major_to_minor[::-1]])
return op.result
# MLIR lowerings for lax primitives
def cache_lowering(f):

View File

@ -1937,7 +1937,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
# first call you pass it a sharded array with layout and on second call you
# pass a numpy array. The layouts should be the same to get cache hits.
try:
al = DeviceLocalLayout(
al = DeviceLocalLayout.from_pjrt_layout(
d.client.get_default_layout(aval.dtype, shard_shape, d))
except:
return None
@ -2704,7 +2704,7 @@ def _get_layouts_from_executable(
new_in_layouts = []
for x, i in safe_zip(in_layouts_xla, in_layouts):
x = DeviceLocalLayout(x)
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(i, DeviceLocalLayout):
if i != x:
raise AssertionError(
@ -2716,7 +2716,7 @@ def _get_layouts_from_executable(
new_out_layouts = []
for x, o in safe_zip(out_layouts_xla, out_layouts):
x = DeviceLocalLayout(x)
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(o, DeviceLocalLayout):
if o != x:
raise AssertionError(

View File

@ -16,9 +16,12 @@ from __future__ import annotations
from typing import Union
import numpy as np
from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
class AutoLayout:
@ -27,31 +30,78 @@ class AutoLayout:
return "AUTO"
class DeviceLocalLayout:
layout: xc.PjRtLayout
if xla_extension_version >= 274:
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
tiling: tuple[tuple[int, ...], ...] | None
AUTO = AutoLayout()
AUTO = AutoLayout()
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)
def __init__(self, major_to_minor: tuple[int, ...],
tiling: tuple[tuple[int, ...], ...] | None = None):
self.major_to_minor = tuple(major_to_minor)
self.tiling = None if tiling is None else tuple(map(tuple, tiling))
def __repr__(self):
return f'DeviceLocalLayout({self._layout_str})'
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling())
def __hash__(self):
return hash(self._layout)
def __repr__(self):
return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' tiling={self.tiling})')
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout
def __hash__(self):
return hash((self.major_to_minor, self.tiling))
def _to_xla_layout(self) -> str:
return self._layout_str
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self.tiling == other.tiling)
def _to_xla_layout(self, dtype) -> str:
if self.tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
if issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore
sub_byte_size)
return str(xla_layout)
else:
class DeviceLocalLayout: # type: ignore
layout: xc.PjRtLayout
AUTO = AutoLayout()
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
return DeviceLocalLayout(pjrt_layout) # type: ignore
def __repr__(self):
return f'DeviceLocalLayout({self._layout_str})'
def __hash__(self):
return hash(self._layout)
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout
def _to_xla_layout(self, dtype) -> str:
return self._layout_str
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
ShardingOptions = Union[Sharding, None, AutoSharding]

View File

@ -1816,6 +1816,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
[tmpvar], [outvar], sharding_constraint_p,
dict(resource_env=resource_env,
sharding=gspmd_sharding,
layout=None,
unconstrained_dims=unconstrained_dims),
set(),
eqn.source_info, eqn.ctx))

View File

@ -2488,6 +2488,9 @@ def with_sharding_constraint(x, shardings):
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
"""
x_flat, tree = tree_flatten(x)
layouts, shardings = _split_layout_and_sharding(shardings)
user_shardings = prepare_axis_resources(
shardings, "shardings", allow_unconstrained_dims=True)
del shardings
@ -2496,6 +2499,10 @@ def with_sharding_constraint(x, shardings):
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
user_layouts_flat = tuple(
flatten_axes("with_sharding_constraint layouts", tree, layouts))
del layouts
resource_env = mesh_lib.thread_resources.env
mesh = resource_env.physical_mesh
@ -2511,19 +2518,27 @@ def with_sharding_constraint(x, shardings):
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=s,
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
resource_env=resource_env,
unconstrained_dims=ud)
for xf, s, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat,
unconstrained_dims)]
return tree_unflatten(tree, outs)
def _identity_fn(x): return x
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
return x
# Run a jit here to raise good errors when device assignment don't match.
return api.jit(_identity_fn, out_shardings=sharding)(x)
def _sharding_constraint_impl(x, sharding, layout, resource_env,
unconstrained_dims):
if layout is None:
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
return x
# Run a jit here to raise good errors when device assignment don't match.
return api.jit(_identity_fn, out_shardings=sharding)(x)
else:
if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and
x.sharding.is_equivalent_to(sharding, x.ndim)):
return x
return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x)
sharding_constraint_p = core.Primitive("sharding_constraint")
@ -2532,7 +2547,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout,
resource_env, unconstrained_dims):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
@ -2547,19 +2562,19 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0]
sharding = NamedSharding._from_parsed_pspec(
mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes)
return [
mlir.wrap_with_sharding_op(ctx,
x_node, out_aval,
sharding._to_xla_hlo_sharding(aval.ndim).to_proto(),
unspecified_dims=unconstrained_dims)
]
out = mlir.wrap_with_sharding_op(
ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(),
unspecified_dims=unconstrained_dims)
if layout is not None:
out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval)
return [out]
mlir.register_lowering(sharding_constraint_p,
_sharding_constraint_hlo_lowering)
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
axis_name, main_type, vals_in, dims_in,
sharding, resource_env, unconstrained_dims):
def _sharding_constraint_batcher(
insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in,
dims_in, sharding, layout, resource_env, unconstrained_dims):
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
used = {n for ns in sharding.spec
for n in (ns if isinstance(ns, tuple) else (ns,))}
@ -2586,9 +2601,14 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
vmapped_sharding = NamedSharding(
vmapped_sharding.mesh, PartitionSpec(*new_spec))
# TODO(yashkatariya): Figure out layouts should change under vmap.
if layout is not None:
raise NotImplementedError
y = sharding_constraint_p.bind(
x,
sharding=vmapped_sharding,
layout=layout,
resource_env=resource_env,
unconstrained_dims=unconstrained_dims)
return y, d

View File

@ -17,7 +17,6 @@ import asyncio
import contextlib
import math
from functools import partial
import re
import os
import pathlib
import tracemalloc as tm
@ -46,13 +45,6 @@ def tearDownModule():
_exit_stack.close()
pattern = re.compile(r"\{(.*?):")
def extract_minor_to_major(l):
match = re.search(pattern, str(l))
return tuple(int(i) for i in match.groups()[0].split(','))
class CheckpointTest(jtu.JaxTestCase):
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
@ -436,8 +428,8 @@ class CheckpointTest(jtu.JaxTestCase):
out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower(
arr).compile().output_layouts()
self.assertEqual(extract_minor_to_major(arr.layout),
extract_minor_to_major(out_layout)[::-1])
self.assertEqual(arr.layout.device_local_layout.major_to_minor,
out_layout.device_local_layout.major_to_minor[::-1])
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)

View File

@ -14,9 +14,9 @@
import contextlib
import math
import re
from absl.testing import absltest
import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
@ -37,15 +37,6 @@ def tearDownModule():
_exit_stack.close()
pattern = re.compile(r"\{(.*?):")
# Extract minor_to_major from str(layout) because layout doesn't have a
# minor_to_major property yet.
def extract_minor_to_major(l):
match = re.search(pattern, str(l))
return tuple(int(i) for i in match.groups()[0].split(','))
class LayoutTest(jtu.JaxTestCase):
def setUp(self):
@ -79,8 +70,8 @@ class LayoutTest(jtu.JaxTestCase):
self.assertEmpty(kw_layouts)
for i, o in zip(arg_layouts, compiled_apply.output_layouts()):
self.assertEqual(extract_minor_to_major(i),
extract_minor_to_major(o)[::-1])
self.assertEqual(i.device_local_layout.major_to_minor,
o.device_local_layout.major_to_minor[::-1])
init_compiled = jax.jit(
init, out_shardings=arg_layouts).lower(sds1, sds2).compile()
@ -108,10 +99,10 @@ class LayoutTest(jtu.JaxTestCase):
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0])
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1])
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
extract_minor_to_major(init_out[0].layout)[::-1])
self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout),
extract_minor_to_major(init_out[1].layout)[::-1])
self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor,
init_out[0].layout.device_local_layout.major_to_minor[::-1])
self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor,
init_out[1].layout.device_local_layout.major_to_minor[::-1])
self.assertArraysEqual(init_out[0], np_inp1 * 2)
self.assertArraysEqual(init_out[1], np_inp2 * 2)
@ -135,18 +126,22 @@ class LayoutTest(jtu.JaxTestCase):
out = compiled(arr)
self.assertTupleEqual(
extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0))
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
(2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled.output_layouts()), (2, 1, 0))
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
(2, 1, 0))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO),
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0))
compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
(2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2))
compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1],
(0, 1, 2))
with self.assertRaisesRegex(
ValueError, "jax.jit` does not accept device-local layouts directly"):
@ -166,9 +161,11 @@ class LayoutTest(jtu.JaxTestCase):
compiled = jax.jit(f, in_shardings=Layout(),
out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
(1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled.output_layouts()), (0, 1))
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
(0, 1))
out = compiled(arr)
self.assertArraysEqual(out, np_inp.T)
@ -185,9 +182,11 @@ class LayoutTest(jtu.JaxTestCase):
out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile()
out = compiled(np_inp)
self.assertTupleEqual(
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
(1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled.output_layouts()), (0, 1))
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
(0, 1))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, s)
@ -358,6 +357,64 @@ class LayoutTest(jtu.JaxTestCase):
jax.make_array_from_callback(
np_inp.shape, Layout(None, None), lambda idx: np_inp[idx])
def test_wsc_concrete_layout(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (128, 128)
s = NamedSharding(mesh, P('x'))
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))
# We need AUTO so that XLA can override the entry computation layout set.
# TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by
# default instead of `None` i.e. default layout and let the compiler choose
# the layout or try setting it to AUTO by default and see if there is chaos.
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
def f(x):
y = x.T
# Constrain `y` to the original layout of `arr` because without it,
# the layout of `y` would be the transpose of `arr`.
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
out = f(arr)
self.assertEqual(out.layout, Layout(custom_dll, s))
self.assertEqual(out.layout, arr.layout)
self.assertArraysEqual(out, np_inp.T)
def test_wsc_concrete_layout_bfloat16(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (128, 128)
s = NamedSharding(mesh, P('x'))
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
arr = jax.device_put(inp, s)
# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
def f(x):
y = x.T
# Constrain `y` to the original layout of `arr` because without it,
# the layout of `y` would be the transpose of `arr`.
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
out = f(arr)
self.assertEqual(out.layout, Layout(custom_dll, s))
self.assertEqual(out.layout, arr.layout)
self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self):
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),))
s = SingleDeviceSharding(jax.devices()[0])
out = jax.device_put(np_inp, Layout(dll, s))
self.assertEqual(out.layout, Layout(dll, s))
self.assertArraysEqual(out, np_inp)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())