[callback] Add a flag to implement host_callback in terms of io_callback.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
This commit is contained in:
George Necula 2024-03-29 13:36:20 +02:00
parent 2512843a56
commit a510f03ef8
6 changed files with 476 additions and 177 deletions

View File

@ -50,6 +50,8 @@ Remember to align the itemized text with the first line of an item within a list
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`#20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
@ -1451,7 +1453,7 @@ Changes:
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the

View File

@ -997,7 +997,11 @@ pytype_library(
pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
srcs = [
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
"experimental/host_callback.py",
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
],
visibility = ["//visibility:public"],
deps = [
":jax",

View File

@ -17,6 +17,7 @@
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@ -501,6 +502,7 @@ Still to do:
from __future__ import annotations
import atexit
import enum
from collections.abc import Sequence
import functools
import itertools
@ -510,6 +512,7 @@ import threading
import traceback
from typing import Any, Callable, cast
import jax
from jax._src import api
from jax._src import core
from jax._src import config
@ -517,6 +520,7 @@ from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.experimental import io_callback
from jax._src.interpreters import ad, batching, pxla
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -560,6 +564,15 @@ _HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
help=(
'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. '
'See https://github.com/google/jax/issues/20385.'
)
)
logger = logging.getLogger(__name__)
@ -591,6 +604,15 @@ XlaDevice = xla_client.Device
XlaLocalClient = xla_client.Client
DType = Any
class CallbackFlavor(enum.Enum):
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
See https://github.com/google/jax/issues/20385.
"""
IO_CALLBACK = 1 # uses jax.experimental.io_callback
PURE = 2 # uses jax.pure_callback
DEBUG = 3 # uses jax.debug.callback, valid only when there are no results
def _deprecated_id_tap(tap_func,
arg,
@ -598,6 +620,7 @@ def _deprecated_id_tap(tap_func,
result=None,
tap_with_device=False,
device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
@ -605,6 +628,7 @@ def _deprecated_id_tap(tap_func,
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func,
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns:
``arg``, or ``result`` if given.
@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func,
call_with_device=tap_with_device,
result_shape=None,
identity=True,
device_index=device_index)
device_index=device_index,
callback_flavor=callback_flavor)
if result is not None:
return result
@ -675,6 +703,7 @@ def _deprecated_id_print(arg,
device_index=0,
output_stream=None,
threshold=None,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs):
"""Like :func:`id_tap` with a printing tap function.
@ -682,6 +711,7 @@ def _deprecated_id_print(arg,
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
@ -697,6 +727,9 @@ def _deprecated_id_print(arg,
built-in ``print``. The string will be passed as
``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``.
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
@ -708,19 +741,22 @@ def _deprecated_id_print(arg,
arg,
result=result,
tap_with_device=tap_with_device,
device_index=device_index)
device_index=device_index,
callback_flavor=callback_flavor)
def _deprecated_call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False,
device_index=0):
device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK):
"""Make a call to the host, and expect a result.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
Args:
callback_func: The Python function to invoke on the host as
@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns:
the result of the ``callback_func`` invocation.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
if (not _HOST_CALLBACK_LEGACY.value and
callback_flavor is CallbackFlavor.DEBUG and
result_shape is not None):
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
"flavor of callback only when the `result_shape` is None. "
"See https://github.com/google/jax/issues/20385."
)
return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False,
device_index=device_index)
device_index=device_index, callback_flavor=callback_flavor)
# We need the wrapper function to have hash and equality defined since it is
@ -766,6 +814,11 @@ class _CallbackWrapper:
self.callback_func = callback_func
self.identity = identity
self.call_with_device = call_with_device
if not _HOST_CALLBACK_LEGACY.value and call_with_device:
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
" do not support `tap_with_device` and `call_with_device`. "
"See https://github.com/google/jax/issues/20385.")
def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))
@ -775,7 +828,16 @@ class _CallbackWrapper:
self.identity == other.identity and
self.call_with_device == other.call_with_device)
def __call__(self, arg, device, transforms):
def __call__(self, *args, **kwargs):
if _HOST_CALLBACK_LEGACY.value:
return self._call_legacy(*args, **kwargs)
else:
if self.identity:
# For id_tap, we pass empty transforms, for backwards compatibility
return self.callback_func(args[0], ())
return self.callback_func(*args, **kwargs)
def _call_legacy(self, arg, device, transforms):
if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device:
@ -797,14 +859,16 @@ def _call(callback_func: Callable,
result_shape=None,
call_with_device=False,
device_index=0,
identity=False):
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
identity=False,
callback_flavor=CallbackFlavor.IO_CALLBACK):
if _HOST_CALLBACK_LEGACY.value:
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
api.check_callable(callback_func)
flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args:
dispatch.check_arg(arg)
for arg_ in flat_args:
dispatch.check_arg(arg_)
# See definition of outside_call_p for what parameters it takes
params: dict[str, Any] = {}
# TODO: wrap function
@ -829,8 +893,27 @@ def _call(callback_func: Callable,
params["result_treedef"] = result_treedef
params["flat_results_aval"] = tuple(flat_results_aval)
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
if _HOST_CALLBACK_LEGACY.value:
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
else:
callback_device = jax.local_devices()[device_index]
sharding = jax.sharding.SingleDeviceSharding(callback_device)
callback_func = _CallbackWrapper(callback_func, identity,
call_with_device)
if callback_flavor is CallbackFlavor.DEBUG:
assert identity
jax.debug.callback(callback_func, arg)
return arg
elif callback_flavor is CallbackFlavor.PURE:
call_res = jax.pure_callback(callback_func, result_shape, arg,
sharding=sharding)
else:
call_res = io_callback(callback_func, result_shape, arg,
sharding=sharding,
ordered=True)
return call_res if not identity else arg
# We need the lock for when we use the CustomCall implementation of callbacks.
@ -855,7 +938,6 @@ def _print_tap_func(
threshold: the value of numpy.array2string threshold parameter.
**kwargs: all other keyword args are printed before printing `arg`.
"""
def emit_str(s: str):
if output_stream is not None:
output_stream.write(s + "\n")
@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
if not _HOST_CALLBACK_LEGACY.value:
jax.effects_barrier()
return
logging_name = logging_name or ""
logger.debug("barrier_wait[%s]: start", logging_name)
@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver():
_deprecation_msg = (
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
"is subsumed by the new JAX external callbacks. "
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
"See https://github.com/google/jax/issues/20385.")
_deprecations = {
# Added March 20, 2024

View File

@ -91,8 +91,10 @@ testing_stream = _TestingOutputStream()
def fun1(a):
"""Function used for several `id_tap` tests."""
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return y ** 2 # Some computation to make the gradient interesting
@ -253,6 +255,10 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait("HostCallbackTapTest.tearDown")
super().tearDown()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def test_tap_eval(self):
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
hcb.barrier_wait()
@ -320,6 +326,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
testing_stream.output)
def test_tap_with_device(self):
self.supported_only_in_legacy_mode()
def func2(x):
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
output_stream=testing_stream,
@ -335,6 +342,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_eval_exception(self):
if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error
def tap_err(*args, **kwargs):
raise ValueError("Some user message")
@ -345,19 +353,30 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
with self.assertRaisesRegex(
hcb.CallbackException,
re.compile("There were exceptions during callback processing. Last one was:.*"
"ValueError: Some user message", re.DOTALL)):
if hcb._HOST_CALLBACK_LEGACY.value:
ctx = self.assertRaisesRegex(
hcb.CallbackException,
re.compile("There were exceptions during callback processing. Last one was:.*"
"ValueError: Some user message", re.DOTALL))
else:
ctx = self.assertRaisesRegex(Exception, "Some user message")
with ctx:
func(0)
hcb.barrier_wait()
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
else:
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
what: x1
1""", testing_stream.output)
def test_tap_empty(self):
"""Tap empty arrays."""
@ -488,6 +507,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_jit_devices(self):
"""Running on multiple devices."""
self.supported_only_in_legacy_mode()
logging.info("%s: has devices %s", self._testMethodName, local_devices())
def func(x, device_id):
@ -830,19 +850,24 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
res = jax.jit(func)(0) # No error yet
with self.assertRaises(hcb.CallbackException):
hcb.barrier_wait()
if hcb._HOST_CALLBACK_LEGACY.value:
res = jax.jit(func)(0) # No error yet
with self.assertRaises(hcb.CallbackException):
hcb.barrier_wait()
# Even though the receiver thread raised, the main thread should still
# return 3.
self.assertEqual(3, res)
# We should have received all others
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
# Even though the receiver thread raised, the main thread should still
# return 3.
self.assertEqual(3, res)
# We should have received all others
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
else:
with self.assertRaisesRegex(Exception, "NotImplementedError"):
res = jax.jit(func)(0)
hcb.barrier_wait()
def test_tap_while(self):
"""Executing while, even without JIT uses compiled code"""
@ -878,7 +903,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3",
output_stream=testing_stream)
output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
grad_func = jax.grad(func)
arg = jnp.float32(5.)
@ -886,21 +912,22 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# making the Jaxpr does not print anything
hcb.barrier_wait()
treedef = jax.tree.structure(arg)
assertMultiLineStrippedEqual(
self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
device_index=0
identity=True
] b
_:f32[] = mul 2.00 c
d:f32[] = mul 2.00 1.00
e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
if hcb._HOST_CALLBACK_LEGACY.value:
treedef = jax.tree.structure(arg)
assertMultiLineStrippedEqual(
self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
device_index=0
identity=True
] b
_:f32[] = mul 2.00 c
d:f32[] = mul 2.00 1.00
e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output)
testing_stream.reset()
@ -914,9 +941,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_grad_simple(self):
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * hcb.id_print(y * 3., what="y * 3",
output_stream=testing_stream)
output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
grad_func = jax.grad(func)
@ -931,7 +960,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_grad_grad(self):
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * (y * 3.)
grad_func = jax.grad(jax.grad(func))
@ -952,7 +982,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def func(x):
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
result=(x * 4., x * 5.),
output_stream=testing_stream)
output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x4 + 2. * x5
x = jnp.float32(5.)
@ -967,15 +998,18 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_jvp_float0(self):
def f(x, yint):
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint),
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * yint
res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
self.assertAllClose((6., 0.6), res)
def test_tap_grad_float0(self):
def func(x, yint):
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x * yint.astype(x.dtype)
grad_func = jax.grad(func)
@ -993,7 +1027,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
x = hcb.id_print(x, result=x, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
return (3. * x[0], x[1])
def f_jax_vjp(x):
@ -1015,7 +1050,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
x = hcb.id_print(x, result=x, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
return (jnp.sin(x[0]), x[1])
def wrap_vjp(f, args, res_f_of_args):
@ -1059,32 +1095,52 @@ class HostCallbackTapTest(jtu.JaxTestCase):
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
vmap_fun1(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
[ 8.00 10.00]
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
[24.00 30.00]""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
[ 8.00 10.00]
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
[24.00 30.00]""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: a * 2
8.00
what: a * 2
10.00
what: y * 3
24.00
what: y * 3
30.00
""", testing_stream.output)
def test_tap_vmap_not_batched(self):
x = 3.
def func(y):
# x is not mapped, y is mapped
_, y = hcb.id_print((x, y), output_stream=testing_stream)
_, y = hcb.id_print((x, y), output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return x + y
vmap_func = jax.vmap(func)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
_ = vmap_func(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (None, 0)})]
( 3.00 [4.00 5.00] )""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (None, 0)})]
( 3.00 [4.00 5.00] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( 3.00 4.00 )
( 3.00 5.00 )
""", testing_stream.output)
def test_tap_vmap_vmap(self):
# A 2D tensor with x[i, j] = i + j using 2 vmap
def sum(x, y):
return hcb.id_print(x + y, output_stream=testing_stream)
return hcb.id_print(x + y, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
def sum_rows(xv, y):
return jax.vmap(sum, in_axes=(0, None))(xv, y)
@ -1097,22 +1153,44 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv)))
_ = sum_all(xv, yv)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
[[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]]""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
[[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]]""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
0
1
2
1
2
3
2
3
4
3
4
5
4
5
6
""", testing_stream.output)
def test_tap_vmap_while(self):
"""Vmap of while."""
def func(x):
# like max(x, 2)
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream)
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
x2 = lax.while_loop(
lambda x: x < 2, lambda x: hcb.id_print(
x + 1, where="body:x+1", output_stream=testing_stream), x1)
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream)
x + 1, where="body:x+1", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG), x1)
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return res
inputs = np.arange(5, dtype=np.int32)
@ -1121,72 +1199,93 @@ class HostCallbackTapTest(jtu.JaxTestCase):
jax.jit(jax.vmap(func))(inputs),
check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(
self, """
transforms: [('batch', {'batch_dims': (0,)})] where: before:x
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: after:x
[2 2 2 3 4]""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(
self, """
transforms: [('batch', {'batch_dims': (0,)})] where: before:x
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: after:x
[2 2 2 3 4]""", testing_stream.output)
else:
pass # order of vmaps is not guaranteed
def test_tap_vmap_while_tap_cond(self):
"""Vmap of while, with a tap in the conditional."""
def func(x):
# like max(x, 2)
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x1 = hcb.id_print(x, where="1", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
output_stream=testing_stream),
output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG),
lambda x: hcb.id_print(x + 1, where="w_b",
output_stream=testing_stream),
output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG),
x1)
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
res = hcb.id_print(x2, where="3", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return res
inputs = np.arange(5, dtype=np.int32)
res = jax.jit(jax.vmap(func))(inputs)
hcb.barrier_wait()
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] where: 1
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True True False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[False False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: 3
[2 2 2 3 4]""", testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] where: 1
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True True False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[False False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: 3
[2 2 2 3 4]""", testing_stream.output)
else:
pass # order of vmap is not guaranteed
def test_tap_transforms_doc(self):
# Examples from the documentation
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return y * x
print(f"impl = {power3(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1197,32 +1296,41 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return primals, tangents
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
print_tangents((x, y))
return y * x
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: tangents
( 0.1 0.6 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )
what: tangents
( 0.1 0.6 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait()
# Only the primals by default
expected = """
what: x,x^2
( 3. 9. )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1236,7 +1344,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
@ -1244,18 +1353,26 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
# Must use the output of print_cotangents
(x1, y1) = print_cotangents((x, y))
return y1 * x1
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: cotangents
( 9. 3. )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )
what: cotangents
( 9. 3. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )
what: cotangents
( 9.0 3.0 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1263,43 +1380,82 @@ class HostCallbackTapTest(jtu.JaxTestCase):
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
else:
expected = """
what: x,x^2
( 2.0 4.0 )
what: x,x^2
( 3.0 9.0 )
what: cotangents
( 4.0 2.0 )
what: cotangents
( 9.0 3.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: x,x^2
( 27. 729. )
what: x,x^2
( 3. 9. )"""
if hcb._HOST_CALLBACK_LEGACY.value:
expected = """
what: x,x^2
( 3. 9. )
what: x,x^2
( 27. 729. )
what: x,x^2
( 3. 9. )"""
else:
expected = """
what: x,x^2
( 3.0 9.0 )
what: x,x^2
( 27.0 729.0 )
what: x,x^2
( 3.0 9.0 )
"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
def test_tap_pmap(self):
self.supported_only_in_legacy_mode()
if len(local_devices()) < 2:
raise SkipTest("test requires at least 2 devices")
@ -1326,6 +1482,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 4 16 )""")
def test_tap_pmap_vmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
shape = (nr_devices, 3)
@ -1353,6 +1510,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 + k
self.supported_only_in_legacy_mode()
nr_devices = len(local_devices())
if nr_devices % 2 != 0:
raise SkipTest("test works only on even number of devices")
@ -1386,6 +1544,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_extra(self):
"""pmap of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j
self.supported_only_in_legacy_mode()
nr_devices = len(local_devices())
if nr_devices != 2:
raise SkipTest("test works only on 2 devices")
@ -1419,6 +1578,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[[203.00 205.00 207.00]]""")
def test_tap_jvp_pmap_vmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices())
shape = (nr_devices, 2, 3)
@ -1445,6 +1605,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[220.00 222.00 224.00]]""")
def test_tap_vmap_pmap(self):
self.supported_only_in_legacy_mode()
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices())
shape = (2, nr_devices, 3)
@ -1472,6 +1633,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning()
def test_tap_jit_pmap_extra(self):
"""jit of a pmap surrounded by extra code."""
self.supported_only_in_legacy_mode()
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
assert nr_devices in (1, 2)
@ -1540,6 +1702,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@jtu.sample_product(device_index=[0, 1])
def test_tap_pjit(self, device_index=0):
self.supported_only_in_legacy_mode()
if (device_index != 0 and
not hcb._HOST_CALLBACK_OUTFEED.value and
jtu.test_device_matches(["cpu"])):
@ -1589,7 +1752,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_jvp
def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
@ -1633,7 +1796,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_scan_custom_vjp(self):
"""custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_vjp
def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
@ -1773,7 +1936,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
from jax.experimental.ode import odeint
def f(x, t, k):
x = hcb.id_print(x)
x = hcb.id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG)
return -k * x
def loss(k=1.0):
@ -1785,7 +1948,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
def test_tap_remat_0(self):
def f(i, k):
x = hcb.id_print(k + i, output_stream=testing_stream)
x = hcb.id_print(k + i, output_stream=testing_stream,
callback_flavor=hcb.CallbackFlavor.DEBUG)
return k * x
def loss(k):
@ -1804,6 +1968,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
use_remat=["old", "new", "none"],
)
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
self.supported_only_in_legacy_mode()
if use_remat == "old": raise SkipTest()
def f(x):
@ -1880,6 +2045,10 @@ class HostCallbackCallTest(jtu.JaxTestCase):
hcb.barrier_wait("HostCallbackCallTest.tearDown")
super().tearDown()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def call_log_testing_stream(self, func, arg, *, result_shape, name=""):
"""Call `func` and log inputs and outputs to the testing stream"""
@ -1916,6 +2085,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * arg, fun(arg))
r = jax.make_jaxpr(fun)(arg)
self.assertEqual(count[0], 1)
@jtu.sample_product(
@ -2124,6 +2294,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
helper_print_optimized_hlo(fun2, m)
def test_call_with_device(self):
self.supported_only_in_legacy_mode()
def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
return x
@ -2139,6 +2310,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
Called with 3.00""")
def test_call_pmap(self):
self.supported_only_in_legacy_mode()
# Works for 1 or 2 devices
def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
@ -2163,10 +2335,14 @@ class HostCallbackCallTest(jtu.JaxTestCase):
def f_outside(x): return x
def fun(x):
return hcb.call(f_outside, x, result_shape=x)
return hcb.call(f_outside, x, result_shape=x,
callback_flavor=hcb.CallbackFlavor.PURE)
with self.assertRaisesRegex(NotImplementedError,
"batching rules are implemented only for id_tap, not for call"):
if hcb._HOST_CALLBACK_LEGACY.value:
with self.assertRaisesRegex(NotImplementedError,
"batching rules are implemented only for id_tap, not for call"):
jax.vmap(fun)(np.ones((2, 3)))
else:
jax.vmap(fun)(np.ones((2, 3)))
@jtu.sample_product(device_index=[0, 1])
@ -2256,6 +2432,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
hcb.barrier_wait("Waiting for error")
def test_call_error_callback_throws_exception(self):
self.supported_only_in_legacy_mode()
def f_outside(x):
raise ValueError("user exception")
def fun(x):
@ -2265,6 +2442,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"ValueError: user exception")
def test_call_error_callback_returns_unexpected_shape(self):
self.supported_only_in_legacy_mode()
def fun(x):
return hcb.call(lambda x: (x, x), x, result_shape=x)
@ -2272,6 +2450,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"Callback func .* should have returned a result with pytree")
def test_call_error_then_compute(self):
self.supported_only_in_legacy_mode()
# Continue computation on device after error
def f_outside(x):
raise ValueError("user exception")
@ -2283,7 +2462,9 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"ValueError: user exception")
def call_jax_other_device(jax_outside_fun, arg, *, device):
def call_jax_other_device(
jax_outside_fun, arg, *, device,
callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK):
"""Calls a JAX function on a specific device with simple support for reverse AD.
Functions whose name starts with "jax_outside" are called on another device,
@ -2296,7 +2477,8 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
@jax.custom_vjp
def make_call(arg):
return hcb.call(run_jax_outside_fun, arg,
result_shape=jax.eval_shape(jax_outside_fun, arg))
result_shape=jax.eval_shape(jax_outside_fun, arg),
callback_flavor=callback_flavor)
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(arg):
@ -2323,6 +2505,8 @@ class CallJaxTest(jtu.JaxTestCase):
"""Tests using `call_jax_other_device`."""
def setUp(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
if xla_bridge.using_pjrt_c_api():
@ -2337,6 +2521,7 @@ class CallJaxTest(jtu.JaxTestCase):
self.outside_device = jax.devices("cpu")[1]
super().setUp()
def test_jax_impl(self):
def f_jax(x):
return jnp.sin(x)
@ -2404,6 +2589,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
raise SkipTest("host_callback not implemented in PJRT C API")
super().setUp()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
has_input_token=True, has_output_token=True):
"""Check that the rewrite of func(*args) matches expected."""
@ -2624,7 +2813,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
def test_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_jvp
def f(x):
return x * hcb.id_print(x)
@ -2706,7 +2895,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
def test_scan_custom_vjp(self):
"""custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives."""
self.supported_only_in_legacy_mode()
@jax.custom_vjp
def f(x):
return x * hcb.id_print(x)
@ -2849,6 +3038,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (c, d, e) }""", tap_scalar, [np.int32(3)])
def test_pmap(self):
self.supported_only_in_legacy_mode()
def f(xv):
jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
axis_name="i")(xv)

View File

@ -25,8 +25,8 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax.experimental import host_callback as hcb
@ -53,7 +53,8 @@ def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy,
tf_fun(arg)),
arg, result_shape=result_shape)
arg, result_shape=result_shape,
callback_flavor=hcb.CallbackFlavor.DEBUG)
def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
@ -166,12 +167,17 @@ class CallToTFTest(jtu.JaxTestCase):
raise unittest.SkipTest("host_callback not implemented in PJRT C API")
super().setUp()
def supported_only_in_legacy_mode(self):
if not hcb._HOST_CALLBACK_LEGACY.value:
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
@parameterized.named_parameters(
dict(
testcase_name=f"_{ad=}",
ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys())
def test_impl(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x):
@ -192,21 +198,27 @@ class CallToTFTest(jtu.JaxTestCase):
for ad in CALL_TF_IMPLEMENTATIONS.keys()
if ad != "none")
def test_grad(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x):
return 3. * jnp.sin(2. * x)
def f_outside(x):
return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x)
return 3. * call_tf(
lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x,
result_shape=jax.ShapeDtypeStruct((), np.float32))
x = 4.
self.assertAllClose(f_jax(x), f_outside(x))
x = np.float32(4.)
self.assertAllClose(f_jax(x), f_outside(x),
check_dtypes=False)
grad_f = jax.grad(f_outside)(x)
self.assertAllClose(jax.grad(f_jax)(x), grad_f)
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
check_dtypes=False)
def test_grad_pytree(self):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad
def f_jax(xy):
@ -215,15 +227,19 @@ class CallToTFTest(jtu.JaxTestCase):
def f_outside(xy):
dict_ab = call_tf(
lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]),
lambda xy: dict(a=tf.cast(2. * xy[0], np.float32),
b=tf.cast(xy[0] * xy[1], np.float32)),
xy,
result_shape=dict(a=xy[0], b=xy[1]))
result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32),
b=jax.ShapeDtypeStruct((), np.float32)))
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
xy = (5., 6.)
self.assertAllClose(f_jax(xy), f_outside(xy))
self.assertAllClose(f_jax(xy), f_outside(xy),
check_dtypes=False)
res_jax = jax.grad(f_jax)(xy)
self.assertAllClose(res_jax, jax.grad(f_outside)(xy))
self.assertAllClose(res_jax, jax.grad(f_outside)(xy),
check_dtypes=False)
@parameterized.named_parameters(
dict(
@ -231,6 +247,7 @@ class CallToTFTest(jtu.JaxTestCase):
degree=degree)
for degree in [1, 2, 3, 4])
def test_higher_order_grad(self, degree=4):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad
def f_jax(x):

View File

@ -247,7 +247,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered):
def test_callback_with_wrong_dtype_outputs(self, *, callback):
def _cb():
return np.array([1], np.float64)