Make sure default layout is None for input and output layout in all codepaths

PiperOrigin-RevId: 731865511
This commit is contained in:
Yash Katariya 2025-02-27 14:24:27 -08:00 committed by jax authors
parent c7ca35fe32
commit dda62f576f
2 changed files with 34 additions and 0 deletions

View File

@ -1508,9 +1508,22 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
f'arg layout: {arg_layout} for '
f'arg shape: {core.shaped_abstractify(arg).str_short()}.'
f'{extra_msg}')
jit_in_l = (None if isinstance(jit_in_l, DeviceLocalLayout) and
pxla.is_default_layout(jit_in_l, rs, aval) else jit_in_l)
resolved_in_layouts.append(jit_in_l)
return tuple(resolved_in_layouts)
def _resolve_out_layouts(out_layouts, out_shardings, out_avals):
new_out_layouts = []
for out_l, out_s, out_aval in safe_zip(out_layouts, out_shardings, out_avals):
if out_l is None:
new_out_layouts.append(None)
elif (isinstance(out_l, DeviceLocalLayout) and
pxla.is_default_layout(out_l, out_s, out_aval)):
new_out_layouts.append(None)
else:
new_out_layouts.append(out_l)
return tuple(new_out_layouts)
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
) -> Sequence[PjitSharding]:
@ -1612,6 +1625,7 @@ def _resolve_and_lower(
in_shardings = _resolve_in_shardings(args, in_shardings)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
return _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline, compiler_options_kvs,

View File

@ -724,6 +724,26 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)
def test_layout_donation_with_default_layout(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (16, 16)
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
out_layout = Layout(arr.layout.device_local_layout, s)
@partial(jax.jit, out_shardings=out_layout, donate_argnums=0)
def f(x):
return x * 2
lowered_text = f.lower(arr).as_text()
self.assertIn('tf.aliasing_output = 0', lowered_text)
self.assertNotIn('jax.buffer_donor', lowered_text)
out = f(arr)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.layout, out_layout)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())