Check if the input which is donated is actually deleted along with the AOT check.

PiperOrigin-RevId: 565098239
This commit is contained in:
Yash Katariya 2023-09-13 10:49:33 -07:00 committed by jax authors
parent 6abefa1977
commit 8340149336

View File

@ -56,7 +56,7 @@ def _create_inputs(shape, pspec, mem_kind=None):
# * nested jit
class MemoriesTest(jtu.JaxTestCase):
class MemoriesTest(jtu.BufferDonationTestCase):
def setUp(self):
if jtu.device_under_test() in ("cpu", "gpu"):
@ -850,6 +850,7 @@ class MemoriesTest(jtu.JaxTestCase):
lowered_text = f.lower(inp).as_text("hlo")
self.assertNotIn("input_output_alias", lowered_text)
self.assertNotDeleted(inp)
@parameterized.named_parameters(
("hbm_to_host", "tpu_hbm", "unpinned_host"),
@ -960,7 +961,7 @@ class MemoriesTest(jtu.JaxTestCase):
" inside jax.jit"):
jax.device_put(np.arange(16), TransferToMemoryKind("tpu_hbm"))
def test_single_mem_kind_donation(self):
def test_single_mem_kind_donation_default_mem_kind(self):
mesh = jtu.create_global_mesh((2,), "x")
@functools.partial(jax.jit, donate_argnums=0)
@ -973,6 +974,25 @@ class MemoriesTest(jtu.JaxTestCase):
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
self.assertDeleted(x)
def test_single_mem_kind_donation_host(self):
mesh = jtu.create_global_mesh((2,), "x")
@functools.partial(jax.jit, donate_argnums=0)
def f(inp1):
return inp1 * 2
s_host = NamedSharding(mesh, P(), memory_kind="unpinned_host")
x = jax.device_put(np.arange(16).reshape(8, 2), s_host)
f(x)
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
# TODO(yashkatariya): Donation does not work on host memory yet. Uncomment
# this after it is fixed.
# self.assertDeleted(x)
def test_remat_jaxpr_offloadable(self):
def policy(prim, *avals, **params):