jax.debug.callback now passes arguments as jax.Arrays

Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
This commit is contained in:
Sergei Lebedev 2024-04-19 13:56:28 -07:00 committed by jax authors
parent 32922f61e9
commit 6e23c14f85
4 changed files with 28 additions and 16 deletions

View File

@ -14,10 +14,11 @@ Remember to align the itemized text with the first line of an item within a list
adopted by NumPy.
* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
of {class}`np.ndarray`. You can recover the old behavior by transforming
the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* Async dispatch expensive computations on the CPU backend. This only applies

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import importlib.util
from collections.abc import Sequence
import functools
import logging
import string
import sys
from typing import Any, Callable, Union
@ -25,6 +26,7 @@ import weakref
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
@ -45,6 +47,8 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
logger = logging.getLogger(__name__)
class DebugEffect(effects.Effect):
__str__ = lambda self: "Debug"
debug_effect = DebugEffect()
@ -73,7 +77,14 @@ map, unsafe_map = util.safe_map, map
def debug_callback_impl(*args, callback: Callable[..., Any],
effect: DebugEffect):
del effect
cpu_device, *_ = jax.local_devices(backend="cpu")
args = jax.device_put(args, cpu_device)
with jax.default_device(cpu_device):
try:
callback(*args)
except BaseException:
logger.exception("jax.debug_callback failed")
raise
return ()
@debug_callback_p.def_effectful_abstract_eval

View File

@ -110,7 +110,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
return y
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) Array(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
@ -126,7 +126,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
return y
expected = _format_multiline(r"""
Entering jdb:
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
(jdb) (Array(2., dtype=float32), Array(3., dtype=float32))
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
@ -196,7 +196,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
-> y = f\(x\)
return jnp\.exp\(y\)
.*
\(jdb\) array\(2\., dtype=float32\)
\(jdb\) Array\(2\., dtype=float32\)
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
@ -225,9 +225,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(3., dtype=float32)
(jdb) Array(3., dtype=float32)
(jdb) Entering jdb:
(jdb) array(6., dtype=float32)
(jdb) Array(6., dtype=float32)
(jdb) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
@ -249,9 +249,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(1., dtype=float32)
(jdb) Array(1., dtype=float32)
(jdb) Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) Array(2., dtype=float32)
(jdb) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
@ -274,9 +274,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) Array\(.*, dtype=float32\)
\(jdb\) Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) Array\(.*, dtype=float32\)
\(jdb\) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
@ -302,7 +302,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
out_shardings=jax.sharding.PartitionSpec("dev"),
)
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32)
arr = (1 + jnp.arange(8)).astype(np.int32)
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) {}

View File

@ -170,7 +170,7 @@ class DebugPrintTest(jtu.JaxTestCase):
with jtu.capture_stdout() as output:
f(np.array(2, np.int32))
jax.effects_barrier()
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n")
def test_debug_print_should_use_default_layout(self):
data = np.array(