Merge pull request #22465 from ROCm:ci_multiprocess_gpu_test

PiperOrigin-RevId: 652602616
This commit is contained in:
jax authors 2024-07-15 14:38:58 -07:00
commit 21e9dabdad

View File

@ -115,7 +115,9 @@ class MultiDeviceTest(jtu.JaxTestCase):
# arguments are first device_put to the specified device. The result
# will be committed on the specified.
# The `device` parameter is experimental, and subject to change.
jit_add_on4 = jax.jit(lambda a, b: a + b, device=devices[4])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
jit_add_on4 = jax.jit(lambda a, b: a + b, device=devices[4])
self.assert_committed_to_device(jit_add_on4(1, 2), devices[4])
self.assert_committed_to_device(jit_add_on4(1, jax.device_put(2, devices[2])),
devices[4])