jax.pure_callback and jax.experimental.io_callback now use jax.Arrays

The motivation for this change is two-fold

* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
  of always copying it to the host. Note that the version here still always
  copies to the host.

If this change breaks you, you can recover the old behavior by changing

    jax.pure_callback(
        f,
        result_shape_dtypes,
        *args,
        **kwargs,
    )

to

    jax.pure_callback(
        lambda *args: f(*jax.tree.map(np.asarray, args)),
        result_shape_dtypes,
        *args,
        **kwargs,
    )

so that the callback function is called with NumPy arrays as before.

I will update the "External callbacks" tutorial in a follow up.

PiperOrigin-RevId: 622457378
This commit is contained in:
Sergei Lebedev 2024-04-06 09:29:16 -07:00 committed by jax authors
parent 63aee94792
commit 9616900cc9
4 changed files with 30 additions and 21 deletions

View File

@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.27
* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
@ -66,7 +72,6 @@ Remember to align the itemized text with the first line of an item within a list
This change could break clients that set a specific
JAX serialization version lower than 9.
## jaxlib 0.4.26 (April 3, 2024)
* Changes

View File

@ -14,14 +14,12 @@
"""Module for JAX callbacks."""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
import logging
import dataclasses
import functools
import logging
from typing import Any, Callable
import numpy as np
import jax
from jax._src import core
from jax._src import dispatch
@ -33,9 +31,10 @@ from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
import numpy as np
logger = logging.getLogger(__name__)
@ -73,11 +72,14 @@ def pure_callback_impl(
vectorized: bool,
):
del sharding, vectorized, result_avals
try:
return callback(*args)
except BaseException:
logger.exception("jax.pure_callback failed")
raise
cpu_device, *_ = jax.local_devices(backend="cpu")
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
with jax.default_device(cpu_device):
try:
return tree_util.tree_map(np.asarray, callback(*args))
except BaseException:
logger.exception("jax.pure_callback failed")
raise
pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
@ -398,11 +400,14 @@ def io_callback_impl(
ordered: bool,
):
del result_avals, sharding, ordered
try:
return callback(*args)
except BaseException:
logger.exception("jax.io_callback failed")
raise
cpu_device, *_ = jax.local_devices(backend="cpu")
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
with jax.default_device(cpu_device):
try:
return tree_util.tree_map(np.asarray, callback(*args))
except BaseException:
logger.exception("jax.io_callback failed")
raise
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,

View File

@ -428,7 +428,7 @@ class PJitTest(jtu.BufferDonationTestCase):
s = NamedSharding(mesh, P('x'))
def _callback(x):
self.assertIs(type(x), np.ndarray)
self.assertIsInstance(x, jax.Array)
@partial(pjit, donate_argnames=('x'))
def f(x):

View File

@ -89,8 +89,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
def test_callback_with_scalar_values(self, *, callback):
@jax.jit
def f(x):
return callback(lambda x: x + np.float32(1.),
core.ShapedArray(x.shape, x.dtype), x)
return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x)
out = f(0.)
self.assertEqual(out, 1.)
@ -617,10 +616,10 @@ class PureCallbackTest(jtu.JaxTestCase):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_pure_callback_passes_ndarrays_without_jit(self):
def test_pure_callback_passes_jax_arrays_without_jit(self):
def cb(x):
self.assertIs(type(x), np.ndarray)
self.assertIsInstance(x, jax.Array)
return x
def f(x):