diff --git a/CHANGELOG.md b/CHANGELOG.md index 24860825a..1b287de3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 +* Changes + * {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` + now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover + the old behavior by transforming the arguments via + `jax.tree.map(np.asarray, args)` before passing them to the callback. + * Deprecations & Removals * Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the @@ -66,7 +72,6 @@ Remember to align the itemized text with the first line of an item within a list This change could break clients that set a specific JAX serialization version lower than 9. - ## jaxlib 0.4.26 (April 3, 2024) * Changes diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 0942d04ab..f5ab19854 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -14,14 +14,12 @@ """Module for JAX callbacks.""" from __future__ import annotations -import dataclasses from collections.abc import Sequence -import logging +import dataclasses import functools +import logging from typing import Any, Callable -import numpy as np - import jax from jax._src import core from jax._src import dispatch @@ -33,9 +31,10 @@ from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lib import xla_client as xc from jax._src.lax.control_flow.loops import map as lax_map +from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding +import numpy as np logger = logging.getLogger(__name__) @@ -73,11 +72,14 @@ def pure_callback_impl( vectorized: bool, ): del sharding, vectorized, result_avals - try: - return callback(*args) - except BaseException: - logger.exception("jax.pure_callback failed") - raise + cpu_device, *_ = jax.local_devices(backend="cpu") + args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args) + with jax.default_device(cpu_device): + try: + return tree_util.tree_map(np.asarray, callback(*args)) + except BaseException: + logger.exception("jax.pure_callback failed") + raise pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -398,11 +400,14 @@ def io_callback_impl( ordered: bool, ): del result_avals, sharding, ordered - try: - return callback(*args) - except BaseException: - logger.exception("jax.io_callback failed") - raise + cpu_device, *_ = jax.local_devices(backend="cpu") + args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args) + with jax.default_device(cpu_device): + try: + return tree_util.tree_map(np.asarray, callback(*args)) + except BaseException: + logger.exception("jax.io_callback failed") + raise io_callback_p.def_impl(functools.partial(dispatch.apply_primitive, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 48b5830ea..ff588ebf3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -428,7 +428,7 @@ class PJitTest(jtu.BufferDonationTestCase): s = NamedSharding(mesh, P('x')) def _callback(x): - self.assertIs(type(x), np.ndarray) + self.assertIsInstance(x, jax.Array) @partial(pjit, donate_argnames=('x')) def f(x): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index ec5945be6..761975512 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -89,8 +89,7 @@ class PythonCallbackTest(jtu.JaxTestCase): def test_callback_with_scalar_values(self, *, callback): @jax.jit def f(x): - return callback(lambda x: x + np.float32(1.), - core.ShapedArray(x.shape, x.dtype), x) + return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x) out = f(0.) self.assertEqual(out, 1.) @@ -617,10 +616,10 @@ class PureCallbackTest(jtu.JaxTestCase): super().tearDown() dispatch.runtime_tokens.clear() - def test_pure_callback_passes_ndarrays_without_jit(self): + def test_pure_callback_passes_jax_arrays_without_jit(self): def cb(x): - self.assertIs(type(x), np.ndarray) + self.assertIsInstance(x, jax.Array) return x def f(x):