Change block_until_ready() to return self rather than nothing.

This commit is contained in:
Peter Hawkins 2019-09-05 10:16:20 -04:00
parent 73d512bdd2
commit 612ffd0687
3 changed files with 7 additions and 2 deletions

View File

@ -297,6 +297,7 @@ class ShardedDeviceValue(xla.DeviceValue):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self
class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):

View File

@ -507,9 +507,12 @@ class DeviceValue(object):
This method is mostly useful for timing microbenchmarks that wish to
time how long a computation takes, without transferring the result back
to the host.
Returns the buffer object (`self`).
"""
self._check_if_deleted()
self.device_buffer.block_host_until_ready()
return self
def _forward_method(attrname, self, fun, *args):
return fun(getattr(self, attrname), *args)

View File

@ -699,8 +699,9 @@ class APITest(jtu.JaxTestCase):
def test_devicearray_block_until_ready(self):
x = device_put(1.)
x.block_until_ready()
# Tests only that block_until_ready() does not produce an error.
y = x.block_until_ready()
# Tests mostly that block_until_ready() does not produce an error.
self.assertTrue(y is x)
def test_namedtuple_transparency(self):
# See https://github.com/google/jax/issues/446