mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make tiling and sub_byte_element_size_in_bits private arguments of DeviceLocalLayout
. This is because XLA does not respect the values passed to it.
Once the compiler supports it, we can make it public and allow users to pass those values. Right now, only `major_to_minor` is supported. But a valid question is why even keep them as arguments in the constructor? It's because we need to translate `PjRtLayout` which we get from the executable to `DeviceLocalLayout` and preserve the `tiling` and `sub_byte_element_size_in_bits` info that we get from the compiler. This has helped catch bugs before when the compiler was not doing the right thing in layout propagation pass. PiperOrigin-RevId: 651644934
This commit is contained in:
parent
2cbe6caa50
commit
ff18dedf99
@ -62,6 +62,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -2726,6 +2727,25 @@ def maybe_recover_user_shardings(
|
||||
return new_shardings
|
||||
|
||||
|
||||
def _check_xla_user_layout(ul, xl, what: str):
|
||||
if xla_extension_version >= 274:
|
||||
if ul._tiling is None:
|
||||
if ul.major_to_minor != xl.major_to_minor:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
else:
|
||||
if ul != xl:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
else:
|
||||
if ul != xl:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
|
||||
|
||||
def _get_layouts_from_executable(
|
||||
xla_executable, in_layouts, out_layouts, num_ordered_effects
|
||||
) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]:
|
||||
@ -2743,25 +2763,19 @@ def _get_layouts_from_executable(
|
||||
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(i, DeviceLocalLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User input"
|
||||
" layout)")
|
||||
new_in_layouts.append(i)
|
||||
else:
|
||||
new_in_layouts.append(x)
|
||||
_check_xla_user_layout(i, x, "input")
|
||||
# Always append the XLA layout because it has the full information
|
||||
# (tiling, etc) even if the user layout does not specify tiling.
|
||||
new_in_layouts.append(x)
|
||||
|
||||
new_out_layouts = []
|
||||
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(o, DeviceLocalLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User output"
|
||||
" layout)")
|
||||
new_out_layouts.append(o)
|
||||
else:
|
||||
new_out_layouts.append(x)
|
||||
_check_xla_user_layout(o, x, "output")
|
||||
# Always append the XLA layout because it has the full information
|
||||
# (tiling, etc) even if the user layout does not specify tiling.
|
||||
new_out_layouts.append(x)
|
||||
|
||||
assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts)
|
||||
assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts)
|
||||
|
@ -33,43 +33,54 @@ class AutoLayout:
|
||||
if xla_extension_version >= 274:
|
||||
class DeviceLocalLayout:
|
||||
major_to_minor: tuple[int, ...]
|
||||
tiling: tuple[tuple[int, ...], ...] | None
|
||||
_tiling: tuple[tuple[int, ...], ...] | None
|
||||
_sub_byte_element_size_in_bits: int
|
||||
|
||||
AUTO = AutoLayout()
|
||||
|
||||
def __init__(self, major_to_minor: tuple[int, ...],
|
||||
tiling: tuple[tuple[int, ...], ...] | None = None):
|
||||
_tiling: tuple[tuple[int, ...], ...] | None = None,
|
||||
_sub_byte_element_size_in_bits: int = 0):
|
||||
self.major_to_minor = tuple(major_to_minor)
|
||||
self.tiling = None if tiling is None else tuple(map(tuple, tiling))
|
||||
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
|
||||
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
|
||||
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
xla_layout = pjrt_layout._xla_layout()
|
||||
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
|
||||
xla_layout.tiling())
|
||||
xla_layout.tiling(),
|
||||
xla_layout.element_size_in_bits())
|
||||
|
||||
def __repr__(self):
|
||||
return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
|
||||
f' tiling={self.tiling})')
|
||||
return (
|
||||
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
|
||||
f' _tiling={self._tiling},'
|
||||
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.major_to_minor, self.tiling))
|
||||
return hash((self.major_to_minor, self._tiling,
|
||||
self._sub_byte_element_size_in_bits))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return (self.major_to_minor == other.major_to_minor and
|
||||
self.tiling == other.tiling)
|
||||
self._tiling == other._tiling and
|
||||
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
if self.tiling is None:
|
||||
if self._tiling is None:
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1])
|
||||
else:
|
||||
if issubdtype(dtype, np.integer):
|
||||
if self._sub_byte_element_size_in_bits != 0:
|
||||
sub_byte_size = self._sub_byte_element_size_in_bits
|
||||
elif issubdtype(dtype, np.integer):
|
||||
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
|
||||
else:
|
||||
sub_byte_size = 0
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
|
||||
sub_byte_size)
|
||||
return str(xla_layout)
|
||||
else:
|
||||
|
@ -364,7 +364,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -374,7 +374,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.layout, Layout(custom_dll, s))
|
||||
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
||||
custom_dll.major_to_minor)
|
||||
self.assertEqual(out.layout, arr.layout)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
|
||||
@ -386,7 +387,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -396,20 +397,40 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.layout, Layout(custom_dll, s))
|
||||
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
||||
custom_dll.major_to_minor)
|
||||
self.assertEqual(out.layout, arr.layout)
|
||||
self.assertArraysEqual(out, inp.T)
|
||||
|
||||
def test_device_put_user_concrete_layout(self):
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),))
|
||||
dll = DLL(major_to_minor=(1, 0))
|
||||
s = SingleDeviceSharding(jax.devices()[0])
|
||||
|
||||
out = jax.device_put(np_inp, Layout(dll, s))
|
||||
self.assertEqual(out.layout, Layout(dll, s))
|
||||
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
||||
dll.major_to_minor)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_concrete_layout_jit(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
f = jax.jit(f, out_shardings=Layout(custom_dll, s))
|
||||
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
||||
custom_dll.major_to_minor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user