mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Check for layout mismatch between array's layout and layout specified via in_shardings to jit by only checking major_to_minor
if _tiling
is None. Otherwise, check the entire layout.
PiperOrigin-RevId: 651796471
This commit is contained in:
parent
ff3dc0f5fb
commit
1e1bca0706
@ -2726,21 +2726,18 @@ def maybe_recover_user_shardings(
|
||||
|
||||
return new_shardings
|
||||
|
||||
|
||||
def _check_xla_user_layout(ul, xl, what: str):
|
||||
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
||||
xl: DeviceLocalLayout) -> bool:
|
||||
if xla_extension_version >= 274:
|
||||
if ul._tiling is None:
|
||||
if ul.major_to_minor != xl.major_to_minor:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
||||
return ul.major_to_minor == xl.major_to_minor
|
||||
else:
|
||||
if ul != xl:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
return ul == xl
|
||||
else:
|
||||
if ul != xl:
|
||||
return ul == xl
|
||||
|
||||
def _check_user_xla_layout(ul, xl, what: str):
|
||||
if not is_user_xla_layout_equal(ul, xl):
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
|
||||
f"(User {what} layout)")
|
||||
@ -2763,7 +2760,7 @@ def _get_layouts_from_executable(
|
||||
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(i, DeviceLocalLayout):
|
||||
_check_xla_user_layout(i, x, "input")
|
||||
_check_user_xla_layout(i, x, "input")
|
||||
# Always append the XLA layout because it has the full information
|
||||
# (tiling, etc) even if the user layout does not specify tiling.
|
||||
new_in_layouts.append(x)
|
||||
@ -2772,7 +2769,7 @@ def _get_layouts_from_executable(
|
||||
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(o, DeviceLocalLayout):
|
||||
_check_xla_user_layout(o, x, "output")
|
||||
_check_user_xla_layout(o, x, "output")
|
||||
# Always append the XLA layout because it has the full information
|
||||
# (tiling, etc) even if the user layout does not specify tiling.
|
||||
new_out_layouts.append(x)
|
||||
|
@ -676,7 +676,6 @@ def _infer_params_impl(
|
||||
attrs_tracked), args_flat
|
||||
|
||||
|
||||
|
||||
class InferParamsCacheEntry:
|
||||
"""Mutable value object for _infer_params_cached."""
|
||||
__slots__ = ['pjit_params']
|
||||
@ -1464,8 +1463,10 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
# arg_layout can be None because some backends don't implement the
|
||||
# required layout methods. Hence `arr.layout` can return
|
||||
# `Layout(None, sharding)`
|
||||
if (committed and not is_pmap_sharding and
|
||||
arg_layout is not None and arg_layout != jit_in_l):
|
||||
if (committed
|
||||
and not is_pmap_sharding
|
||||
and arg_layout is not None
|
||||
and not pxla.is_user_xla_layout_equal(jit_in_l, arg_layout)):
|
||||
extra_msg = ''
|
||||
if isinstance(jit_in_l, AutoLayout):
|
||||
extra_msg = (
|
||||
|
@ -456,6 +456,35 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
'.*Length of major_to_minor and the rank of the value should match.*'):
|
||||
jax.device_put(inp, l)
|
||||
|
||||
def test_concrete_layout_in_shardings(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
|
||||
@partial(jax.jit, in_shardings=Layout(custom_dll, s))
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
||||
custom_dll.major_to_minor[::-1])
|
||||
|
||||
custom_dll2 = DLL(major_to_minor=(1, 0))
|
||||
|
||||
@partial(jax.jit, in_shardings=Layout(custom_dll2, s))
|
||||
def g(x):
|
||||
return x.T
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
g(arr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user