From 1e1bca0706c92907657d35ddb8f269ce507f61c6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Jul 2024 09:22:44 -0700 Subject: [PATCH] Check for layout mismatch between array's layout and layout specified via in_shardings to jit by only checking `major_to_minor` if `_tiling` is None. Otherwise, check the entire layout. PiperOrigin-RevId: 651796471 --- jax/_src/interpreters/pxla.py | 31 ++++++++++++++----------------- jax/_src/pjit.py | 7 ++++--- tests/layout_test.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2c4daf4a8..345665343 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2726,24 +2726,21 @@ def maybe_recover_user_shardings( return new_shardings - -def _check_xla_user_layout(ul, xl, what: str): +def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, + xl: DeviceLocalLayout) -> bool: if xla_extension_version >= 274: - if ul._tiling is None: - if ul.major_to_minor != xl.major_to_minor: - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") + if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: + return ul.major_to_minor == xl.major_to_minor else: - if ul != xl: - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") + return ul == xl else: - if ul != xl: - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") + return ul == xl + +def _check_user_xla_layout(ul, xl, what: str): + if not is_user_xla_layout_equal(ul, xl): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {xl} != {ul} " + f"(User {what} layout)") def _get_layouts_from_executable( @@ -2763,7 +2760,7 @@ def _get_layouts_from_executable( for x, i in safe_zip(in_layouts_xla, in_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(i, DeviceLocalLayout): - _check_xla_user_layout(i, x, "input") + _check_user_xla_layout(i, x, "input") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_in_layouts.append(x) @@ -2772,7 +2769,7 @@ def _get_layouts_from_executable( for x, o in safe_zip(out_layouts_xla, out_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(o, DeviceLocalLayout): - _check_xla_user_layout(o, x, "output") + _check_user_xla_layout(o, x, "output") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_out_layouts.append(x) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6dbd8a8bf..f5892e44c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -676,7 +676,6 @@ def _infer_params_impl( attrs_tracked), args_flat - class InferParamsCacheEntry: """Mutable value object for _infer_params_cached.""" __slots__ = ['pjit_params'] @@ -1464,8 +1463,10 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # arg_layout can be None because some backends don't implement the # required layout methods. Hence `arr.layout` can return # `Layout(None, sharding)` - if (committed and not is_pmap_sharding and - arg_layout is not None and arg_layout != jit_in_l): + if (committed + and not is_pmap_sharding + and arg_layout is not None + and not pxla.is_user_xla_layout_equal(jit_in_l, arg_layout)): extra_msg = '' if isinstance(jit_in_l, AutoLayout): extra_msg = ( diff --git a/tests/layout_test.py b/tests/layout_test.py index da38ed948..7972d44d3 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -456,6 +456,35 @@ class LayoutTest(jtu.JaxTestCase): '.*Length of major_to_minor and the rank of the value should match.*'): jax.device_put(inp, l) + def test_concrete_layout_in_shardings(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + custom_dll = DLL(major_to_minor=(0, 1)) + + @partial(jax.jit, in_shardings=Layout(custom_dll, s)) + def f(x): + return x.T + + out = f(arr) + self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.layout.device_local_layout.major_to_minor, + custom_dll.major_to_minor[::-1]) + + custom_dll2 = DLL(major_to_minor=(1, 0)) + + @partial(jax.jit, in_shardings=Layout(custom_dll2, s)) + def g(x): + return x.T + + with self.assertRaisesRegex( + ValueError, + 'Layout passed to jit does not match the layout on the respective arg'): + g(arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())