From bab70dda97cf3ecab68a915ad9f9261f25a02efc Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 09:11:13 -0700 Subject: [PATCH] Reverts 734ebd570891ceaf8c7104e12256a1edfe942b14 PiperOrigin-RevId: 662942100 --- jax/_src/interpreters/mlir.py | 24 ------------------ tests/layout_test.py | 48 ----------------------------------- 2 files changed, 72 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5cfb6a0e0..814c6a988 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1010,23 +1010,6 @@ def _get_mem_kind(s: JSharding | None) -> str | None: return s.memory_kind -def _is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None: - return True - if isinstance(curr_layout, AutoLayout): - return False - d = sharding._device_assignment[0] - try: - return curr_layout == DeviceLocalLayout.from_pjrt_layout( - d.client.get_default_layout(aval.dtype, aval.shape, d)) - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return True - else: - raise - - def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -1081,13 +1064,6 @@ def lower_jaxpr_to_module( "In multi-platform lowering either all or no lowering platforms " f"should support donation. Lowering for {platforms} of which " f"only {platforms_with_donation} support donation") - if (in_layouts is not None and arg_shardings is not None and - out_layouts is not None and result_shardings is not None - ) and not ( - all(map(_is_default_layout, in_layouts, arg_shardings, in_avals)) and - all(map(_is_default_layout, out_layouts, result_shardings, out_avals)) - ): - xla_donated_args = donated_args if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args diff --git a/tests/layout_test.py b/tests/layout_test.py index 2ddd72764..c72082d0a 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -500,54 +500,6 @@ class LayoutTest(jtu.JaxTestCase): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) - def test_layout_donation(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (16, 128) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - custom_dll = DLL(major_to_minor=(0, 1)) - arr = jax.device_put(np_inp, Layout(custom_dll, s)) - - @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) - def f(x): - return x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - - def test_layout_donation_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (128, 16) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - arr = jax.device_put(np_inp, s) - - @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) - def f(x): - return x * x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - - def test_layout_donation_matching_in_and_out(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (128, 16) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - custom_dll = DLL(major_to_minor=(0, 1)) - l = Layout(custom_dll, s) - arr = jax.device_put(np_inp, l) - - @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) - def f(x): - return x * x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())