mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Deprecate .block_host_until_ready() in favor of .block_until_ready().
JAX kept an older name around (.block_host_until_ready()) in parallel with the new name (.block_until_ready()) to avoid breaking users. Deprecate it so we only have one name. PiperOrigin-RevId: 433228545
This commit is contained in:
parent
809156c1fc
commit
08fbd77d90
@ -24,6 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jaxlib 0.3.1 (Unreleased)
|
||||
* Changes
|
||||
* Deprecations:
|
||||
* The ``.block_host_until_ready()`` method on JAX arrays has been deprecated.
|
||||
Use ``.block_until_ready()`` instead.
|
||||
|
||||
## jax 0.3.1 (Feb 18, 2022)
|
||||
* [GitHub
|
||||
|
@ -138,7 +138,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
Returns the buffer object (`self`).
|
||||
"""
|
||||
self._check_if_deleted()
|
||||
self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error
|
||||
self.device_buffer.block_until_ready()
|
||||
return self
|
||||
|
||||
@property
|
||||
|
@ -653,7 +653,7 @@ def _sda_check_if_deleted(self):
|
||||
def _sda_block_until_ready(self):
|
||||
self._check_if_deleted()
|
||||
for buf in self.device_buffers:
|
||||
buf.block_host_until_ready()
|
||||
buf.block_until_ready()
|
||||
return self
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user