mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Change block_until_ready() to return self
rather than nothing.
This commit is contained in:
parent
73d512bdd2
commit
612ffd0687
@ -297,6 +297,7 @@ class ShardedDeviceValue(xla.DeviceValue):
|
||||
self._check_if_deleted()
|
||||
for buf in self.device_buffers:
|
||||
buf.block_host_until_ready()
|
||||
return self
|
||||
|
||||
|
||||
class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
|
@ -507,9 +507,12 @@ class DeviceValue(object):
|
||||
This method is mostly useful for timing microbenchmarks that wish to
|
||||
time how long a computation takes, without transferring the result back
|
||||
to the host.
|
||||
|
||||
Returns the buffer object (`self`).
|
||||
"""
|
||||
self._check_if_deleted()
|
||||
self.device_buffer.block_host_until_ready()
|
||||
return self
|
||||
|
||||
def _forward_method(attrname, self, fun, *args):
|
||||
return fun(getattr(self, attrname), *args)
|
||||
|
@ -699,8 +699,9 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_devicearray_block_until_ready(self):
|
||||
x = device_put(1.)
|
||||
x.block_until_ready()
|
||||
# Tests only that block_until_ready() does not produce an error.
|
||||
y = x.block_until_ready()
|
||||
# Tests mostly that block_until_ready() does not produce an error.
|
||||
self.assertTrue(y is x)
|
||||
|
||||
def test_namedtuple_transparency(self):
|
||||
# See https://github.com/google/jax/issues/446
|
||||
|
Loading…
x
Reference in New Issue
Block a user