mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update a regression test to test size-zero device to device transfers. (#2411)
This commit is contained in:
parent
7f0463e2c9
commit
271041b499
@ -308,14 +308,17 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y2[1][1], onp.ndarray)
|
||||
assert onp.all(y2[1][1] == 3 * x)
|
||||
|
||||
def test_device_put_across_devices(self):
|
||||
if xb.device_count() == 1:
|
||||
@parameterized.parameters([(3,)], [(2, 0)])
|
||||
def test_device_put_across_devices(self, shape):
|
||||
if len(api.local_devices()) < 2:
|
||||
raise unittest.SkipTest("this test requires multiple devices")
|
||||
d1, d2 = xb.local_devices()[:2]
|
||||
x = api.device_put(onp.array([1,2,3]), device=d1)
|
||||
d1, d2 = api.local_devices()[:2]
|
||||
data = onp.random.randn(*shape).astype(onp.float32)
|
||||
x = api.device_put(data, device=d1)
|
||||
self.assertEqual(x.device_buffer.device(), d1)
|
||||
y = api.device_put(x, device=d2)
|
||||
self.assertEqual(y.device_buffer.device(), d2)
|
||||
onp.testing.assert_array_equal(data, onp.array(y))
|
||||
# Make sure these don't crash
|
||||
api.device_put(x)
|
||||
api.device_put(y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user