Make tiling and sub_byte_element_size_in_bits private arguments of DeviceLocalLayout. This is because XLA does not respect the values passed to it.

Once the compiler supports it, we can make it public and allow users to pass those values. Right now, only `major_to_minor` is supported.

But a valid question is why even keep them as arguments in the constructor?

It's because we need to translate `PjRtLayout` which we get from the executable to `DeviceLocalLayout` and preserve the `tiling` and `sub_byte_element_size_in_bits` info that we get from the compiler. This has helped catch bugs before when the compiler was not doing the right thing in layout propagation pass.

PiperOrigin-RevId: 651644934
This commit is contained in:
Yash Katariya 2024-07-11 22:05:19 -07:00 committed by jax authors
parent 2cbe6caa50
commit ff18dedf99
3 changed files with 77 additions and 31 deletions

View File

@ -62,6 +62,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -2726,6 +2727,25 @@ def maybe_recover_user_shardings(
return new_shardings
def _check_xla_user_layout(ul, xl, what: str):
if xla_extension_version >= 274:
if ul._tiling is None:
if ul.major_to_minor != xl.major_to_minor:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
else:
if ul != xl:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
else:
if ul != xl:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
def _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, num_ordered_effects
) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]:
@ -2743,25 +2763,19 @@ def _get_layouts_from_executable(
for x, i in safe_zip(in_layouts_xla, in_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(i, DeviceLocalLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User input"
" layout)")
new_in_layouts.append(i)
else:
new_in_layouts.append(x)
_check_xla_user_layout(i, x, "input")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_in_layouts.append(x)
new_out_layouts = []
for x, o in safe_zip(out_layouts_xla, out_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(o, DeviceLocalLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User output"
" layout)")
new_out_layouts.append(o)
else:
new_out_layouts.append(x)
_check_xla_user_layout(o, x, "output")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_out_layouts.append(x)
assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts)
assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts)

View File

@ -33,43 +33,54 @@ class AutoLayout:
if xla_extension_version >= 274:
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
tiling: tuple[tuple[int, ...], ...] | None
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int
AUTO = AutoLayout()
def __init__(self, major_to_minor: tuple[int, ...],
tiling: tuple[tuple[int, ...], ...] | None = None):
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self.tiling = None if tiling is None else tuple(map(tuple, tiling))
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling())
xla_layout.tiling(),
xla_layout.element_size_in_bits())
def __repr__(self):
return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' tiling={self.tiling})')
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)
def __hash__(self):
return hash((self.major_to_minor, self.tiling))
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self.tiling == other.tiling)
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def _to_xla_layout(self, dtype) -> str:
if self.tiling is None:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
if issubdtype(dtype, np.integer):
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
sub_byte_size)
return str(xla_layout)
else:

View File

@ -364,7 +364,7 @@ class LayoutTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))
custom_dll = DLL(major_to_minor=(0, 1))
@jax.jit
def f(x):
@ -374,7 +374,8 @@ class LayoutTest(jtu.JaxTestCase):
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
out = f(arr)
self.assertEqual(out.layout, Layout(custom_dll, s))
self.assertEqual(out.layout.device_local_layout.major_to_minor,
custom_dll.major_to_minor)
self.assertEqual(out.layout, arr.layout)
self.assertArraysEqual(out, np_inp.T)
@ -386,7 +387,7 @@ class LayoutTest(jtu.JaxTestCase):
arr = jax.device_put(inp, s)
# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))
custom_dll = DLL(major_to_minor=(0, 1))
@jax.jit
def f(x):
@ -396,20 +397,40 @@ class LayoutTest(jtu.JaxTestCase):
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
out = f(arr)
self.assertEqual(out.layout, Layout(custom_dll, s))
self.assertEqual(out.layout.device_local_layout.major_to_minor,
custom_dll.major_to_minor)
self.assertEqual(out.layout, arr.layout)
self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self):
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),))
dll = DLL(major_to_minor=(1, 0))
s = SingleDeviceSharding(jax.devices()[0])
out = jax.device_put(np_inp, Layout(dll, s))
self.assertEqual(out.layout, Layout(dll, s))
self.assertEqual(out.layout.device_local_layout.major_to_minor,
dll.major_to_minor)
self.assertArraysEqual(out, np_inp)
def test_concrete_layout_jit(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (16, 128)
s = NamedSharding(mesh, P('x'))
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
def f(x):
return x.T
custom_dll = DLL(major_to_minor=(0, 1))
f = jax.jit(f, out_shardings=Layout(custom_dll, s))
out = f(arr)
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.layout.device_local_layout.major_to_minor,
custom_dll.major_to_minor)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())