[callback] Add a flag to implement host_callback in terms of io_callback.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
This commit is contained in:
George Necula 2024-03-29 13:36:20 +02:00
parent 2512843a56
commit a510f03ef8
6 changed files with 476 additions and 177 deletions

View File

@ -50,6 +50,8 @@ Remember to align the itemized text with the first line of an item within a list
`spmd_axis_name` argument for expressing SPMD device-parallel computations. `spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated. * The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`#20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception. that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed. * The deprecated flag `jax_parallel_functions_output_gda` has been removed.
@ -1451,7 +1453,7 @@ Changes:
special autodiff handling for hcb.id_tap and id_print. special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag. environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`). using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the * Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the

View File

@ -997,7 +997,11 @@ pytype_library(
pytype_library( pytype_library(
name = "experimental_host_callback", name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"], srcs = [
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
"experimental/host_callback.py",
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":jax", ":jax",

View File

@ -17,6 +17,7 @@
The host_callback APIs are deprecated as of March 20, 2024. The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_ `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
This module introduces the host callback functions :func:`call`, This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device :func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@ -501,6 +502,7 @@ Still to do:
from __future__ import annotations from __future__ import annotations
import atexit import atexit
import enum
from collections.abc import Sequence from collections.abc import Sequence
import functools import functools
import itertools import itertools
@ -510,6 +512,7 @@ import threading
import traceback import traceback
from typing import Any, Callable, cast from typing import Any, Callable, cast
import jax
from jax._src import api from jax._src import api
from jax._src import core from jax._src import core
from jax._src import config from jax._src import config
@ -517,6 +520,7 @@ from jax import custom_derivatives
from jax._src import dtypes from jax._src import dtypes
from jax import lax from jax import lax
from jax.experimental import pjit from jax.experimental import pjit
from jax.experimental import io_callback
from jax._src.interpreters import ad, batching, pxla from jax._src.interpreters import ad, batching, pxla
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
@ -560,6 +564,15 @@ _HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
'Has no effect on TPU, since only the outfeed mechanism is implemented.' 'Has no effect on TPU, since only the outfeed mechanism is implemented.'
) )
) )
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
help=(
'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. '
'See https://github.com/google/jax/issues/20385.'
)
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -591,6 +604,15 @@ XlaDevice = xla_client.Device
XlaLocalClient = xla_client.Client XlaLocalClient = xla_client.Client
DType = Any DType = Any
class CallbackFlavor(enum.Enum):
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
See https://github.com/google/jax/issues/20385.
"""
IO_CALLBACK = 1 # uses jax.experimental.io_callback
PURE = 2 # uses jax.pure_callback
DEBUG = 3 # uses jax.debug.callback, valid only when there are no results
def _deprecated_id_tap(tap_func, def _deprecated_id_tap(tap_func,
arg, arg,
@ -598,6 +620,7 @@ def _deprecated_id_tap(tap_func,
result=None, result=None,
tap_with_device=False, tap_with_device=False,
device_index=0, device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs): **kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``. """Host-callback tap primitive, like identity function with a call to ``tap_func``.
@ -605,6 +628,7 @@ def _deprecated_id_tap(tap_func,
The host_callback APIs are deprecated as of March 20, 2024. The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_ `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
``id_tap`` behaves semantically like the identity function but has the ``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime side-effect that a user-defined Python function is called with the runtime
@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func,
device_index: specifies from which device the tap function is invoked in a device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism, SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True. i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns: Returns:
``arg``, or ``result`` if given. ``arg``, or ``result`` if given.
@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func,
call_with_device=tap_with_device, call_with_device=tap_with_device,
result_shape=None, result_shape=None,
identity=True, identity=True,
device_index=device_index) device_index=device_index,
callback_flavor=callback_flavor)
if result is not None: if result is not None:
return result return result
@ -675,6 +703,7 @@ def _deprecated_id_print(arg,
device_index=0, device_index=0,
output_stream=None, output_stream=None,
threshold=None, threshold=None,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs): **kwargs):
"""Like :func:`id_tap` with a printing tap function. """Like :func:`id_tap` with a printing tap function.
@ -682,6 +711,7 @@ def _deprecated_id_print(arg,
The host_callback APIs are deprecated as of March 20, 2024. The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_ `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
On each invocation of the printing tap, the ``kwargs`` if present On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed, will be printed first (sorted by keys). Then arg will be printed,
@ -697,6 +727,9 @@ def _deprecated_id_print(arg,
built-in ``print``. The string will be passed as built-in ``print``. The string will be passed as
``output_stream.write(s)``. ``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``. * ``threshold`` is passed to ``numpy.array2string``.
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
For more details see the :mod:`jax.experimental.host_callback` module documentation. For more details see the :mod:`jax.experimental.host_callback` module documentation.
""" """
@ -708,19 +741,22 @@ def _deprecated_id_print(arg,
arg, arg,
result=result, result=result,
tap_with_device=tap_with_device, tap_with_device=tap_with_device,
device_index=device_index) device_index=device_index,
callback_flavor=callback_flavor)
def _deprecated_call(callback_func: Callable, arg, *, def _deprecated_call(callback_func: Callable, arg, *,
result_shape=None, result_shape=None,
call_with_device=False, call_with_device=False,
device_index=0): device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK):
"""Make a call to the host, and expect a result. """Make a call to the host, and expect a result.
.. warning:: .. warning::
The host_callback APIs are deprecated as of March 20, 2024. The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_ `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
Args: Args:
callback_func: The Python function to invoke on the host as callback_func: The Python function to invoke on the host as
@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
device_index: specifies from which device the tap function is invoked in a device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism, SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True. i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns: Returns:
the result of the ``callback_func`` invocation. the result of the ``callback_func`` invocation.
For more details see the :mod:`jax.experimental.host_callback` module documentation. For more details see the :mod:`jax.experimental.host_callback` module documentation.
""" """
if (not _HOST_CALLBACK_LEGACY.value and
callback_flavor is CallbackFlavor.DEBUG and
result_shape is not None):
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
"flavor of callback only when the `result_shape` is None. "
"See https://github.com/google/jax/issues/20385."
)
return _call(callback_func, arg, result_shape=result_shape, return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False, call_with_device=call_with_device, identity=False,
device_index=device_index) device_index=device_index, callback_flavor=callback_flavor)
# We need the wrapper function to have hash and equality defined since it is # We need the wrapper function to have hash and equality defined since it is
@ -766,6 +814,11 @@ class _CallbackWrapper:
self.callback_func = callback_func self.callback_func = callback_func
self.identity = identity self.identity = identity
self.call_with_device = call_with_device self.call_with_device = call_with_device
if not _HOST_CALLBACK_LEGACY.value and call_with_device:
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
" do not support `tap_with_device` and `call_with_device`. "
"See https://github.com/google/jax/issues/20385.")
def __hash__(self): def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device)) return hash((self.callback_func, self.identity, self.call_with_device))
@ -775,7 +828,16 @@ class _CallbackWrapper:
self.identity == other.identity and self.identity == other.identity and
self.call_with_device == other.call_with_device) self.call_with_device == other.call_with_device)
def __call__(self, arg, device, transforms): def __call__(self, *args, **kwargs):
if _HOST_CALLBACK_LEGACY.value:
return self._call_legacy(*args, **kwargs)
else:
if self.identity:
# For id_tap, we pass empty transforms, for backwards compatibility
return self.callback_func(args[0], ())
return self.callback_func(*args, **kwargs)
def _call_legacy(self, arg, device, transforms):
if self.identity: if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility # For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device: if self.call_with_device:
@ -797,14 +859,16 @@ def _call(callback_func: Callable,
result_shape=None, result_shape=None,
call_with_device=False, call_with_device=False,
device_index=0, device_index=0,
identity=False): identity=False,
# Lazy initialization callback_flavor=CallbackFlavor.IO_CALLBACK):
_initialize_outfeed_receiver( if _HOST_CALLBACK_LEGACY.value:
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) # Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
api.check_callable(callback_func) api.check_callable(callback_func)
flat_args, arg_treedef = tree_util.tree_flatten(arg) flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args: for arg_ in flat_args:
dispatch.check_arg(arg) dispatch.check_arg(arg_)
# See definition of outside_call_p for what parameters it takes # See definition of outside_call_p for what parameters it takes
params: dict[str, Any] = {} params: dict[str, Any] = {}
# TODO: wrap function # TODO: wrap function
@ -829,8 +893,27 @@ def _call(callback_func: Callable,
params["result_treedef"] = result_treedef params["result_treedef"] = result_treedef
params["flat_results_aval"] = tuple(flat_results_aval) params["flat_results_aval"] = tuple(flat_results_aval)
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) if _HOST_CALLBACK_LEGACY.value:
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
else:
callback_device = jax.local_devices()[device_index]
sharding = jax.sharding.SingleDeviceSharding(callback_device)
callback_func = _CallbackWrapper(callback_func, identity,
call_with_device)
if callback_flavor is CallbackFlavor.DEBUG:
assert identity
jax.debug.callback(callback_func, arg)
return arg
elif callback_flavor is CallbackFlavor.PURE:
call_res = jax.pure_callback(callback_func, result_shape, arg,
sharding=sharding)
else:
call_res = io_callback(callback_func, result_shape, arg,
sharding=sharding,
ordered=True)
return call_res if not identity else arg
# We need the lock for when we use the CustomCall implementation of callbacks. # We need the lock for when we use the CustomCall implementation of callbacks.
@ -855,7 +938,6 @@ def _print_tap_func(
threshold: the value of numpy.array2string threshold parameter. threshold: the value of numpy.array2string threshold parameter.
**kwargs: all other keyword args are printed before printing `arg`. **kwargs: all other keyword args are printed before printing `arg`.
""" """
def emit_str(s: str): def emit_str(s: str):
if output_stream is not None: if output_stream is not None:
output_stream.write(s + "\n") output_stream.write(s + "\n")
@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
For more details see the :mod:`jax.experimental.host_callback` module documentation. For more details see the :mod:`jax.experimental.host_callback` module documentation.
""" """
if not _HOST_CALLBACK_LEGACY.value:
jax.effects_barrier()
return
logging_name = logging_name or "" logging_name = logging_name or ""
logger.debug("barrier_wait[%s]: start", logging_name) logger.debug("barrier_wait[%s]: start", logging_name)
@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver():
_deprecation_msg = ( _deprecation_msg = (
"The host_callback APIs are deprecated as of March 20, 2024. The functionality " "The host_callback APIs are deprecated as of March 20, 2024. The functionality "
"is subsumed by the new JAX external callbacks. " "is subsumed by the new JAX external callbacks. "
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.") "See https://github.com/google/jax/issues/20385.")
_deprecations = { _deprecations = {
# Added March 20, 2024 # Added March 20, 2024

View File

@ -91,8 +91,10 @@ testing_stream = _TestingOutputStream()
def fun1(a): def fun1(a):
"""Function used for several `id_tap` tests.""" """Function used for several `id_tap` tests."""
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream) y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream,
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y) callback_flavor=hcb.CallbackFlavor.DEBUG)
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return y ** 2 # Some computation to make the gradient interesting return y ** 2 # Some computation to make the gradient interesting
@ -253,6 +255,10 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait("HostCallbackTapTest.tearDown") hcb.barrier_wait("HostCallbackTapTest.tearDown")
super().tearDown() super().tearDown()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def test_tap_eval(self): def test_tap_eval(self):
self.assertAllClose((5. * 2.) ** 2, fun1(5.)) self.assertAllClose((5. * 2.) ** 2, fun1(5.))
hcb.barrier_wait() hcb.barrier_wait()
@ -320,6 +326,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
testing_stream.output) testing_stream.output)
def test_tap_with_device(self): def test_tap_with_device(self):
self.supported_only_in_legacy_mode()
def func2(x): def func2(x):
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
output_stream=testing_stream, output_stream=testing_stream,
@ -335,6 +342,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_eval_exception(self): def test_tap_eval_exception(self):
if not hcb._HOST_CALLBACK_OUTFEED.value: if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall") raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error # Simulate a tap error
def tap_err(*args, **kwargs): def tap_err(*args, **kwargs):
raise ValueError("Some user message") raise ValueError("Some user message")
@ -345,19 +353,30 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3 return x3
with self.assertRaisesRegex( if hcb._HOST_CALLBACK_LEGACY.value:
hcb.CallbackException, ctx = self.assertRaisesRegex(
re.compile("There were exceptions during callback processing. Last one was:.*" hcb.CallbackException,
"ValueError: Some user message", re.DOTALL)): re.compile("There were exceptions during callback processing. Last one was:.*"
"ValueError: Some user message", re.DOTALL))
else:
ctx = self.assertRaisesRegex(Exception, "Some user message")
with ctx:
func(0) func(0)
hcb.barrier_wait() hcb.barrier_wait()
# We should have received everything before the error if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(self, """ # We should have received everything before the error
what: x1 assertMultiLineStrippedEqual(self, """
1 what: x1
what: x3 1
3""", testing_stream.output) what: x3
3""", testing_stream.output)
else:
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
what: x1
1""", testing_stream.output)
def test_tap_empty(self): def test_tap_empty(self):
"""Tap empty arrays.""" """Tap empty arrays."""
@ -488,6 +507,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_jit_devices(self): def test_tap_jit_devices(self):
"""Running on multiple devices.""" """Running on multiple devices."""
self.supported_only_in_legacy_mode()
logging.info("%s: has devices %s", self._testMethodName, local_devices()) logging.info("%s: has devices %s", self._testMethodName, local_devices())
def func(x, device_id): def func(x, device_id):
@ -830,19 +850,24 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3 return x3
res = jax.jit(func)(0) # No error yet if hcb._HOST_CALLBACK_LEGACY.value:
with self.assertRaises(hcb.CallbackException): res = jax.jit(func)(0) # No error yet
hcb.barrier_wait() with self.assertRaises(hcb.CallbackException):
hcb.barrier_wait()
# Even though the receiver thread raised, the main thread should still # Even though the receiver thread raised, the main thread should still
# return 3. # return 3.
self.assertEqual(3, res) self.assertEqual(3, res)
# We should have received all others # We should have received all others
assertMultiLineStrippedEqual(self, """ assertMultiLineStrippedEqual(self, """
what: x1 what: x1
1 1
what: x3 what: x3
3""", testing_stream.output) 3""", testing_stream.output)
else:
with self.assertRaisesRegex(Exception, "NotImplementedError"):
res = jax.jit(func)(0)
hcb.barrier_wait()
def test_tap_while(self): def test_tap_while(self):
"""Executing while, even without JIT uses compiled code""" """Executing while, even without JIT uses compiled code"""
@ -878,7 +903,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# The output of id_print is not needed for backwards pass # The output of id_print is not needed for backwards pass
def func(x): def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3", return 2. * hcb.id_print(x * 3., what="x * 3",
output_stream=testing_stream) output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
grad_func = jax.grad(func) grad_func = jax.grad(func)
arg = jnp.float32(5.) arg = jnp.float32(5.)
@ -886,21 +912,22 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# making the Jaxpr does not print anything # making the Jaxpr does not print anything
hcb.barrier_wait() hcb.barrier_wait()
treedef = jax.tree.structure(arg) if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual( treedef = jax.tree.structure(arg)
self, f""" assertMultiLineStrippedEqual(
{{ lambda ; a:f32[]. let self, f"""
b:f32[] = mul a 3.00 {{ lambda ; a:f32[]. let
c:f32[] = outside_call[ b:f32[] = mul a 3.00
arg_treedef={treedef} c:f32[] = outside_call[
callback=... arg_treedef={treedef}
device_index=0 callback=...
identity=True device_index=0
] b identity=True
_:f32[] = mul 2.00 c ] b
d:f32[] = mul 2.00 1.00 _:f32[] = mul 2.00 c
e:f32[] = mul d 3.00 d:f32[] = mul 2.00 1.00
in (e,) }}""", jaxpr) e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output) assertMultiLineStrippedEqual(self, "", testing_stream.output)
testing_stream.reset() testing_stream.reset()
@ -914,9 +941,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_grad_simple(self): def test_tap_grad_simple(self):
def func(x): def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * hcb.id_print(y * 3., what="y * 3", return x * hcb.id_print(y * 3., what="y * 3",
output_stream=testing_stream) output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
grad_func = jax.grad(func) grad_func = jax.grad(func)
@ -931,7 +960,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_grad_grad(self): def test_tap_grad_grad(self):
def func(x): def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * (y * 3.) return x * (y * 3.)
grad_func = jax.grad(jax.grad(func)) grad_func = jax.grad(jax.grad(func))
@ -952,7 +982,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def func(x): def func(x):
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair", x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
result=(x * 4., x * 5.), result=(x * 4., x * 5.),
output_stream=testing_stream) output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x4 + 2. * x5 return x4 + 2. * x5
x = jnp.float32(5.) x = jnp.float32(5.)
@ -967,15 +998,18 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_jvp_float0(self): def test_tap_jvp_float0(self):
def f(x, yint): def f(x, yint):
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint)) x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint),
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * yint return x * yint
res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0))) res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
self.assertAllClose((6., 0.6), res) self.assertAllClose((6., 0.6), res)
def test_tap_grad_float0(self): def test_tap_grad_float0(self):
def func(x, yint): def func(x, yint):
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream) x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * yint.astype(x.dtype) return x * yint.astype(x.dtype)
grad_func = jax.grad(func) grad_func = jax.grad(func)
@ -993,7 +1027,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x = (np.array([.7, .8], dtype=np.float32), x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32)) np.array([11, 12, 13], dtype=np.int32))
def f_jax(x): def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important x = hcb.id_print(x, result=x, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
return (3. * x[0], x[1]) return (3. * x[0], x[1])
def f_jax_vjp(x): def f_jax_vjp(x):
@ -1015,7 +1050,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x = (np.array([.7, .8], dtype=np.float32), x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32)) np.array([11, 12, 13], dtype=np.int32))
def f_jax(x): def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important x = hcb.id_print(x, result=x, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
return (jnp.sin(x[0]), x[1]) return (jnp.sin(x[0]), x[1])
def wrap_vjp(f, args, res_f_of_args): def wrap_vjp(f, args, res_f_of_args):
@ -1059,32 +1095,52 @@ class HostCallbackTapTest(jtu.JaxTestCase):
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
vmap_fun1(vargs) vmap_fun1(vargs)
hcb.barrier_wait() hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 assertMultiLineStrippedEqual(self, """
[ 8.00 10.00] transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 [ 8.00 10.00]
[24.00 30.00]""", testing_stream.output) transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
[24.00 30.00]""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: a * 2
8.00
what: a * 2
10.00
what: y * 3
24.00
what: y * 3
30.00
""", testing_stream.output)
def test_tap_vmap_not_batched(self): def test_tap_vmap_not_batched(self):
x = 3. x = 3.
def func(y): def func(y):
# x is not mapped, y is mapped # x is not mapped, y is mapped
_, y = hcb.id_print((x, y), output_stream=testing_stream) _, y = hcb.id_print((x, y), output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x + y return x + y
vmap_func = jax.vmap(func) vmap_func = jax.vmap(func)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
_ = vmap_func(vargs) _ = vmap_func(vargs)
hcb.barrier_wait() hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (None, 0)})] assertMultiLineStrippedEqual(self, """
( 3.00 [4.00 5.00] )""", testing_stream.output) transforms: [('batch', {'batch_dims': (None, 0)})]
( 3.00 [4.00 5.00] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( 3.00 4.00 )
( 3.00 5.00 )
""", testing_stream.output)
def test_tap_vmap_vmap(self): def test_tap_vmap_vmap(self):
# A 2D tensor with x[i, j] = i + j using 2 vmap # A 2D tensor with x[i, j] = i + j using 2 vmap
def sum(x, y): def sum(x, y):
return hcb.id_print(x + y, output_stream=testing_stream) return hcb.id_print(x + y, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
def sum_rows(xv, y): def sum_rows(xv, y):
return jax.vmap(sum, in_axes=(0, None))(xv, y) return jax.vmap(sum, in_axes=(0, None))(xv, y)
@ -1097,22 +1153,44 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv))) # assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv)))
_ = sum_all(xv, yv) _ = sum_all(xv, yv)
hcb.barrier_wait() hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] assertMultiLineStrippedEqual(self, """
[[0 1 2 3 4] transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
[1 2 3 4 5] [[0 1 2 3 4]
[2 3 4 5 6]]""", testing_stream.output) [1 2 3 4 5]
[2 3 4 5 6]]""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
0
1
2
1
2
3
2
3
4
3
4
5
4
5
6
""", testing_stream.output)
def test_tap_vmap_while(self): def test_tap_vmap_while(self):
"""Vmap of while.""" """Vmap of while."""
def func(x): def func(x):
# like max(x, 2) # like max(x, 2)
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream) x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
x2 = lax.while_loop( x2 = lax.while_loop(
lambda x: x < 2, lambda x: hcb.id_print( lambda x: x < 2, lambda x: hcb.id_print(
x + 1, where="body:x+1", output_stream=testing_stream), x1) x + 1, where="body:x+1", output_stream=testing_stream,
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream) callback_flavor=hcb.CallbackFlavor.DEBUG), x1)
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return res return res
inputs = np.arange(5, dtype=np.int32) inputs = np.arange(5, dtype=np.int32)
@ -1121,72 +1199,93 @@ class HostCallbackTapTest(jtu.JaxTestCase):
jax.jit(jax.vmap(func))(inputs), jax.jit(jax.vmap(func))(inputs),
check_dtypes=False) check_dtypes=False)
hcb.barrier_wait() hcb.barrier_wait()
assertMultiLineStrippedEqual( if hcb._HOST_CALLBACK_LEGACY.value:
self, """ assertMultiLineStrippedEqual(
transforms: [('batch', {'batch_dims': (0,)})] where: before:x self, """
[0 1 2 3 4] transforms: [('batch', {'batch_dims': (0,)})] where: before:x
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 [0 1 2 3 4]
[1 2 3 4 5] transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 [1 2 3 4 5]
[2 3 3 4 5] transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
transforms: [('batch', {'batch_dims': (0,)})] where: after:x [2 3 3 4 5]
[2 2 2 3 4]""", testing_stream.output) transforms: [('batch', {'batch_dims': (0,)})] where: after:x
[2 2 2 3 4]""", testing_stream.output)
else:
pass # order of vmaps is not guaranteed
def test_tap_vmap_while_tap_cond(self): def test_tap_vmap_while_tap_cond(self):
"""Vmap of while, with a tap in the conditional.""" """Vmap of while, with a tap in the conditional."""
def func(x): def func(x):
# like max(x, 2) # like max(x, 2)
x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x1 = hcb.id_print(x, where="1", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c", x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
output_stream=testing_stream), output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG),
lambda x: hcb.id_print(x + 1, where="w_b", lambda x: hcb.id_print(x + 1, where="w_b",
output_stream=testing_stream), output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG),
x1) x1)
res = hcb.id_print(x2, where="3", output_stream=testing_stream) res = hcb.id_print(x2, where="3", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return res return res
inputs = np.arange(5, dtype=np.int32) inputs = np.arange(5, dtype=np.int32)
res = jax.jit(jax.vmap(func))(inputs) res = jax.jit(jax.vmap(func))(inputs)
hcb.barrier_wait() hcb.barrier_wait()
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
assertMultiLineStrippedEqual(self, """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0,)})] where: 1 assertMultiLineStrippedEqual(self, """
[0 1 2 3 4] transforms: [('batch', {'batch_dims': (0,)})] where: 1
transforms: [('batch', {'batch_dims': (0,)})] where: w_c [0 1 2 3 4]
[ True True False False False] transforms: [('batch', {'batch_dims': (0,)})] where: w_c
transforms: [('batch', {'batch_dims': (0,)})] where: w_b [ True True False False False]
[1 2 3 4 5] transforms: [('batch', {'batch_dims': (0,)})] where: w_b
transforms: [('batch', {'batch_dims': (0,)})] where: w_c [1 2 3 4 5]
[ True False False False False] transforms: [('batch', {'batch_dims': (0,)})] where: w_c
transforms: [('batch', {'batch_dims': (0,)})] where: w_b [ True False False False False]
[2 3 3 4 5] transforms: [('batch', {'batch_dims': (0,)})] where: w_b
transforms: [('batch', {'batch_dims': (0,)})] where: w_c [2 3 3 4 5]
[False False False False False] transforms: [('batch', {'batch_dims': (0,)})] where: w_c
transforms: [('batch', {'batch_dims': (0,)})] where: 3 [False False False False False]
[2 2 2 3 4]""", testing_stream.output) transforms: [('batch', {'batch_dims': (0,)})] where: 3
[2 2 2 3 4]""", testing_stream.output)
else:
pass # order of vmap is not guaranteed
def test_tap_transforms_doc(self): def test_tap_transforms_doc(self):
# Examples from the documentation # Examples from the documentation
def power3(x): def power3(x):
y = x * x y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple. # Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return y * x return y * x
print(f"impl = {power3(3.)}") print(f"impl = {power3(3.)}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. )""" what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. )""" what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
@ -1197,32 +1296,41 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@print_tangents.defjvp @print_tangents.defjvp
def print_tangents_jvp(primals, tangents): def print_tangents_jvp(primals, tangents):
arg_dot, = tangents arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream) hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return primals, tangents return primals, tangents
def power3_with_tangents(x): def power3_with_tangents(x):
y = x * x y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple. # Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
print_tangents((x, y)) print_tangents((x, y))
return y * x return y * x
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. ) what: x,x^2
what: tangents ( 3. 9. )
( 0.1 0.6 )""" what: tangents
self.assertMultiLineStrippedEqual(expected, testing_stream.output) ( 0.1 0.6 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
print(f"grad = {jax.grad(power3)(3.)}") print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait() hcb.barrier_wait()
# Only the primals by default # Only the primals by default
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. )""" what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
@ -1236,7 +1344,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
return print_cotangents(arg), None return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a] # f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b): def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return ct_b, return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
@ -1244,18 +1353,26 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def power3_with_cotangents(x): def power3_with_cotangents(x):
y = x * x y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple. # Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
# Must use the output of print_cotangents # Must use the output of print_cotangents
(x1, y1) = print_cotangents((x, y)) (x1, y1) = print_cotangents((x, y))
return y1 * x1 return y1 * x1
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. ) what: x,x^2
what: cotangents ( 3. 9. )
( 9. 3. )""" what: cotangents
( 9. 3. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )
what: cotangents
( 9.0 3.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
@ -1263,43 +1380,82 @@ class HostCallbackTapTest(jtu.JaxTestCase):
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 expected = """
( [2. 3.] [4. 9.] )""" transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 expected = """
( [2. 3.] [4. 9.] )""" transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 expected = """
( [2. 3.] [4. 9.] ) transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents ( [2. 3.] [4. 9.] )
( [4. 9.] [2. 3.] )""" transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
what: cotangents
( 4.0 2.0 )
what: cotangents
( 9.0 3.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
hcb.barrier_wait() hcb.barrier_wait()
expected = """ if hcb._HOST_CALLBACK_LEGACY.value:
what: x,x^2 expected = """
( 3. 9. ) what: x,x^2
what: x,x^2 ( 3. 9. )
( 27. 729. ) what: x,x^2
what: x,x^2 ( 27. 729. )
( 3. 9. )""" what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )
what: x,x^2
( 27.0 729.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output) self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset() testing_stream.reset()
def test_tap_pmap(self): def test_tap_pmap(self):
self.supported_only_in_legacy_mode()
if len(local_devices()) < 2: if len(local_devices()) < 2:
raise SkipTest("test requires at least 2 devices") raise SkipTest("test requires at least 2 devices")
@ -1326,6 +1482,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 4 16 )""") ( 4 16 )""")
def test_tap_pmap_vmap(self): def test_tap_pmap_vmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ij] = i * 10 + j # A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices()) nr_devices = len(local_devices())
shape = (nr_devices, 3) shape = (nr_devices, 3)
@ -1353,6 +1510,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_vmap(self): def test_tap_pmap_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 + k # A matrix M[ijk] = i * 100 + j * 10 + k
self.supported_only_in_legacy_mode()
nr_devices = len(local_devices()) nr_devices = len(local_devices())
if nr_devices % 2 != 0: if nr_devices % 2 != 0:
raise SkipTest("test works only on even number of devices") raise SkipTest("test works only on even number of devices")
@ -1386,6 +1544,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_extra(self): def test_tap_pmap_pmap_extra(self):
"""pmap of a pmap surrounded by extra code.""" """pmap of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j # A matrix M[ij] = i * 10 + j
self.supported_only_in_legacy_mode()
nr_devices = len(local_devices()) nr_devices = len(local_devices())
if nr_devices != 2: if nr_devices != 2:
raise SkipTest("test works only on 2 devices") raise SkipTest("test works only on 2 devices")
@ -1419,6 +1578,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[[203.00 205.00 207.00]]""") [[203.00 205.00 207.00]]""")
def test_tap_jvp_pmap_vmap(self): def test_tap_jvp_pmap_vmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ijk] = i * 100 + j * 10 * k # A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices()) nr_devices = len(local_devices())
shape = (nr_devices, 2, 3) shape = (nr_devices, 2, 3)
@ -1445,6 +1605,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[220.00 222.00 224.00]]""") [220.00 222.00 224.00]]""")
def test_tap_vmap_pmap(self): def test_tap_vmap_pmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ijk] = i * 100 + j * 10 * k # A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices()) nr_devices = len(local_devices())
shape = (2, nr_devices, 3) shape = (2, nr_devices, 3)
@ -1472,6 +1633,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning() @ignore_jit_of_pmap_warning()
def test_tap_jit_pmap_extra(self): def test_tap_jit_pmap_extra(self):
"""jit of a pmap surrounded by extra code.""" """jit of a pmap surrounded by extra code."""
self.supported_only_in_legacy_mode()
# A matrix M[ij] = i * 10 + j # A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices()) nr_devices = len(local_devices())
assert nr_devices in (1, 2) assert nr_devices in (1, 2)
@ -1540,6 +1702,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@jtu.sample_product(device_index=[0, 1]) @jtu.sample_product(device_index=[0, 1])
def test_tap_pjit(self, device_index=0): def test_tap_pjit(self, device_index=0):
self.supported_only_in_legacy_mode()
if (device_index != 0 and if (device_index != 0 and
not hcb._HOST_CALLBACK_OUTFEED.value and not hcb._HOST_CALLBACK_OUTFEED.value and
jtu.test_device_matches(["cpu"])): jtu.test_device_matches(["cpu"])):
@ -1589,7 +1752,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_scan_custom_jvp(self): def test_tap_scan_custom_jvp(self):
"""custom JVP, inside scan. """custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives.""" This exercises the custom_jvp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_jvp @jax.custom_jvp
def f(x): def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x") return x * hcb.id_print(x, output_stream=testing_stream, what="x")
@ -1633,7 +1796,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_scan_custom_vjp(self): def test_tap_scan_custom_vjp(self):
"""custom VJP, inside scan. """custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives.""" This exercises the custom_vjp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_vjp @jax.custom_vjp
def f(x): def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x") return x * hcb.id_print(x, output_stream=testing_stream, what="x")
@ -1773,7 +1936,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
from jax.experimental.ode import odeint from jax.experimental.ode import odeint
def f(x, t, k): def f(x, t, k):
x = hcb.id_print(x) x = hcb.id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG)
return -k * x return -k * x
def loss(k=1.0): def loss(k=1.0):
@ -1785,7 +1948,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_remat_0(self): def test_tap_remat_0(self):
def f(i, k): def f(i, k):
x = hcb.id_print(k + i, output_stream=testing_stream) x = hcb.id_print(k + i, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return k * x return k * x
def loss(k): def loss(k):
@ -1804,6 +1968,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
use_remat=["old", "new", "none"], use_remat=["old", "new", "none"],
) )
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
self.supported_only_in_legacy_mode()
if use_remat == "old": raise SkipTest() if use_remat == "old": raise SkipTest()
def f(x): def f(x):
@ -1880,6 +2045,10 @@ class HostCallbackCallTest(jtu.JaxTestCase):
hcb.barrier_wait("HostCallbackCallTest.tearDown") hcb.barrier_wait("HostCallbackCallTest.tearDown")
super().tearDown() super().tearDown()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def call_log_testing_stream(self, func, arg, *, result_shape, name=""): def call_log_testing_stream(self, func, arg, *, result_shape, name=""):
"""Call `func` and log inputs and outputs to the testing stream""" """Call `func` and log inputs and outputs to the testing stream"""
@ -1916,6 +2085,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
with jtu.count_primitive_compiles() as count: with jtu.count_primitive_compiles() as count:
for _ in range(3): for _ in range(3):
self.assertAllClose(2 * arg, fun(arg)) self.assertAllClose(2 * arg, fun(arg))
r = jax.make_jaxpr(fun)(arg)
self.assertEqual(count[0], 1) self.assertEqual(count[0], 1)
@jtu.sample_product( @jtu.sample_product(
@ -2124,6 +2294,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
helper_print_optimized_hlo(fun2, m) helper_print_optimized_hlo(fun2, m)
def test_call_with_device(self): def test_call_with_device(self):
self.supported_only_in_legacy_mode()
def callback_func(x, device=None): def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}") testing_stream.write(f"device: {device}\n Called with {x}")
return x return x
@ -2139,6 +2310,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
Called with 3.00""") Called with 3.00""")
def test_call_pmap(self): def test_call_pmap(self):
self.supported_only_in_legacy_mode()
# Works for 1 or 2 devices # Works for 1 or 2 devices
def callback_func(x, device=None): def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}") testing_stream.write(f"device: {device}\n Called with {x}")
@ -2163,10 +2335,14 @@ class HostCallbackCallTest(jtu.JaxTestCase):
def f_outside(x): return x def f_outside(x): return x
def fun(x): def fun(x):
return hcb.call(f_outside, x, result_shape=x) return hcb.call(f_outside, x, result_shape=x,
callback_flavor=hcb.CallbackFlavor.PURE)
with self.assertRaisesRegex(NotImplementedError, if hcb._HOST_CALLBACK_LEGACY.value:
"batching rules are implemented only for id_tap, not for call"): with self.assertRaisesRegex(NotImplementedError,
"batching rules are implemented only for id_tap, not for call"):
jax.vmap(fun)(np.ones((2, 3)))
else:
jax.vmap(fun)(np.ones((2, 3))) jax.vmap(fun)(np.ones((2, 3)))
@jtu.sample_product(device_index=[0, 1]) @jtu.sample_product(device_index=[0, 1])
@ -2256,6 +2432,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
hcb.barrier_wait("Waiting for error") hcb.barrier_wait("Waiting for error")
def test_call_error_callback_throws_exception(self): def test_call_error_callback_throws_exception(self):
self.supported_only_in_legacy_mode()
def f_outside(x): def f_outside(x):
raise ValueError("user exception") raise ValueError("user exception")
def fun(x): def fun(x):
@ -2265,6 +2442,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"ValueError: user exception") "ValueError: user exception")
def test_call_error_callback_returns_unexpected_shape(self): def test_call_error_callback_returns_unexpected_shape(self):
self.supported_only_in_legacy_mode()
def fun(x): def fun(x):
return hcb.call(lambda x: (x, x), x, result_shape=x) return hcb.call(lambda x: (x, x), x, result_shape=x)
@ -2272,6 +2450,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"Callback func .* should have returned a result with pytree") "Callback func .* should have returned a result with pytree")
def test_call_error_then_compute(self): def test_call_error_then_compute(self):
self.supported_only_in_legacy_mode()
# Continue computation on device after error # Continue computation on device after error
def f_outside(x): def f_outside(x):
raise ValueError("user exception") raise ValueError("user exception")
@ -2283,7 +2462,9 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"ValueError: user exception") "ValueError: user exception")
def call_jax_other_device(jax_outside_fun, arg, *, device): def call_jax_other_device(
jax_outside_fun, arg, *, device,
callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK):
"""Calls a JAX function on a specific device with simple support for reverse AD. """Calls a JAX function on a specific device with simple support for reverse AD.
Functions whose name starts with "jax_outside" are called on another device, Functions whose name starts with "jax_outside" are called on another device,
@ -2296,7 +2477,8 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
@jax.custom_vjp @jax.custom_vjp
def make_call(arg): def make_call(arg):
return hcb.call(run_jax_outside_fun, arg, return hcb.call(run_jax_outside_fun, arg,
result_shape=jax.eval_shape(jax_outside_fun, arg)) result_shape=jax.eval_shape(jax_outside_fun, arg),
callback_flavor=callback_flavor)
# Define the fwd and bwd custom_vjp functions # Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(arg): def make_call_vjp_fwd(arg):
@ -2323,6 +2505,8 @@ class CallJaxTest(jtu.JaxTestCase):
"""Tests using `call_jax_other_device`.""" """Tests using `call_jax_other_device`."""
def setUp(self): def setUp(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
if xla_bridge.using_pjrt_c_api(): if xla_bridge.using_pjrt_c_api():
@ -2337,6 +2521,7 @@ class CallJaxTest(jtu.JaxTestCase):
self.outside_device = jax.devices("cpu")[1] self.outside_device = jax.devices("cpu")[1]
super().setUp() super().setUp()
def test_jax_impl(self): def test_jax_impl(self):
def f_jax(x): def f_jax(x):
return jnp.sin(x) return jnp.sin(x)
@ -2404,6 +2589,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
raise SkipTest("host_callback not implemented in PJRT C API") raise SkipTest("host_callback not implemented in PJRT C API")
super().setUp() super().setUp()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def assertRewrite(self, expected: str, func: Callable, args: Sequence, def assertRewrite(self, expected: str, func: Callable, args: Sequence,
has_input_token=True, has_output_token=True): has_input_token=True, has_output_token=True):
"""Check that the rewrite of func(*args) matches expected.""" """Check that the rewrite of func(*args) matches expected."""
@ -2624,7 +2813,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
def test_scan_custom_jvp(self): def test_scan_custom_jvp(self):
"""custom JVP, inside scan. """custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives.""" This exercises the custom_jvp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_jvp @jax.custom_jvp
def f(x): def f(x):
return x * hcb.id_print(x) return x * hcb.id_print(x)
@ -2706,7 +2895,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
def test_scan_custom_vjp(self): def test_scan_custom_vjp(self):
"""custom VJP, inside scan. """custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives.""" This exercises the custom_vjp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_vjp @jax.custom_vjp
def f(x): def f(x):
return x * hcb.id_print(x) return x * hcb.id_print(x)
@ -2849,6 +3038,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (c, d, e) }""", tap_scalar, [np.int32(3)]) in (c, d, e) }""", tap_scalar, [np.int32(3)])
def test_pmap(self): def test_pmap(self):
self.supported_only_in_legacy_mode()
def f(xv): def f(xv):
jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)), jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
axis_name="i")(xv) axis_name="i")(xv)

View File

@ -25,8 +25,8 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import jax import jax
from jax import config
from jax import numpy as jnp from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge from jax._src import xla_bridge
from jax.experimental import host_callback as hcb from jax.experimental import host_callback as hcb
@ -53,7 +53,8 @@ def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy, return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy,
tf_fun(arg)), tf_fun(arg)),
arg, result_shape=result_shape) arg, result_shape=result_shape,
callback_flavor=hcb.CallbackFlavor.DEBUG)
def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape): def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
@ -166,12 +167,17 @@ class CallToTFTest(jtu.JaxTestCase):
raise unittest.SkipTest("host_callback not implemented in PJRT C API") raise unittest.SkipTest("host_callback not implemented in PJRT C API")
super().setUp() super().setUp()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
testcase_name=f"_{ad=}", testcase_name=f"_{ad=}",
ad=ad) ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys()) for ad in CALL_TF_IMPLEMENTATIONS.keys())
def test_impl(self, ad="simple"): def test_impl(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad] call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x): def f_jax(x):
@ -192,21 +198,27 @@ class CallToTFTest(jtu.JaxTestCase):
for ad in CALL_TF_IMPLEMENTATIONS.keys() for ad in CALL_TF_IMPLEMENTATIONS.keys()
if ad != "none") if ad != "none")
def test_grad(self, ad="simple"): def test_grad(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad] call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x): def f_jax(x):
return 3. * jnp.sin(2. * x) return 3. * jnp.sin(2. * x)
def f_outside(x): def f_outside(x):
return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x) return 3. * call_tf(
lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x,
result_shape=jax.ShapeDtypeStruct((), np.float32))
x = 4. x = np.float32(4.)
self.assertAllClose(f_jax(x), f_outside(x)) self.assertAllClose(f_jax(x), f_outside(x),
check_dtypes=False)
grad_f = jax.grad(f_outside)(x) grad_f = jax.grad(f_outside)(x)
self.assertAllClose(jax.grad(f_jax)(x), grad_f) self.assertAllClose(jax.grad(f_jax)(x), grad_f,
check_dtypes=False)
def test_grad_pytree(self): def test_grad_pytree(self):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad call_tf = call_tf_full_ad
def f_jax(xy): def f_jax(xy):
@ -215,15 +227,19 @@ class CallToTFTest(jtu.JaxTestCase):
def f_outside(xy): def f_outside(xy):
dict_ab = call_tf( dict_ab = call_tf(
lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]), lambda xy: dict(a=tf.cast(2. * xy[0], np.float32),
b=tf.cast(xy[0] * xy[1], np.float32)),
xy, xy,
result_shape=dict(a=xy[0], b=xy[1])) result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32),
b=jax.ShapeDtypeStruct((), np.float32)))
return 3. * dict_ab["a"] + 4. * dict_ab["b"] return 3. * dict_ab["a"] + 4. * dict_ab["b"]
xy = (5., 6.) xy = (5., 6.)
self.assertAllClose(f_jax(xy), f_outside(xy)) self.assertAllClose(f_jax(xy), f_outside(xy),
check_dtypes=False)
res_jax = jax.grad(f_jax)(xy) res_jax = jax.grad(f_jax)(xy)
self.assertAllClose(res_jax, jax.grad(f_outside)(xy)) self.assertAllClose(res_jax, jax.grad(f_outside)(xy),
check_dtypes=False)
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -231,6 +247,7 @@ class CallToTFTest(jtu.JaxTestCase):
degree=degree) degree=degree)
for degree in [1, 2, 3, 4]) for degree in [1, 2, 3, 4])
def test_higher_order_grad(self, degree=4): def test_higher_order_grad(self, degree=4):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad call_tf = call_tf_full_ad
def f_jax(x): def f_jax(x):

View File

@ -247,7 +247,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
jax.effects_barrier() jax.effects_barrier()
@with_pure_and_io_callbacks @with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered): def test_callback_with_wrong_dtype_outputs(self, *, callback):
def _cb(): def _cb():
return np.array([1], np.float64) return np.array([1], np.float64)