Add tests to check if pjit handles deleted array inputs gracefully and consistently

pjit dispatch paths should check deleted array inputs when attempting to use
them. These new tests ensure that various pjit dispatch paths detect and handle
them gracefully and consistently.

Add a check to the PyArray argument handling to make the tests pass.

PiperOrigin-RevId: 492605524
This commit is contained in:
Hyeontaek Lim 2022-12-02 18:40:59 -08:00 committed by jax authors
parent 693047a14b
commit 02fab525a7

View File

@ -2997,6 +2997,49 @@ class PJitErrorTest(jtu.JaxTestCase):
"Changing the physical mesh is not allowed.*"):
f(x)
@parameterized.named_parameters(
("committed", True),
("uncommitted", False),
)
@jax_array(True)
def test_pjit_with_deleted_input_at_first_call(self, committed):
if xla_extension_version < 109:
self.skipTest('Does not work for xla_extension_version < 109')
shape = (8,)
mesh = jtu.create_global_mesh((1,), ('x',))
inp_data = np.arange(prod(shape)).reshape(shape)
if committed:
s = NamedSharding(mesh, P('x',))
x = jax.device_put(inp_data, s)
else:
x = jax.device_put(inp_data)
f = pjit(lambda x: x + 1)
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
x.delete()
_ = f(x)
@parameterized.named_parameters(
("committed", True),
("uncommitted", False),
)
@jax_array(True)
def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
if xla_extension_version < 109:
self.skipTest('Does not work for xla_extension_version < 109')
shape = (8,)
mesh = jtu.create_global_mesh((1,), ('x',))
inp_data = np.arange(prod(shape)).reshape(shape)
if committed:
s = NamedSharding(mesh, P('x',))
x = jax.device_put(inp_data, s)
else:
x = jax.device_put(inp_data)
f = pjit(lambda x: x + 1)
_ = f(x)
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
x.delete()
_ = f(x)
class UtilTest(jtu.JaxTestCase):