diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b12f92110..24e72117b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1508,9 +1508,22 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): f'arg layout: {arg_layout} for ' f'arg shape: {core.shaped_abstractify(arg).str_short()}.' f'{extra_msg}') + jit_in_l = (None if isinstance(jit_in_l, DeviceLocalLayout) and + pxla.is_default_layout(jit_in_l, rs, aval) else jit_in_l) resolved_in_layouts.append(jit_in_l) return tuple(resolved_in_layouts) +def _resolve_out_layouts(out_layouts, out_shardings, out_avals): + new_out_layouts = [] + for out_l, out_s, out_aval in safe_zip(out_layouts, out_shardings, out_avals): + if out_l is None: + new_out_layouts.append(None) + elif (isinstance(out_l, DeviceLocalLayout) and + pxla.is_default_layout(out_l, out_s, out_aval)): + new_out_layouts.append(None) + else: + new_out_layouts.append(out_l) + return tuple(new_out_layouts) def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] ) -> Sequence[PjitSharding]: @@ -1612,6 +1625,7 @@ def _resolve_and_lower( in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) + out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals) return _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, diff --git a/tests/layout_test.py b/tests/layout_test.py index d98121b53..b9062b8d2 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -724,6 +724,26 @@ class LayoutTest(jtu.JaxTestCase): self.assertArraysEqual(out, np_inp @ np_inp.T) self.assertArraysEqual(out2, np_inp @ np_inp.T) + def test_layout_donation_with_default_layout(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + out_layout = Layout(arr.layout.device_local_layout, s) + + @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) + def f(x): + return x * 2 + + lowered_text = f.lower(arr).as_text() + self.assertIn('tf.aliasing_output = 0', lowered_text) + self.assertNotIn('jax.buffer_donor', lowered_text) + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.layout, out_layout) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())