mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue #20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
This commit is contained in:
parent
2512843a56
commit
a510f03ef8
@ -50,6 +50,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
|
||||
* The `jax.experimental.host_callback` module is deprecated.
|
||||
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
|
||||
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
|
||||
new callbacks. See {jax-issue}`#20385` for a discussion.
|
||||
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
||||
that cannot be converted to a JAX array now results in an exception.
|
||||
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
|
||||
@ -1451,7 +1453,7 @@ Changes:
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#8678`).
|
||||
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
|
||||
|
@ -997,7 +997,11 @@ pytype_library(
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_host_callback",
|
||||
srcs = ["experimental/host_callback.py"],
|
||||
srcs = [
|
||||
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
||||
"experimental/host_callback.py",
|
||||
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jax",
|
||||
|
@ -17,6 +17,7 @@
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
This module introduces the host callback functions :func:`call`,
|
||||
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
|
||||
@ -501,6 +502,7 @@ Still to do:
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import enum
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import itertools
|
||||
@ -510,6 +512,7 @@ import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
@ -517,6 +520,7 @@ from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import io_callback
|
||||
from jax._src.interpreters import ad, batching, pxla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -560,6 +564,15 @@ _HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
|
||||
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
||||
)
|
||||
)
|
||||
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
|
||||
'jax_host_callback_legacy',
|
||||
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
|
||||
help=(
|
||||
'Use old implementation of host_callback, documented in the module docstring.'
|
||||
'If False, use the jax.experimental.io_callback implementation. '
|
||||
'See https://github.com/google/jax/issues/20385.'
|
||||
)
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -591,6 +604,15 @@ XlaDevice = xla_client.Device
|
||||
XlaLocalClient = xla_client.Client
|
||||
DType = Any
|
||||
|
||||
class CallbackFlavor(enum.Enum):
|
||||
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
|
||||
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
"""
|
||||
IO_CALLBACK = 1 # uses jax.experimental.io_callback
|
||||
PURE = 2 # uses jax.pure_callback
|
||||
DEBUG = 3 # uses jax.debug.callback, valid only when there are no results
|
||||
|
||||
|
||||
def _deprecated_id_tap(tap_func,
|
||||
arg,
|
||||
@ -598,6 +620,7 @@ def _deprecated_id_tap(tap_func,
|
||||
result=None,
|
||||
tap_with_device=False,
|
||||
device_index=0,
|
||||
callback_flavor=CallbackFlavor.IO_CALLBACK,
|
||||
**kwargs):
|
||||
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
|
||||
|
||||
@ -605,6 +628,7 @@ def _deprecated_id_tap(tap_func,
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
``id_tap`` behaves semantically like the identity function but has the
|
||||
side-effect that a user-defined Python function is called with the runtime
|
||||
@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func,
|
||||
device_index: specifies from which device the tap function is invoked in a
|
||||
SPMD program. Works only when using the outfeed implementation mechanism,
|
||||
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
|
||||
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
|
||||
the flavor of callback to use.
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
Returns:
|
||||
``arg``, or ``result`` if given.
|
||||
@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func,
|
||||
call_with_device=tap_with_device,
|
||||
result_shape=None,
|
||||
identity=True,
|
||||
device_index=device_index)
|
||||
device_index=device_index,
|
||||
callback_flavor=callback_flavor)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
@ -675,6 +703,7 @@ def _deprecated_id_print(arg,
|
||||
device_index=0,
|
||||
output_stream=None,
|
||||
threshold=None,
|
||||
callback_flavor=CallbackFlavor.IO_CALLBACK,
|
||||
**kwargs):
|
||||
"""Like :func:`id_tap` with a printing tap function.
|
||||
|
||||
@ -682,6 +711,7 @@ def _deprecated_id_print(arg,
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
On each invocation of the printing tap, the ``kwargs`` if present
|
||||
will be printed first (sorted by keys). Then arg will be printed,
|
||||
@ -697,6 +727,9 @@ def _deprecated_id_print(arg,
|
||||
built-in ``print``. The string will be passed as
|
||||
``output_stream.write(s)``.
|
||||
* ``threshold`` is passed to ``numpy.array2string``.
|
||||
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
|
||||
the flavor of callback to use.
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
For more details see the :mod:`jax.experimental.host_callback` module documentation.
|
||||
"""
|
||||
@ -708,19 +741,22 @@ def _deprecated_id_print(arg,
|
||||
arg,
|
||||
result=result,
|
||||
tap_with_device=tap_with_device,
|
||||
device_index=device_index)
|
||||
device_index=device_index,
|
||||
callback_flavor=callback_flavor)
|
||||
|
||||
|
||||
def _deprecated_call(callback_func: Callable, arg, *,
|
||||
result_shape=None,
|
||||
call_with_device=False,
|
||||
device_index=0):
|
||||
device_index=0,
|
||||
callback_flavor=CallbackFlavor.IO_CALLBACK):
|
||||
"""Make a call to the host, and expect a result.
|
||||
|
||||
.. warning::
|
||||
The host_callback APIs are deprecated as of March 20, 2024.
|
||||
The functionality is subsumed by the
|
||||
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
Args:
|
||||
callback_func: The Python function to invoke on the host as
|
||||
@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
|
||||
device_index: specifies from which device the tap function is invoked in a
|
||||
SPMD program. Works only when using the outfeed implementation mechanism,
|
||||
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
|
||||
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
|
||||
the flavor of callback to use.
|
||||
See https://github.com/google/jax/issues/20385.
|
||||
|
||||
Returns:
|
||||
the result of the ``callback_func`` invocation.
|
||||
|
||||
For more details see the :mod:`jax.experimental.host_callback` module documentation.
|
||||
"""
|
||||
if (not _HOST_CALLBACK_LEGACY.value and
|
||||
callback_flavor is CallbackFlavor.DEBUG and
|
||||
result_shape is not None):
|
||||
raise NotImplementedError(
|
||||
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
|
||||
"flavor of callback only when the `result_shape` is None. "
|
||||
"See https://github.com/google/jax/issues/20385."
|
||||
)
|
||||
return _call(callback_func, arg, result_shape=result_shape,
|
||||
call_with_device=call_with_device, identity=False,
|
||||
device_index=device_index)
|
||||
device_index=device_index, callback_flavor=callback_flavor)
|
||||
|
||||
|
||||
# We need the wrapper function to have hash and equality defined since it is
|
||||
@ -766,6 +814,11 @@ class _CallbackWrapper:
|
||||
self.callback_func = callback_func
|
||||
self.identity = identity
|
||||
self.call_with_device = call_with_device
|
||||
if not _HOST_CALLBACK_LEGACY.value and call_with_device:
|
||||
raise NotImplementedError(
|
||||
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
|
||||
" do not support `tap_with_device` and `call_with_device`. "
|
||||
"See https://github.com/google/jax/issues/20385.")
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.callback_func, self.identity, self.call_with_device))
|
||||
@ -775,7 +828,16 @@ class _CallbackWrapper:
|
||||
self.identity == other.identity and
|
||||
self.call_with_device == other.call_with_device)
|
||||
|
||||
def __call__(self, arg, device, transforms):
|
||||
def __call__(self, *args, **kwargs):
|
||||
if _HOST_CALLBACK_LEGACY.value:
|
||||
return self._call_legacy(*args, **kwargs)
|
||||
else:
|
||||
if self.identity:
|
||||
# For id_tap, we pass empty transforms, for backwards compatibility
|
||||
return self.callback_func(args[0], ())
|
||||
return self.callback_func(*args, **kwargs)
|
||||
|
||||
def _call_legacy(self, arg, device, transforms):
|
||||
if self.identity:
|
||||
# For id_tap, we pass the transforms, for backwards compatibility
|
||||
if self.call_with_device:
|
||||
@ -797,14 +859,16 @@ def _call(callback_func: Callable,
|
||||
result_shape=None,
|
||||
call_with_device=False,
|
||||
device_index=0,
|
||||
identity=False):
|
||||
identity=False,
|
||||
callback_flavor=CallbackFlavor.IO_CALLBACK):
|
||||
if _HOST_CALLBACK_LEGACY.value:
|
||||
# Lazy initialization
|
||||
_initialize_outfeed_receiver(
|
||||
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
|
||||
api.check_callable(callback_func)
|
||||
flat_args, arg_treedef = tree_util.tree_flatten(arg)
|
||||
for arg in flat_args:
|
||||
dispatch.check_arg(arg)
|
||||
for arg_ in flat_args:
|
||||
dispatch.check_arg(arg_)
|
||||
# See definition of outside_call_p for what parameters it takes
|
||||
params: dict[str, Any] = {}
|
||||
# TODO: wrap function
|
||||
@ -829,8 +893,27 @@ def _call(callback_func: Callable,
|
||||
|
||||
params["result_treedef"] = result_treedef
|
||||
params["flat_results_aval"] = tuple(flat_results_aval)
|
||||
|
||||
if _HOST_CALLBACK_LEGACY.value:
|
||||
flat_results = outside_call_p.bind(*flat_args, **params)
|
||||
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
|
||||
else:
|
||||
callback_device = jax.local_devices()[device_index]
|
||||
sharding = jax.sharding.SingleDeviceSharding(callback_device)
|
||||
callback_func = _CallbackWrapper(callback_func, identity,
|
||||
call_with_device)
|
||||
if callback_flavor is CallbackFlavor.DEBUG:
|
||||
assert identity
|
||||
jax.debug.callback(callback_func, arg)
|
||||
return arg
|
||||
elif callback_flavor is CallbackFlavor.PURE:
|
||||
call_res = jax.pure_callback(callback_func, result_shape, arg,
|
||||
sharding=sharding)
|
||||
else:
|
||||
call_res = io_callback(callback_func, result_shape, arg,
|
||||
sharding=sharding,
|
||||
ordered=True)
|
||||
return call_res if not identity else arg
|
||||
|
||||
|
||||
# We need the lock for when we use the CustomCall implementation of callbacks.
|
||||
@ -855,7 +938,6 @@ def _print_tap_func(
|
||||
threshold: the value of numpy.array2string threshold parameter.
|
||||
**kwargs: all other keyword args are printed before printing `arg`.
|
||||
"""
|
||||
|
||||
def emit_str(s: str):
|
||||
if output_stream is not None:
|
||||
output_stream.write(s + "\n")
|
||||
@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
|
||||
|
||||
For more details see the :mod:`jax.experimental.host_callback` module documentation.
|
||||
"""
|
||||
if not _HOST_CALLBACK_LEGACY.value:
|
||||
jax.effects_barrier()
|
||||
return
|
||||
|
||||
logging_name = logging_name or ""
|
||||
logger.debug("barrier_wait[%s]: start", logging_name)
|
||||
|
||||
@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver():
|
||||
_deprecation_msg = (
|
||||
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
|
||||
"is subsumed by the new JAX external callbacks. "
|
||||
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
|
||||
"See https://github.com/google/jax/issues/20385.")
|
||||
|
||||
_deprecations = {
|
||||
# Added March 20, 2024
|
||||
|
@ -91,8 +91,10 @@ testing_stream = _TestingOutputStream()
|
||||
|
||||
def fun1(a):
|
||||
"""Function used for several `id_tap` tests."""
|
||||
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
|
||||
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
|
||||
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return y ** 2 # Some computation to make the gradient interesting
|
||||
|
||||
|
||||
@ -253,6 +255,10 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait("HostCallbackTapTest.tearDown")
|
||||
super().tearDown()
|
||||
|
||||
def supported_only_in_legacy_mode(self):
|
||||
if not hcb._HOST_CALLBACK_LEGACY.value:
|
||||
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
|
||||
|
||||
def test_tap_eval(self):
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
||||
hcb.barrier_wait()
|
||||
@ -320,6 +326,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
testing_stream.output)
|
||||
|
||||
def test_tap_with_device(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
def func2(x):
|
||||
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
||||
output_stream=testing_stream,
|
||||
@ -335,6 +342,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def test_tap_eval_exception(self):
|
||||
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
|
||||
# Simulate a tap error
|
||||
def tap_err(*args, **kwargs):
|
||||
raise ValueError("Some user message")
|
||||
@ -345,19 +353,30 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
||||
return x3
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
ctx = self.assertRaisesRegex(
|
||||
hcb.CallbackException,
|
||||
re.compile("There were exceptions during callback processing. Last one was:.*"
|
||||
"ValueError: Some user message", re.DOTALL)):
|
||||
"ValueError: Some user message", re.DOTALL))
|
||||
else:
|
||||
ctx = self.assertRaisesRegex(Exception, "Some user message")
|
||||
|
||||
with ctx:
|
||||
func(0)
|
||||
hcb.barrier_wait()
|
||||
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
# We should have received everything before the error
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x1
|
||||
1
|
||||
what: x3
|
||||
3""", testing_stream.output)
|
||||
else:
|
||||
# We should have received everything before the error
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x1
|
||||
1""", testing_stream.output)
|
||||
|
||||
def test_tap_empty(self):
|
||||
"""Tap empty arrays."""
|
||||
@ -488,6 +507,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_jit_devices(self):
|
||||
"""Running on multiple devices."""
|
||||
self.supported_only_in_legacy_mode()
|
||||
logging.info("%s: has devices %s", self._testMethodName, local_devices())
|
||||
|
||||
def func(x, device_id):
|
||||
@ -830,6 +850,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
||||
return x3
|
||||
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
res = jax.jit(func)(0) # No error yet
|
||||
with self.assertRaises(hcb.CallbackException):
|
||||
hcb.barrier_wait()
|
||||
@ -843,6 +864,10 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
1
|
||||
what: x3
|
||||
3""", testing_stream.output)
|
||||
else:
|
||||
with self.assertRaisesRegex(Exception, "NotImplementedError"):
|
||||
res = jax.jit(func)(0)
|
||||
hcb.barrier_wait()
|
||||
|
||||
def test_tap_while(self):
|
||||
"""Executing while, even without JIT uses compiled code"""
|
||||
@ -878,7 +903,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
# The output of id_print is not needed for backwards pass
|
||||
def func(x):
|
||||
return 2. * hcb.id_print(x * 3., what="x * 3",
|
||||
output_stream=testing_stream)
|
||||
output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
|
||||
grad_func = jax.grad(func)
|
||||
arg = jnp.float32(5.)
|
||||
@ -886,6 +912,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
# making the Jaxpr does not print anything
|
||||
hcb.barrier_wait()
|
||||
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
treedef = jax.tree.structure(arg)
|
||||
assertMultiLineStrippedEqual(
|
||||
self, f"""
|
||||
@ -914,9 +941,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_grad_simple(self):
|
||||
def func(x):
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x * hcb.id_print(y * 3., what="y * 3",
|
||||
output_stream=testing_stream)
|
||||
output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
|
||||
grad_func = jax.grad(func)
|
||||
|
||||
@ -931,7 +960,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_grad_grad(self):
|
||||
def func(x):
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x * (y * 3.)
|
||||
|
||||
grad_func = jax.grad(jax.grad(func))
|
||||
@ -952,7 +982,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def func(x):
|
||||
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
|
||||
result=(x * 4., x * 5.),
|
||||
output_stream=testing_stream)
|
||||
output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x4 + 2. * x5
|
||||
|
||||
x = jnp.float32(5.)
|
||||
@ -967,15 +998,18 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_jvp_float0(self):
|
||||
def f(x, yint):
|
||||
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
|
||||
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint),
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x * yint
|
||||
|
||||
res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
|
||||
self.assertAllClose((6., 0.6), res)
|
||||
|
||||
def test_tap_grad_float0(self):
|
||||
|
||||
def func(x, yint):
|
||||
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
|
||||
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x * yint.astype(x.dtype)
|
||||
|
||||
grad_func = jax.grad(func)
|
||||
@ -993,7 +1027,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
x = (np.array([.7, .8], dtype=np.float32),
|
||||
np.array([11, 12, 13], dtype=np.int32))
|
||||
def f_jax(x):
|
||||
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
|
||||
x = hcb.id_print(x, result=x, output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
|
||||
return (3. * x[0], x[1])
|
||||
|
||||
def f_jax_vjp(x):
|
||||
@ -1015,7 +1050,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
x = (np.array([.7, .8], dtype=np.float32),
|
||||
np.array([11, 12, 13], dtype=np.int32))
|
||||
def f_jax(x):
|
||||
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
|
||||
x = hcb.id_print(x, result=x, output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important
|
||||
return (jnp.sin(x[0]), x[1])
|
||||
|
||||
def wrap_vjp(f, args, res_f_of_args):
|
||||
@ -1059,32 +1095,52 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
||||
vmap_fun1(vargs)
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
|
||||
[ 8.00 10.00]
|
||||
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
|
||||
[24.00 30.00]""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
8.00
|
||||
what: a * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
24.00
|
||||
what: y * 3
|
||||
30.00
|
||||
""", testing_stream.output)
|
||||
|
||||
def test_tap_vmap_not_batched(self):
|
||||
x = 3.
|
||||
|
||||
def func(y):
|
||||
# x is not mapped, y is mapped
|
||||
_, y = hcb.id_print((x, y), output_stream=testing_stream)
|
||||
_, y = hcb.id_print((x, y), output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return x + y
|
||||
|
||||
vmap_func = jax.vmap(func)
|
||||
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
||||
_ = vmap_func(vargs)
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: [('batch', {'batch_dims': (None, 0)})]
|
||||
( 3.00 [4.00 5.00] )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( 3.00 4.00 )
|
||||
( 3.00 5.00 )
|
||||
""", testing_stream.output)
|
||||
|
||||
def test_tap_vmap_vmap(self):
|
||||
# A 2D tensor with x[i, j] = i + j using 2 vmap
|
||||
def sum(x, y):
|
||||
return hcb.id_print(x + y, output_stream=testing_stream)
|
||||
return hcb.id_print(x + y, output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
|
||||
def sum_rows(xv, y):
|
||||
return jax.vmap(sum, in_axes=(0, None))(xv, y)
|
||||
@ -1097,22 +1153,44 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
# assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv)))
|
||||
_ = sum_all(xv, yv)
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
|
||||
[[0 1 2 3 4]
|
||||
[1 2 3 4 5]
|
||||
[2 3 4 5 6]]""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
0
|
||||
1
|
||||
2
|
||||
1
|
||||
2
|
||||
3
|
||||
2
|
||||
3
|
||||
4
|
||||
3
|
||||
4
|
||||
5
|
||||
4
|
||||
5
|
||||
6
|
||||
""", testing_stream.output)
|
||||
|
||||
def test_tap_vmap_while(self):
|
||||
"""Vmap of while."""
|
||||
|
||||
def func(x):
|
||||
# like max(x, 2)
|
||||
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream)
|
||||
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
x2 = lax.while_loop(
|
||||
lambda x: x < 2, lambda x: hcb.id_print(
|
||||
x + 1, where="body:x+1", output_stream=testing_stream), x1)
|
||||
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream)
|
||||
x + 1, where="body:x+1", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG), x1)
|
||||
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return res
|
||||
|
||||
inputs = np.arange(5, dtype=np.int32)
|
||||
@ -1121,6 +1199,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
jax.jit(jax.vmap(func))(inputs),
|
||||
check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
assertMultiLineStrippedEqual(
|
||||
self, """
|
||||
transforms: [('batch', {'batch_dims': (0,)})] where: before:x
|
||||
@ -1131,25 +1210,32 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[2 3 3 4 5]
|
||||
transforms: [('batch', {'batch_dims': (0,)})] where: after:x
|
||||
[2 2 2 3 4]""", testing_stream.output)
|
||||
else:
|
||||
pass # order of vmaps is not guaranteed
|
||||
|
||||
def test_tap_vmap_while_tap_cond(self):
|
||||
"""Vmap of while, with a tap in the conditional."""
|
||||
|
||||
def func(x):
|
||||
# like max(x, 2)
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
|
||||
output_stream=testing_stream),
|
||||
output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG),
|
||||
lambda x: hcb.id_print(x + 1, where="w_b",
|
||||
output_stream=testing_stream),
|
||||
output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG),
|
||||
x1)
|
||||
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
||||
res = hcb.id_print(x2, where="3", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return res
|
||||
|
||||
inputs = np.arange(5, dtype=np.int32)
|
||||
res = jax.jit(jax.vmap(func))(inputs)
|
||||
hcb.barrier_wait()
|
||||
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: [('batch', {'batch_dims': (0,)})] where: 1
|
||||
[0 1 2 3 4]
|
||||
@ -1165,28 +1251,41 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[False False False False False]
|
||||
transforms: [('batch', {'batch_dims': (0,)})] where: 3
|
||||
[2 2 2 3 4]""", testing_stream.output)
|
||||
else:
|
||||
pass # order of vmap is not guaranteed
|
||||
|
||||
def test_tap_transforms_doc(self):
|
||||
# Examples from the documentation
|
||||
def power3(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return y * x
|
||||
|
||||
print(f"impl = {power3(3.)}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1197,32 +1296,41 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
@print_tangents.defjvp
|
||||
def print_tangents_jvp(primals, tangents):
|
||||
arg_dot, = tangents
|
||||
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
|
||||
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return primals, tangents
|
||||
|
||||
def power3_with_tangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
print_tangents((x, y))
|
||||
return y * x
|
||||
|
||||
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
what: tangents
|
||||
( 0.1 0.6 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"grad = {jax.grad(power3)(3.)}")
|
||||
hcb.barrier_wait()
|
||||
# Only the primals by default
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1236,7 +1344,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
return print_cotangents(arg), None
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def print_cotangents_bwd(residual, ct_b):
|
||||
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
|
||||
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return ct_b,
|
||||
|
||||
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
|
||||
@ -1244,18 +1353,26 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def power3_with_cotangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
# Must use the output of print_cotangents
|
||||
(x1, y1) = print_cotangents((x, y))
|
||||
return y1 * x1
|
||||
|
||||
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
what: cotangents
|
||||
( 9. 3. )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
what: cotangents
|
||||
( 9.0 3.0 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1263,32 +1380,61 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 2.0 4.0 )
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 2.0 4.0 )
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
|
||||
( [4. 9.] [2. 3.] )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 2.0 4.0 )
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
what: cotangents
|
||||
( 4.0 2.0 )
|
||||
what: cotangents
|
||||
( 9.0 3.0 )
|
||||
"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
|
||||
hcb.barrier_wait()
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
@ -1296,10 +1442,20 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
( 27. 729. )
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
else:
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
what: x,x^2
|
||||
( 27.0 729.0 )
|
||||
what: x,x^2
|
||||
( 3.0 9.0 )
|
||||
"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_pmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
if len(local_devices()) < 2:
|
||||
raise SkipTest("test requires at least 2 devices")
|
||||
|
||||
@ -1326,6 +1482,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
( 4 16 )""")
|
||||
|
||||
def test_tap_pmap_vmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
# A matrix M[ij] = i * 10 + j
|
||||
nr_devices = len(local_devices())
|
||||
shape = (nr_devices, 3)
|
||||
@ -1353,6 +1510,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_pmap_pmap_vmap(self):
|
||||
# A matrix M[ijk] = i * 100 + j * 10 + k
|
||||
self.supported_only_in_legacy_mode()
|
||||
nr_devices = len(local_devices())
|
||||
if nr_devices % 2 != 0:
|
||||
raise SkipTest("test works only on even number of devices")
|
||||
@ -1386,6 +1544,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def test_tap_pmap_pmap_extra(self):
|
||||
"""pmap of a pmap surrounded by extra code."""
|
||||
# A matrix M[ij] = i * 10 + j
|
||||
self.supported_only_in_legacy_mode()
|
||||
nr_devices = len(local_devices())
|
||||
if nr_devices != 2:
|
||||
raise SkipTest("test works only on 2 devices")
|
||||
@ -1419,6 +1578,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[[203.00 205.00 207.00]]""")
|
||||
|
||||
def test_tap_jvp_pmap_vmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
# A matrix M[ijk] = i * 100 + j * 10 * k
|
||||
nr_devices = len(local_devices())
|
||||
shape = (nr_devices, 2, 3)
|
||||
@ -1445,6 +1605,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[220.00 222.00 224.00]]""")
|
||||
|
||||
def test_tap_vmap_pmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
# A matrix M[ijk] = i * 100 + j * 10 * k
|
||||
nr_devices = len(local_devices())
|
||||
shape = (2, nr_devices, 3)
|
||||
@ -1472,6 +1633,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
@ignore_jit_of_pmap_warning()
|
||||
def test_tap_jit_pmap_extra(self):
|
||||
"""jit of a pmap surrounded by extra code."""
|
||||
self.supported_only_in_legacy_mode()
|
||||
# A matrix M[ij] = i * 10 + j
|
||||
nr_devices = len(local_devices())
|
||||
assert nr_devices in (1, 2)
|
||||
@ -1540,6 +1702,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(device_index=[0, 1])
|
||||
def test_tap_pjit(self, device_index=0):
|
||||
self.supported_only_in_legacy_mode()
|
||||
if (device_index != 0 and
|
||||
not hcb._HOST_CALLBACK_OUTFEED.value and
|
||||
jtu.test_device_matches(["cpu"])):
|
||||
@ -1589,7 +1752,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def test_tap_scan_custom_jvp(self):
|
||||
"""custom JVP, inside scan.
|
||||
This exercises the custom_jvp_call_jaxpr primitives."""
|
||||
|
||||
self.supported_only_in_legacy_mode()
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
||||
@ -1633,7 +1796,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
def test_tap_scan_custom_vjp(self):
|
||||
"""custom VJP, inside scan.
|
||||
This exercises the custom_vjp_call_jaxpr primitives."""
|
||||
|
||||
self.supported_only_in_legacy_mode()
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
||||
@ -1773,7 +1936,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
from jax.experimental.ode import odeint
|
||||
|
||||
def f(x, t, k):
|
||||
x = hcb.id_print(x)
|
||||
x = hcb.id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return -k * x
|
||||
|
||||
def loss(k=1.0):
|
||||
@ -1785,7 +1948,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tap_remat_0(self):
|
||||
def f(i, k):
|
||||
x = hcb.id_print(k + i, output_stream=testing_stream)
|
||||
x = hcb.id_print(k + i, output_stream=testing_stream,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
return k * x
|
||||
|
||||
def loss(k):
|
||||
@ -1804,6 +1968,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
use_remat=["old", "new", "none"],
|
||||
)
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
||||
self.supported_only_in_legacy_mode()
|
||||
if use_remat == "old": raise SkipTest()
|
||||
|
||||
def f(x):
|
||||
@ -1880,6 +2045,10 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait("HostCallbackCallTest.tearDown")
|
||||
super().tearDown()
|
||||
|
||||
def supported_only_in_legacy_mode(self):
|
||||
if not hcb._HOST_CALLBACK_LEGACY.value:
|
||||
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
|
||||
|
||||
def call_log_testing_stream(self, func, arg, *, result_shape, name=""):
|
||||
"""Call `func` and log inputs and outputs to the testing stream"""
|
||||
|
||||
@ -1916,6 +2085,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
for _ in range(3):
|
||||
self.assertAllClose(2 * arg, fun(arg))
|
||||
r = jax.make_jaxpr(fun)(arg)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -2124,6 +2294,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
helper_print_optimized_hlo(fun2, m)
|
||||
|
||||
def test_call_with_device(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
def callback_func(x, device=None):
|
||||
testing_stream.write(f"device: {device}\n Called with {x}")
|
||||
return x
|
||||
@ -2139,6 +2310,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
Called with 3.00""")
|
||||
|
||||
def test_call_pmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
# Works for 1 or 2 devices
|
||||
def callback_func(x, device=None):
|
||||
testing_stream.write(f"device: {device}\n Called with {x}")
|
||||
@ -2163,11 +2335,15 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
def f_outside(x): return x
|
||||
|
||||
def fun(x):
|
||||
return hcb.call(f_outside, x, result_shape=x)
|
||||
return hcb.call(f_outside, x, result_shape=x,
|
||||
callback_flavor=hcb.CallbackFlavor.PURE)
|
||||
|
||||
if hcb._HOST_CALLBACK_LEGACY.value:
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"batching rules are implemented only for id_tap, not for call"):
|
||||
jax.vmap(fun)(np.ones((2, 3)))
|
||||
else:
|
||||
jax.vmap(fun)(np.ones((2, 3)))
|
||||
|
||||
@jtu.sample_product(device_index=[0, 1])
|
||||
@jtu.skip_on_devices("cpu") # TODO: RET_CHECK failure
|
||||
@ -2256,6 +2432,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait("Waiting for error")
|
||||
|
||||
def test_call_error_callback_throws_exception(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
def f_outside(x):
|
||||
raise ValueError("user exception")
|
||||
def fun(x):
|
||||
@ -2265,6 +2442,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"ValueError: user exception")
|
||||
|
||||
def test_call_error_callback_returns_unexpected_shape(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
def fun(x):
|
||||
return hcb.call(lambda x: (x, x), x, result_shape=x)
|
||||
|
||||
@ -2272,6 +2450,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"Callback func .* should have returned a result with pytree")
|
||||
|
||||
def test_call_error_then_compute(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
# Continue computation on device after error
|
||||
def f_outside(x):
|
||||
raise ValueError("user exception")
|
||||
@ -2283,7 +2462,9 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"ValueError: user exception")
|
||||
|
||||
|
||||
def call_jax_other_device(jax_outside_fun, arg, *, device):
|
||||
def call_jax_other_device(
|
||||
jax_outside_fun, arg, *, device,
|
||||
callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK):
|
||||
"""Calls a JAX function on a specific device with simple support for reverse AD.
|
||||
|
||||
Functions whose name starts with "jax_outside" are called on another device,
|
||||
@ -2296,7 +2477,8 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
|
||||
@jax.custom_vjp
|
||||
def make_call(arg):
|
||||
return hcb.call(run_jax_outside_fun, arg,
|
||||
result_shape=jax.eval_shape(jax_outside_fun, arg))
|
||||
result_shape=jax.eval_shape(jax_outside_fun, arg),
|
||||
callback_flavor=callback_flavor)
|
||||
|
||||
# Define the fwd and bwd custom_vjp functions
|
||||
def make_call_vjp_fwd(arg):
|
||||
@ -2323,6 +2505,8 @@ class CallJaxTest(jtu.JaxTestCase):
|
||||
"""Tests using `call_jax_other_device`."""
|
||||
|
||||
def setUp(self):
|
||||
if not hcb._HOST_CALLBACK_LEGACY.value:
|
||||
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
|
||||
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
||||
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
@ -2337,6 +2521,7 @@ class CallJaxTest(jtu.JaxTestCase):
|
||||
self.outside_device = jax.devices("cpu")[1]
|
||||
super().setUp()
|
||||
|
||||
|
||||
def test_jax_impl(self):
|
||||
def f_jax(x):
|
||||
return jnp.sin(x)
|
||||
@ -2404,6 +2589,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
raise SkipTest("host_callback not implemented in PJRT C API")
|
||||
super().setUp()
|
||||
|
||||
def supported_only_in_legacy_mode(self):
|
||||
if not hcb._HOST_CALLBACK_LEGACY.value:
|
||||
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
|
||||
|
||||
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
||||
has_input_token=True, has_output_token=True):
|
||||
"""Check that the rewrite of func(*args) matches expected."""
|
||||
@ -2624,7 +2813,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
def test_scan_custom_jvp(self):
|
||||
"""custom JVP, inside scan.
|
||||
This exercises the custom_jvp_call_jaxpr primitives."""
|
||||
|
||||
self.supported_only_in_legacy_mode()
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
return x * hcb.id_print(x)
|
||||
@ -2706,7 +2895,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
def test_scan_custom_vjp(self):
|
||||
"""custom VJP, inside scan.
|
||||
This exercises the custom_vjp_call_jaxpr primitives."""
|
||||
|
||||
self.supported_only_in_legacy_mode()
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * hcb.id_print(x)
|
||||
@ -2849,6 +3038,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
in (c, d, e) }""", tap_scalar, [np.int32(3)])
|
||||
|
||||
def test_pmap(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
def f(xv):
|
||||
jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
|
||||
axis_name="i")(xv)
|
||||
|
@ -25,8 +25,8 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax.experimental import host_callback as hcb
|
||||
@ -53,7 +53,8 @@ def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
|
||||
|
||||
return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy,
|
||||
tf_fun(arg)),
|
||||
arg, result_shape=result_shape)
|
||||
arg, result_shape=result_shape,
|
||||
callback_flavor=hcb.CallbackFlavor.DEBUG)
|
||||
|
||||
|
||||
def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
|
||||
@ -166,12 +167,17 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("host_callback not implemented in PJRT C API")
|
||||
super().setUp()
|
||||
|
||||
def supported_only_in_legacy_mode(self):
|
||||
if not hcb._HOST_CALLBACK_LEGACY.value:
|
||||
self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name=f"_{ad=}",
|
||||
ad=ad)
|
||||
for ad in CALL_TF_IMPLEMENTATIONS.keys())
|
||||
def test_impl(self, ad="simple"):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||
|
||||
def f_jax(x):
|
||||
@ -192,21 +198,27 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
for ad in CALL_TF_IMPLEMENTATIONS.keys()
|
||||
if ad != "none")
|
||||
def test_grad(self, ad="simple"):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||
|
||||
def f_jax(x):
|
||||
return 3. * jnp.sin(2. * x)
|
||||
|
||||
def f_outside(x):
|
||||
return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x)
|
||||
return 3. * call_tf(
|
||||
lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x,
|
||||
result_shape=jax.ShapeDtypeStruct((), np.float32))
|
||||
|
||||
x = 4.
|
||||
self.assertAllClose(f_jax(x), f_outside(x))
|
||||
x = np.float32(4.)
|
||||
self.assertAllClose(f_jax(x), f_outside(x),
|
||||
check_dtypes=False)
|
||||
|
||||
grad_f = jax.grad(f_outside)(x)
|
||||
self.assertAllClose(jax.grad(f_jax)(x), grad_f)
|
||||
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
|
||||
check_dtypes=False)
|
||||
|
||||
def test_grad_pytree(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = call_tf_full_ad
|
||||
|
||||
def f_jax(xy):
|
||||
@ -215,15 +227,19 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
|
||||
def f_outside(xy):
|
||||
dict_ab = call_tf(
|
||||
lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]),
|
||||
lambda xy: dict(a=tf.cast(2. * xy[0], np.float32),
|
||||
b=tf.cast(xy[0] * xy[1], np.float32)),
|
||||
xy,
|
||||
result_shape=dict(a=xy[0], b=xy[1]))
|
||||
result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32),
|
||||
b=jax.ShapeDtypeStruct((), np.float32)))
|
||||
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
|
||||
|
||||
xy = (5., 6.)
|
||||
self.assertAllClose(f_jax(xy), f_outside(xy))
|
||||
self.assertAllClose(f_jax(xy), f_outside(xy),
|
||||
check_dtypes=False)
|
||||
res_jax = jax.grad(f_jax)(xy)
|
||||
self.assertAllClose(res_jax, jax.grad(f_outside)(xy))
|
||||
self.assertAllClose(res_jax, jax.grad(f_outside)(xy),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
@ -231,6 +247,7 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
degree=degree)
|
||||
for degree in [1, 2, 3, 4])
|
||||
def test_higher_order_grad(self, degree=4):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = call_tf_full_ad
|
||||
|
||||
def f_jax(x):
|
||||
|
@ -247,7 +247,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered):
|
||||
def test_callback_with_wrong_dtype_outputs(self, *, callback):
|
||||
|
||||
def _cb():
|
||||
return np.array([1], np.float64)
|
||||
|
Loading…
x
Reference in New Issue
Block a user