mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add tests to check if pjit handles deleted array inputs gracefully and consistently
pjit dispatch paths should check deleted array inputs when attempting to use them. These new tests ensure that various pjit dispatch paths detect and handle them gracefully and consistently. Add a check to the PyArray argument handling to make the tests pass. PiperOrigin-RevId: 492605524
This commit is contained in:
parent
693047a14b
commit
02fab525a7
@ -2997,6 +2997,49 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
"Changing the physical mesh is not allowed.*"):
|
||||
f(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("committed", True),
|
||||
("uncommitted", False),
|
||||
)
|
||||
@jax_array(True)
|
||||
def test_pjit_with_deleted_input_at_first_call(self, committed):
|
||||
if xla_extension_version < 109:
|
||||
self.skipTest('Does not work for xla_extension_version < 109')
|
||||
shape = (8,)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
if committed:
|
||||
s = NamedSharding(mesh, P('x',))
|
||||
x = jax.device_put(inp_data, s)
|
||||
else:
|
||||
x = jax.device_put(inp_data)
|
||||
f = pjit(lambda x: x + 1)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
||||
x.delete()
|
||||
_ = f(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("committed", True),
|
||||
("uncommitted", False),
|
||||
)
|
||||
@jax_array(True)
|
||||
def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
|
||||
if xla_extension_version < 109:
|
||||
self.skipTest('Does not work for xla_extension_version < 109')
|
||||
shape = (8,)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
if committed:
|
||||
s = NamedSharding(mesh, P('x',))
|
||||
x = jax.device_put(inp_data, s)
|
||||
else:
|
||||
x = jax.device_put(inp_data)
|
||||
f = pjit(lambda x: x + 1)
|
||||
_ = f(x)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
||||
x.delete()
|
||||
_ = f(x)
|
||||
|
||||
|
||||
class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user