Check for layout mismatch between array's layout and layout specified via in_shardings to jit by only checking major_to_minor if _tiling is None. Otherwise, check the entire layout.

PiperOrigin-RevId: 651796471
This commit is contained in:
Yash Katariya 2024-07-12 09:22:44 -07:00 committed by jax authors
parent ff3dc0f5fb
commit 1e1bca0706
3 changed files with 47 additions and 20 deletions

View File

@ -2726,21 +2726,18 @@ def maybe_recover_user_shardings(
return new_shardings
def _check_xla_user_layout(ul, xl, what: str):
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
xl: DeviceLocalLayout) -> bool:
if xla_extension_version >= 274:
if ul._tiling is None:
if ul.major_to_minor != xl.major_to_minor:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor
else:
if ul != xl:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
return ul == xl
else:
if ul != xl:
return ul == xl
def _check_user_xla_layout(ul, xl, what: str):
if not is_user_xla_layout_equal(ul, xl):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")
@ -2763,7 +2760,7 @@ def _get_layouts_from_executable(
for x, i in safe_zip(in_layouts_xla, in_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(i, DeviceLocalLayout):
_check_xla_user_layout(i, x, "input")
_check_user_xla_layout(i, x, "input")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_in_layouts.append(x)
@ -2772,7 +2769,7 @@ def _get_layouts_from_executable(
for x, o in safe_zip(out_layouts_xla, out_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(o, DeviceLocalLayout):
_check_xla_user_layout(o, x, "output")
_check_user_xla_layout(o, x, "output")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_out_layouts.append(x)

View File

@ -676,7 +676,6 @@ def _infer_params_impl(
attrs_tracked), args_flat
class InferParamsCacheEntry:
"""Mutable value object for _infer_params_cached."""
__slots__ = ['pjit_params']
@ -1464,8 +1463,10 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
# arg_layout can be None because some backends don't implement the
# required layout methods. Hence `arr.layout` can return
# `Layout(None, sharding)`
if (committed and not is_pmap_sharding and
arg_layout is not None and arg_layout != jit_in_l):
if (committed
and not is_pmap_sharding
and arg_layout is not None
and not pxla.is_user_xla_layout_equal(jit_in_l, arg_layout)):
extra_msg = ''
if isinstance(jit_in_l, AutoLayout):
extra_msg = (

View File

@ -456,6 +456,35 @@ class LayoutTest(jtu.JaxTestCase):
'.*Length of major_to_minor and the rank of the value should match.*'):
jax.device_put(inp, l)
def test_concrete_layout_in_shardings(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (16, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
custom_dll = DLL(major_to_minor=(0, 1))
@partial(jax.jit, in_shardings=Layout(custom_dll, s))
def f(x):
return x.T
out = f(arr)
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.layout.device_local_layout.major_to_minor,
custom_dll.major_to_minor[::-1])
custom_dll2 = DLL(major_to_minor=(1, 0))
@partial(jax.jit, in_shardings=Layout(custom_dll2, s))
def g(x):
return x.T
with self.assertRaisesRegex(
ValueError,
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())