mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check if the input which is donated is actually deleted along with the AOT check.
PiperOrigin-RevId: 565098239
This commit is contained in:
parent
6abefa1977
commit
8340149336
@ -56,7 +56,7 @@ def _create_inputs(shape, pspec, mem_kind=None):
|
||||
# * nested jit
|
||||
|
||||
|
||||
class MemoriesTest(jtu.JaxTestCase):
|
||||
class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if jtu.device_under_test() in ("cpu", "gpu"):
|
||||
@ -850,6 +850,7 @@ class MemoriesTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = f.lower(inp).as_text("hlo")
|
||||
self.assertNotIn("input_output_alias", lowered_text)
|
||||
self.assertNotDeleted(inp)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("hbm_to_host", "tpu_hbm", "unpinned_host"),
|
||||
@ -960,7 +961,7 @@ class MemoriesTest(jtu.JaxTestCase):
|
||||
" inside jax.jit"):
|
||||
jax.device_put(np.arange(16), TransferToMemoryKind("tpu_hbm"))
|
||||
|
||||
def test_single_mem_kind_donation(self):
|
||||
def test_single_mem_kind_donation_default_mem_kind(self):
|
||||
mesh = jtu.create_global_mesh((2,), "x")
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=0)
|
||||
@ -973,6 +974,25 @@ class MemoriesTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = f.lower(x).as_text("hlo")
|
||||
self.assertIn("input_output_alias", lowered_text)
|
||||
self.assertDeleted(x)
|
||||
|
||||
def test_single_mem_kind_donation_host(self):
|
||||
mesh = jtu.create_global_mesh((2,), "x")
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=0)
|
||||
def f(inp1):
|
||||
return inp1 * 2
|
||||
|
||||
s_host = NamedSharding(mesh, P(), memory_kind="unpinned_host")
|
||||
x = jax.device_put(np.arange(16).reshape(8, 2), s_host)
|
||||
|
||||
f(x)
|
||||
|
||||
lowered_text = f.lower(x).as_text("hlo")
|
||||
self.assertIn("input_output_alias", lowered_text)
|
||||
# TODO(yashkatariya): Donation does not work on host memory yet. Uncomment
|
||||
# this after it is fixed.
|
||||
# self.assertDeleted(x)
|
||||
|
||||
def test_remat_jaxpr_offloadable(self):
|
||||
def policy(prim, *avals, **params):
|
||||
|
Loading…
x
Reference in New Issue
Block a user