mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent >>> (lambda *args: jax.debug.callback(print, *args))([42]) [42] >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42]) [array(42, dtype=int32)] It was also inconsistent with other callback APIs, which cast the arguments to jax.Arrays. Closes #20627. PiperOrigin-RevId: 626461904
This commit is contained in:
parent
32922f61e9
commit
6e23c14f85
@ -14,10 +14,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
adopted by NumPy.
|
||||
|
||||
* Changes
|
||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||
the old behavior by transforming the arguments via
|
||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
|
||||
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
|
||||
of {class}`np.ndarray`. You can recover the old behavior by transforming
|
||||
the arguments via `jax.tree.map(np.asarray, args)` before passing them
|
||||
to the callback.
|
||||
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||
* Async dispatch expensive computations on the CPU backend. This only applies
|
||||
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import importlib.util
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import logging
|
||||
import string
|
||||
import sys
|
||||
from typing import Any, Callable, Union
|
||||
@ -25,6 +26,7 @@ import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
||||
@ -45,6 +47,8 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DebugEffect(effects.Effect):
|
||||
__str__ = lambda self: "Debug"
|
||||
debug_effect = DebugEffect()
|
||||
@ -73,7 +77,14 @@ map, unsafe_map = util.safe_map, map
|
||||
def debug_callback_impl(*args, callback: Callable[..., Any],
|
||||
effect: DebugEffect):
|
||||
del effect
|
||||
callback(*args)
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
args = jax.device_put(args, cpu_device)
|
||||
with jax.default_device(cpu_device):
|
||||
try:
|
||||
callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.debug_callback failed")
|
||||
raise
|
||||
return ()
|
||||
|
||||
@debug_callback_p.def_effectful_abstract_eval
|
||||
|
@ -110,7 +110,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
return y
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
(jdb) array(2., dtype=float32)
|
||||
(jdb) Array(2., dtype=float32)
|
||||
(jdb) """)
|
||||
f(jnp.array(2., jnp.float32))
|
||||
jax.effects_barrier()
|
||||
@ -126,7 +126,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
return y
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
|
||||
(jdb) (Array(2., dtype=float32), Array(3., dtype=float32))
|
||||
(jdb) """)
|
||||
f(jnp.array(2., jnp.float32))
|
||||
jax.effects_barrier()
|
||||
@ -196,7 +196,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
-> y = f\(x\)
|
||||
return jnp\.exp\(y\)
|
||||
.*
|
||||
\(jdb\) array\(2\., dtype=float32\)
|
||||
\(jdb\) Array\(2\., dtype=float32\)
|
||||
\(jdb\) > .*debugger_test\.py\([0-9]+\)
|
||||
def f\(x\):
|
||||
y = jnp\.sin\(x\)
|
||||
@ -225,9 +225,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
return jnp.exp(y)
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
(jdb) array(3., dtype=float32)
|
||||
(jdb) Array(3., dtype=float32)
|
||||
(jdb) Entering jdb:
|
||||
(jdb) array(6., dtype=float32)
|
||||
(jdb) Array(6., dtype=float32)
|
||||
(jdb) """)
|
||||
g(jnp.array(2., jnp.float32))
|
||||
jax.effects_barrier()
|
||||
@ -249,9 +249,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
return jnp.exp(y)
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
(jdb) array(1., dtype=float32)
|
||||
(jdb) Array(1., dtype=float32)
|
||||
(jdb) Entering jdb:
|
||||
(jdb) array(2., dtype=float32)
|
||||
(jdb) Array(2., dtype=float32)
|
||||
(jdb) """)
|
||||
g(jnp.arange(2., dtype=jnp.float32))
|
||||
jax.effects_barrier()
|
||||
@ -274,9 +274,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
return jnp.exp(y)
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
\(jdb\) array\(.*, dtype=float32\)
|
||||
\(jdb\) Array\(.*, dtype=float32\)
|
||||
\(jdb\) Entering jdb:
|
||||
\(jdb\) array\(.*, dtype=float32\)
|
||||
\(jdb\) Array\(.*, dtype=float32\)
|
||||
\(jdb\) """)
|
||||
g(jnp.arange(2., dtype=jnp.float32))
|
||||
jax.effects_barrier()
|
||||
@ -302,7 +302,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
out_shardings=jax.sharding.PartitionSpec("dev"),
|
||||
)
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
|
||||
arr = (1 + np.arange(8)).astype(np.int32)
|
||||
arr = (1 + jnp.arange(8)).astype(np.int32)
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
\(jdb\) {}
|
||||
|
@ -170,7 +170,7 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
with jtu.capture_stdout() as output:
|
||||
f(np.array(2, np.int32))
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
|
||||
self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n")
|
||||
|
||||
def test_debug_print_should_use_default_layout(self):
|
||||
data = np.array(
|
||||
|
Loading…
x
Reference in New Issue
Block a user