mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Reverts 734ebd570891ceaf8c7104e12256a1edfe942b14
PiperOrigin-RevId: 662942100
This commit is contained in:
parent
229cbae5ea
commit
bab70dda97
@ -1010,23 +1010,6 @@ def _get_mem_kind(s: JSharding | None) -> str | None:
|
||||
return s.memory_kind
|
||||
|
||||
|
||||
def _is_default_layout(curr_layout, sharding, aval):
|
||||
if curr_layout is None or sharding is None:
|
||||
return True
|
||||
if isinstance(curr_layout, AutoLayout):
|
||||
return False
|
||||
d = sharding._device_assignment[0]
|
||||
try:
|
||||
return curr_layout == DeviceLocalLayout.from_pjrt_layout(
|
||||
d.client.get_default_layout(aval.dtype, aval.shape, d))
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
return True
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
@ -1081,13 +1064,6 @@ def lower_jaxpr_to_module(
|
||||
"In multi-platform lowering either all or no lowering platforms "
|
||||
f"should support donation. Lowering for {platforms} of which "
|
||||
f"only {platforms_with_donation} support donation")
|
||||
if (in_layouts is not None and arg_shardings is not None and
|
||||
out_layouts is not None and result_shardings is not None
|
||||
) and not (
|
||||
all(map(_is_default_layout, in_layouts, arg_shardings, in_avals)) and
|
||||
all(map(_is_default_layout, out_layouts, result_shardings, out_avals))
|
||||
):
|
||||
xla_donated_args = donated_args
|
||||
if num_partitions > 1 and (
|
||||
result_shardings is None or all(s is None for s in result_shardings)):
|
||||
xla_donated_args = donated_args
|
||||
|
@ -500,54 +500,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
g(arr)
|
||||
|
||||
def test_layout_donation(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
arr = jax.device_put(np_inp, Layout(custom_dll, s))
|
||||
|
||||
@partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0)
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
out = f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (128, 16)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0)
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
out = f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_matching_in_and_out(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (128, 16)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1))
|
||||
l = Layout(custom_dll, s)
|
||||
arr = jax.device_put(np_inp, l)
|
||||
|
||||
@partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0)
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
out = f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user