From ca59971bef8bd1d1af862331a93d49df5d46b63d Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 21 Mar 2024 08:18:57 +0200 Subject: [PATCH] [host_callback] Deprecate the jax.experimental.host_callback module. --- CHANGELOG.md | 2 + jax/experimental/host_callback.py | 64 +++++++++++++++++++++++++------ pyproject.toml | 1 + 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e318e12f..640cf2abf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ Remember to align the itemized text with the first line of an item within a list * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. + * The `jax.experimental.host_callback` module is deprecated. + Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. * The deprecated flag `jax_parallel_functions_output_gda` has been removed. diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b4ab6628a..49869d8e1 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -13,7 +13,10 @@ # limitations under the License. """Primitives for calling Python functions on the host from JAX accelerator code. -**Experimental: please give feedback, and expect changes.** +.. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -505,7 +508,7 @@ import logging import math import threading import traceback -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, cast from jax._src import api from jax._src import core @@ -589,7 +592,7 @@ XlaLocalClient = xla_client.Client DType = Any -def id_tap(tap_func, +def _deprecated_id_tap(tap_func, arg, *, result=None, @@ -598,7 +601,10 @@ def id_tap(tap_func, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -662,7 +668,7 @@ def id_tap(tap_func, return call_res -def id_print(arg, +def _deprecated_id_print(arg, *, result=None, tap_with_device=False, @@ -672,7 +678,10 @@ def id_print(arg, **kwargs): """Like :func:`id_tap` with a printing tap function. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -694,7 +703,7 @@ def id_print(arg, printer = functools.partial(_print_tap_func, output_stream=output_stream, threshold=threshold, **kwargs) - return id_tap( + return _deprecated_id_tap( printer, arg, result=result, @@ -702,13 +711,16 @@ def id_print(arg, device_index=device_index) -def call(callback_func: Callable, arg, *, +def _deprecated_call(callback_func: Callable, arg, *, result_shape=None, call_with_device=False, device_index=0): """Make a call to the host, and expect a result. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ Args: callback_func: The Python function to invoke on the host as @@ -1858,7 +1870,7 @@ def barrier_wait(logging_name: str | None = None): for d_idx, d in enumerate(_callback_handler_data.devices): logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d) x_on_dev = api.device_put(d_idx, device=d) - api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev) + api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev) logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name) @@ -1875,9 +1887,14 @@ def barrier_wait(logging_name: str | None = None): f"Last one was: {formatted_last_exception}") from last_exception -def stop_outfeed_receiver(): +def _deprecated_stop_outfeed_receiver(): """Stops the outfeed receiver runtime. + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + This waits for all outfeeds from computations already running on all devices, and then stops the outfeed receiver runtime. The runtime will be restarted next time you use a tap function. @@ -1886,3 +1903,28 @@ def stop_outfeed_receiver(): using lax.outfeed directly after having used host callbacks. """ _callback_handler_data.stop() + +_deprecation_msg = ( + "The host_callback APIs are deprecated as of March 20, 2024. The functionality " + "is subsumed by the new JAX external callbacks. " + "See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.") + +_deprecations = { + # Added March 20, 2024 + "id_tap": (_deprecation_msg, _deprecated_id_tap), + "id_print": (_deprecation_msg, _deprecated_id_print), + "call": (_deprecation_msg, _deprecated_call), + "stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver), +} + +import typing +if typing.TYPE_CHECKING: + id_tap = _deprecated_id_tap + id_print = _deprecated_id_print + call = _deprecated_call + stop_outfeed_receiver = _deprecated_stop_outfeed_receiver +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/pyproject.toml b/pyproject.toml index fdbcf1555..cb6743651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ filterwarnings = [ # end array_api_tests-related warnings "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", "ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning", + "ignore:The host_callback APIs are deprecated .*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER",