Reverts 734ebd570891ceaf8c7104e12256a1edfe942b14

PiperOrigin-RevId: 662942100
This commit is contained in:
jax authors 2024-08-14 09:11:13 -07:00 committed by jax authors
parent 229cbae5ea
commit bab70dda97
2 changed files with 0 additions and 72 deletions

View File

@ -1010,23 +1010,6 @@ def _get_mem_kind(s: JSharding | None) -> str | None:
return s.memory_kind
def _is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None:
return True
if isinstance(curr_layout, AutoLayout):
return False
d = sharding._device_assignment[0]
try:
return curr_layout == DeviceLocalLayout.from_pjrt_layout(
d.client.get_default_layout(aval.dtype, aval.shape, d))
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return True
else:
raise
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
@ -1081,13 +1064,6 @@ def lower_jaxpr_to_module(
"In multi-platform lowering either all or no lowering platforms "
f"should support donation. Lowering for {platforms} of which "
f"only {platforms_with_donation} support donation")
if (in_layouts is not None and arg_shardings is not None and
out_layouts is not None and result_shardings is not None
) and not (
all(map(_is_default_layout, in_layouts, arg_shardings, in_avals)) and
all(map(_is_default_layout, out_layouts, result_shardings, out_avals))
):
xla_donated_args = donated_args
if num_partitions > 1 and (
result_shardings is None or all(s is None for s in result_shardings)):
xla_donated_args = donated_args

View File

@ -500,54 +500,6 @@ class LayoutTest(jtu.JaxTestCase):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)
def test_layout_donation(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (16, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
custom_dll = DLL(major_to_minor=(0, 1))
arr = jax.device_put(np_inp, Layout(custom_dll, s))
@partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0)
def f(x):
return x
out = f(arr)
self.assertTrue(arr.is_deleted())
def test_layout_donation_auto(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (128, 16)
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0)
def f(x):
return x * x
out = f(arr)
self.assertTrue(arr.is_deleted())
def test_layout_donation_matching_in_and_out(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (128, 16)
np_inp = np.arange(math.prod(shape)).reshape(shape)
custom_dll = DLL(major_to_minor=(0, 1))
l = Layout(custom_dll, s)
arr = jax.device_put(np_inp, l)
@partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0)
def f(x):
return x * x
out = f(arr)
self.assertTrue(arr.is_deleted())
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())