mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22465 from ROCm:ci_multiprocess_gpu_test
PiperOrigin-RevId: 652602616
This commit is contained in:
commit
21e9dabdad
@ -115,7 +115,9 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
# arguments are first device_put to the specified device. The result
|
||||
# will be committed on the specified.
|
||||
# The `device` parameter is experimental, and subject to change.
|
||||
jit_add_on4 = jax.jit(lambda a, b: a + b, device=devices[4])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
jit_add_on4 = jax.jit(lambda a, b: a + b, device=devices[4])
|
||||
self.assert_committed_to_device(jit_add_on4(1, 2), devices[4])
|
||||
self.assert_committed_to_device(jit_add_on4(1, jax.device_put(2, devices[2])),
|
||||
devices[4])
|
||||
|
Loading…
x
Reference in New Issue
Block a user