[host_callback] Deprecate the jax.experimental.host_callback module.

This commit is contained in:
George Necula 2024-03-21 08:18:57 +02:00
parent d4948d8f13
commit ca59971bef
3 changed files with 56 additions and 11 deletions

View File

@ -20,6 +20,8 @@ Remember to align the itemized text with the first line of an item within a list
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.

View File

@ -13,7 +13,10 @@
# limitations under the License.
"""Primitives for calling Python functions on the host from JAX accelerator code.
**Experimental: please give feedback, and expect changes.**
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@ -505,7 +508,7 @@ import logging
import math
import threading
import traceback
from typing import Any, Callable, Optional, cast
from typing import Any, Callable, cast
from jax._src import api
from jax._src import core
@ -589,7 +592,7 @@ XlaLocalClient = xla_client.Client
DType = Any
def id_tap(tap_func,
def _deprecated_id_tap(tap_func,
arg,
*,
result=None,
@ -598,7 +601,10 @@ def id_tap(tap_func,
**kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
**Experimental: please give feedback, and expect changes!**
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
@ -662,7 +668,7 @@ def id_tap(tap_func,
return call_res
def id_print(arg,
def _deprecated_id_print(arg,
*,
result=None,
tap_with_device=False,
@ -672,7 +678,10 @@ def id_print(arg,
**kwargs):
"""Like :func:`id_tap` with a printing tap function.
**Experimental: please give feedback, and expect changes!**
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
@ -694,7 +703,7 @@ def id_print(arg,
printer = functools.partial(_print_tap_func,
output_stream=output_stream,
threshold=threshold, **kwargs)
return id_tap(
return _deprecated_id_tap(
printer,
arg,
result=result,
@ -702,13 +711,16 @@ def id_print(arg,
device_index=device_index)
def call(callback_func: Callable, arg, *,
def _deprecated_call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False,
device_index=0):
"""Make a call to the host, and expect a result.
**Experimental: please give feedback, and expect changes!**
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
Args:
callback_func: The Python function to invoke on the host as
@ -1858,7 +1870,7 @@ def barrier_wait(logging_name: str | None = None):
for d_idx, d in enumerate(_callback_handler_data.devices):
logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d)
x_on_dev = api.device_put(d_idx, device=d)
api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev)
api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev)
logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name)
@ -1875,9 +1887,14 @@ def barrier_wait(logging_name: str | None = None):
f"Last one was: {formatted_last_exception}") from last_exception
def stop_outfeed_receiver():
def _deprecated_stop_outfeed_receiver():
"""Stops the outfeed receiver runtime.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
This waits for all outfeeds from computations already running on all devices,
and then stops the outfeed receiver runtime. The runtime will be restarted
next time you use a tap function.
@ -1886,3 +1903,28 @@ def stop_outfeed_receiver():
using lax.outfeed directly after having used host callbacks.
"""
_callback_handler_data.stop()
_deprecation_msg = (
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
"is subsumed by the new JAX external callbacks. "
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
_deprecations = {
# Added March 20, 2024
"id_tap": (_deprecation_msg, _deprecated_id_tap),
"id_print": (_deprecation_msg, _deprecated_id_print),
"call": (_deprecation_msg, _deprecated_call),
"stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver),
}
import typing
if typing.TYPE_CHECKING:
id_tap = _deprecated_id_tap
id_print = _deprecated_id_print
call = _deprecated_call
stop_outfeed_receiver = _deprecated_stop_outfeed_receiver
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -85,6 +85,7 @@ filterwarnings = [
# end array_api_tests-related warnings
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
"ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning",
"ignore:The host_callback APIs are deprecated .*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",