mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold * JAX APIs should use jax.Arrays. * Using jax.Arrays potentially allows keeping the data on device, instead of always copying it to the host. Note that the version here still always copies to the host. If this change breaks you, you can recover the old behavior by changing jax.pure_callback( f, result_shape_dtypes, *args, **kwargs, ) to jax.pure_callback( lambda *args: f(*jax.tree.map(np.asarray, args)), result_shape_dtypes, *args, **kwargs, ) so that the callback function is called with NumPy arrays as before. I will update the "External callbacks" tutorial in a follow up. PiperOrigin-RevId: 622457378
This commit is contained in:
parent
63aee94792
commit
9616900cc9
@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.27
|
||||
|
||||
* Changes
|
||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||
the old behavior by transforming the arguments via
|
||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
lowering pass via Triton Python APIs has been removed and the
|
||||
@ -66,7 +72,6 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
This change could break clients that set a specific
|
||||
JAX serialization version lower than 9.
|
||||
|
||||
|
||||
## jaxlib 0.4.26 (April 3, 2024)
|
||||
|
||||
* Changes
|
||||
|
@ -14,14 +14,12 @@
|
||||
"""Module for JAX callbacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -33,9 +31,10 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -73,11 +72,14 @@ def pure_callback_impl(
|
||||
vectorized: bool,
|
||||
):
|
||||
del sharding, vectorized, result_avals
|
||||
try:
|
||||
return callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.pure_callback failed")
|
||||
raise
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
|
||||
with jax.default_device(cpu_device):
|
||||
try:
|
||||
return tree_util.tree_map(np.asarray, callback(*args))
|
||||
except BaseException:
|
||||
logger.exception("jax.pure_callback failed")
|
||||
raise
|
||||
|
||||
|
||||
pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
@ -398,11 +400,14 @@ def io_callback_impl(
|
||||
ordered: bool,
|
||||
):
|
||||
del result_avals, sharding, ordered
|
||||
try:
|
||||
return callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.io_callback failed")
|
||||
raise
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
|
||||
with jax.default_device(cpu_device):
|
||||
try:
|
||||
return tree_util.tree_map(np.asarray, callback(*args))
|
||||
except BaseException:
|
||||
logger.exception("jax.io_callback failed")
|
||||
raise
|
||||
|
||||
|
||||
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
|
@ -428,7 +428,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
def _callback(x):
|
||||
self.assertIs(type(x), np.ndarray)
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
|
||||
@partial(pjit, donate_argnames=('x'))
|
||||
def f(x):
|
||||
|
@ -89,8 +89,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
def test_callback_with_scalar_values(self, *, callback):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback(lambda x: x + np.float32(1.),
|
||||
core.ShapedArray(x.shape, x.dtype), x)
|
||||
return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x)
|
||||
|
||||
out = f(0.)
|
||||
self.assertEqual(out, 1.)
|
||||
@ -617,10 +616,10 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def test_pure_callback_passes_ndarrays_without_jit(self):
|
||||
def test_pure_callback_passes_jax_arrays_without_jit(self):
|
||||
|
||||
def cb(x):
|
||||
self.assertIs(type(x), np.ndarray)
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
return x
|
||||
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user