mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[host_callback] Deprecate the jax.experimental.host_callback module.
This commit is contained in:
parent
d4948d8f13
commit
ca59971bef
@ -20,6 +20,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
|
||||
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
|
||||
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
|
||||
* The `jax.experimental.host_callback` module is deprecated.
|
||||
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
|
||||
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
||||
that cannot be converted to a JAX array now results in an exception.
|
||||
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
|
||||
|
@ -13,7 +13,10 @@
|
||||
# limitations under the License.
|
||||
"""Primitives for calling Python functions on the host from JAX accelerator code.
|
||||
|
||||
**Experimental: please give feedback, and expect changes.**
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
|
||||
This module introduces the host callback functions :func:`call`,
|
||||
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
|
||||
@ -505,7 +508,7 @@ import logging
|
||||
import math
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
@ -589,7 +592,7 @@ XlaLocalClient = xla_client.Client
|
||||
DType = Any
|
||||
|
||||
|
||||
def id_tap(tap_func,
|
||||
def _deprecated_id_tap(tap_func,
|
||||
arg,
|
||||
*,
|
||||
result=None,
|
||||
@ -598,7 +601,10 @@ def id_tap(tap_func,
|
||||
**kwargs):
|
||||
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
|
||||
|
||||
**Experimental: please give feedback, and expect changes!**
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
|
||||
``id_tap`` behaves semantically like the identity function but has the
|
||||
side-effect that a user-defined Python function is called with the runtime
|
||||
@ -662,7 +668,7 @@ def id_tap(tap_func,
|
||||
return call_res
|
||||
|
||||
|
||||
def id_print(arg,
|
||||
def _deprecated_id_print(arg,
|
||||
*,
|
||||
result=None,
|
||||
tap_with_device=False,
|
||||
@ -672,7 +678,10 @@ def id_print(arg,
|
||||
**kwargs):
|
||||
"""Like :func:`id_tap` with a printing tap function.
|
||||
|
||||
**Experimental: please give feedback, and expect changes!**
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
|
||||
On each invocation of the printing tap, the ``kwargs`` if present
|
||||
will be printed first (sorted by keys). Then arg will be printed,
|
||||
@ -694,7 +703,7 @@ def id_print(arg,
|
||||
printer = functools.partial(_print_tap_func,
|
||||
output_stream=output_stream,
|
||||
threshold=threshold, **kwargs)
|
||||
return id_tap(
|
||||
return _deprecated_id_tap(
|
||||
printer,
|
||||
arg,
|
||||
result=result,
|
||||
@ -702,13 +711,16 @@ def id_print(arg,
|
||||
device_index=device_index)
|
||||
|
||||
|
||||
def call(callback_func: Callable, arg, *,
|
||||
def _deprecated_call(callback_func: Callable, arg, *,
|
||||
result_shape=None,
|
||||
call_with_device=False,
|
||||
device_index=0):
|
||||
"""Make a call to the host, and expect a result.
|
||||
|
||||
**Experimental: please give feedback, and expect changes!**
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
|
||||
Args:
|
||||
callback_func: The Python function to invoke on the host as
|
||||
@ -1858,7 +1870,7 @@ def barrier_wait(logging_name: str | None = None):
|
||||
for d_idx, d in enumerate(_callback_handler_data.devices):
|
||||
logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d)
|
||||
x_on_dev = api.device_put(d_idx, device=d)
|
||||
api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev)
|
||||
api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev)
|
||||
|
||||
logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name)
|
||||
|
||||
@ -1875,9 +1887,14 @@ def barrier_wait(logging_name: str | None = None):
|
||||
f"Last one was: {formatted_last_exception}") from last_exception
|
||||
|
||||
|
||||
def stop_outfeed_receiver():
|
||||
def _deprecated_stop_outfeed_receiver():
|
||||
"""Stops the outfeed receiver runtime.
|
||||
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
|
||||
This waits for all outfeeds from computations already running on all devices,
|
||||
and then stops the outfeed receiver runtime. The runtime will be restarted
|
||||
next time you use a tap function.
|
||||
@ -1886,3 +1903,28 @@ def stop_outfeed_receiver():
|
||||
using lax.outfeed directly after having used host callbacks.
|
||||
"""
|
||||
_callback_handler_data.stop()
|
||||
|
||||
_deprecation_msg = (
|
||||
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
|
||||
"is subsumed by the new JAX external callbacks. "
|
||||
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
|
||||
|
||||
_deprecations = {
|
||||
# Added March 20, 2024
|
||||
"id_tap": (_deprecation_msg, _deprecated_id_tap),
|
||||
"id_print": (_deprecation_msg, _deprecated_id_print),
|
||||
"call": (_deprecation_msg, _deprecated_call),
|
||||
"stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
id_tap = _deprecated_id_tap
|
||||
id_print = _deprecated_id_print
|
||||
call = _deprecated_call
|
||||
stop_outfeed_receiver = _deprecated_stop_outfeed_receiver
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -85,6 +85,7 @@ filterwarnings = [
|
||||
# end array_api_tests-related warnings
|
||||
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
|
||||
"ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning",
|
||||
"ignore:The host_callback APIs are deprecated .*:DeprecationWarning",
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
Loading…
x
Reference in New Issue
Block a user