Update a regression test to test size-zero device to device transfers. (#2411)

This commit is contained in:
Peter Hawkins 2020-03-13 13:35:18 -04:00 committed by GitHub
parent 7f0463e2c9
commit 271041b499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -308,14 +308,17 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(y2[1][1], onp.ndarray)
assert onp.all(y2[1][1] == 3 * x)
def test_device_put_across_devices(self):
if xb.device_count() == 1:
@parameterized.parameters([(3,)], [(2, 0)])
def test_device_put_across_devices(self, shape):
if len(api.local_devices()) < 2:
raise unittest.SkipTest("this test requires multiple devices")
d1, d2 = xb.local_devices()[:2]
x = api.device_put(onp.array([1,2,3]), device=d1)
d1, d2 = api.local_devices()[:2]
data = onp.random.randn(*shape).astype(onp.float32)
x = api.device_put(data, device=d1)
self.assertEqual(x.device_buffer.device(), d1)
y = api.device_put(x, device=d2)
self.assertEqual(y.device_buffer.device(), d2)
onp.testing.assert_array_equal(data, onp.array(y))
# Make sure these don't crash
api.device_put(x)
api.device_put(y)