mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Make sure default layout is None for input and output layout in all codepaths
PiperOrigin-RevId: 731865511
This commit is contained in:
parent
c7ca35fe32
commit
dda62f576f
@ -1508,9 +1508,22 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
f'arg layout: {arg_layout} for '
|
||||
f'arg shape: {core.shaped_abstractify(arg).str_short()}.'
|
||||
f'{extra_msg}')
|
||||
jit_in_l = (None if isinstance(jit_in_l, DeviceLocalLayout) and
|
||||
pxla.is_default_layout(jit_in_l, rs, aval) else jit_in_l)
|
||||
resolved_in_layouts.append(jit_in_l)
|
||||
return tuple(resolved_in_layouts)
|
||||
|
||||
def _resolve_out_layouts(out_layouts, out_shardings, out_avals):
|
||||
new_out_layouts = []
|
||||
for out_l, out_s, out_aval in safe_zip(out_layouts, out_shardings, out_avals):
|
||||
if out_l is None:
|
||||
new_out_layouts.append(None)
|
||||
elif (isinstance(out_l, DeviceLocalLayout) and
|
||||
pxla.is_default_layout(out_l, out_s, out_aval)):
|
||||
new_out_layouts.append(None)
|
||||
else:
|
||||
new_out_layouts.append(out_l)
|
||||
return tuple(new_out_layouts)
|
||||
|
||||
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
) -> Sequence[PjitSharding]:
|
||||
@ -1612,6 +1625,7 @@ def _resolve_and_lower(
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
|
||||
jaxpr.in_avals)
|
||||
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
|
||||
return _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline, compiler_options_kvs,
|
||||
|
@ -724,6 +724,26 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
||||
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
||||
|
||||
def test_layout_donation_with_default_layout(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 16)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
out_layout = Layout(arr.layout.device_local_layout, s)
|
||||
|
||||
@partial(jax.jit, out_shardings=out_layout, donate_argnums=0)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('tf.aliasing_output = 0', lowered_text)
|
||||
self.assertNotIn('jax.buffer_donor', lowered_text)
|
||||
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
self.assertEqual(out.layout, out_layout)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user