diff --git a/CHANGELOG.md b/CHANGELOG.md index aba51ab97..177c353cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jaxlib 0.3.1 (Unreleased) * Changes +* Deprecations: + * The ``.block_host_until_ready()`` method on JAX arrays has been deprecated. + Use ``.block_until_ready()`` instead. ## jax 0.3.1 (Feb 18, 2022) * [GitHub diff --git a/jax/_src/device_array.py b/jax/_src/device_array.py index f70a6ea2f..dfb765d11 100644 --- a/jax/_src/device_array.py +++ b/jax/_src/device_array.py @@ -138,7 +138,7 @@ class _DeviceArray(DeviceArray): # type: ignore Returns the buffer object (`self`). """ self._check_if_deleted() - self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error + self.device_buffer.block_until_ready() return self @property diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 914f9d7c9..fd9a25440 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -653,7 +653,7 @@ def _sda_check_if_deleted(self): def _sda_block_until_ready(self): self._check_if_deleted() for buf in self.device_buffers: - buf.block_host_until_ready() + buf.block_until_ready() return self