jax.debug.callback now passes arguments as jax.Arrays

Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
This commit is contained in:
Sergei Lebedev 2024-04-19 13:56:28 -07:00 committed by jax authors
parent 32922f61e9
commit 6e23c14f85
4 changed files with 28 additions and 16 deletions

View File

@ -14,10 +14,11 @@ Remember to align the itemized text with the first line of an item within a list
adopted by NumPy. adopted by NumPy.
* Changes * Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` * {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover and {func}`jax.debug.callback` now use {class}`jax.Array` instead
the old behavior by transforming the arguments via of {class}`np.ndarray`. You can recover the old behavior by transforming
`jax.tree.map(np.asarray, args)` before passing them to the callback. the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning * `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise. False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* Async dispatch expensive computations on the CPU backend. This only applies * Async dispatch expensive computations on the CPU backend. This only applies

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import importlib.util import importlib.util
from collections.abc import Sequence from collections.abc import Sequence
import functools import functools
import logging
import string import string
import sys import sys
from typing import Any, Callable, Union from typing import Any, Callable, Union
@ -25,6 +26,7 @@ import weakref
import numpy as np import numpy as np
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax
@ -45,6 +47,8 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
logger = logging.getLogger(__name__)
class DebugEffect(effects.Effect): class DebugEffect(effects.Effect):
__str__ = lambda self: "Debug" __str__ = lambda self: "Debug"
debug_effect = DebugEffect() debug_effect = DebugEffect()
@ -73,7 +77,14 @@ map, unsafe_map = util.safe_map, map
def debug_callback_impl(*args, callback: Callable[..., Any], def debug_callback_impl(*args, callback: Callable[..., Any],
effect: DebugEffect): effect: DebugEffect):
del effect del effect
callback(*args) cpu_device, *_ = jax.local_devices(backend="cpu")
args = jax.device_put(args, cpu_device)
with jax.default_device(cpu_device):
try:
callback(*args)
except BaseException:
logger.exception("jax.debug_callback failed")
raise
return () return ()
@debug_callback_p.def_effectful_abstract_eval @debug_callback_p.def_effectful_abstract_eval

View File

@ -110,7 +110,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
return y return y
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
(jdb) array(2., dtype=float32) (jdb) Array(2., dtype=float32)
(jdb) """) (jdb) """)
f(jnp.array(2., jnp.float32)) f(jnp.array(2., jnp.float32))
jax.effects_barrier() jax.effects_barrier()
@ -126,7 +126,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
return y return y
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
(jdb) (array(2., dtype=float32), array(3., dtype=float32)) (jdb) (Array(2., dtype=float32), Array(3., dtype=float32))
(jdb) """) (jdb) """)
f(jnp.array(2., jnp.float32)) f(jnp.array(2., jnp.float32))
jax.effects_barrier() jax.effects_barrier()
@ -196,7 +196,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
-> y = f\(x\) -> y = f\(x\)
return jnp\.exp\(y\) return jnp\.exp\(y\)
.* .*
\(jdb\) array\(2\., dtype=float32\) \(jdb\) Array\(2\., dtype=float32\)
\(jdb\) > .*debugger_test\.py\([0-9]+\) \(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\): def f\(x\):
y = jnp\.sin\(x\) y = jnp\.sin\(x\)
@ -225,9 +225,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y) return jnp.exp(y)
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
(jdb) array(3., dtype=float32) (jdb) Array(3., dtype=float32)
(jdb) Entering jdb: (jdb) Entering jdb:
(jdb) array(6., dtype=float32) (jdb) Array(6., dtype=float32)
(jdb) """) (jdb) """)
g(jnp.array(2., jnp.float32)) g(jnp.array(2., jnp.float32))
jax.effects_barrier() jax.effects_barrier()
@ -249,9 +249,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y) return jnp.exp(y)
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
(jdb) array(1., dtype=float32) (jdb) Array(1., dtype=float32)
(jdb) Entering jdb: (jdb) Entering jdb:
(jdb) array(2., dtype=float32) (jdb) Array(2., dtype=float32)
(jdb) """) (jdb) """)
g(jnp.arange(2., dtype=jnp.float32)) g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier() jax.effects_barrier()
@ -274,9 +274,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
return jnp.exp(y) return jnp.exp(y)
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
\(jdb\) array\(.*, dtype=float32\) \(jdb\) Array\(.*, dtype=float32\)
\(jdb\) Entering jdb: \(jdb\) Entering jdb:
\(jdb\) array\(.*, dtype=float32\) \(jdb\) Array\(.*, dtype=float32\)
\(jdb\) """) \(jdb\) """)
g(jnp.arange(2., dtype=jnp.float32)) g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier() jax.effects_barrier()
@ -302,7 +302,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
out_shardings=jax.sharding.PartitionSpec("dev"), out_shardings=jax.sharding.PartitionSpec("dev"),
) )
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]): with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32) arr = (1 + jnp.arange(8)).astype(np.int32)
expected = _format_multiline(r""" expected = _format_multiline(r"""
Entering jdb: Entering jdb:
\(jdb\) {} \(jdb\) {}

View File

@ -170,7 +170,7 @@ class DebugPrintTest(jtu.JaxTestCase):
with jtu.capture_stdout() as output: with jtu.capture_stdout() as output:
f(np.array(2, np.int32)) f(np.array(2, np.int32))
jax.effects_barrier() jax.effects_barrier()
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n") self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n")
def test_debug_print_should_use_default_layout(self): def test_debug_print_should_use_default_layout(self):
data = np.array( data = np.array(