mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent >>> (lambda *args: jax.debug.callback(print, *args))([42]) [42] >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42]) [array(42, dtype=int32)] It was also inconsistent with other callback APIs, which cast the arguments to jax.Arrays. Closes #20627. PiperOrigin-RevId: 626461904
This commit is contained in:
parent
32922f61e9
commit
6e23c14f85
@ -14,10 +14,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
adopted by NumPy.
|
adopted by NumPy.
|
||||||
|
|
||||||
* Changes
|
* Changes
|
||||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
|
||||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
|
||||||
the old behavior by transforming the arguments via
|
of {class}`np.ndarray`. You can recover the old behavior by transforming
|
||||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
the arguments via `jax.tree.map(np.asarray, args)` before passing them
|
||||||
|
to the callback.
|
||||||
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||||
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||||
* Async dispatch expensive computations on the CPU backend. This only applies
|
* Async dispatch expensive computations on the CPU backend. This only applies
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
@ -25,6 +26,7 @@ import weakref
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
@ -45,6 +47,8 @@ from jax._src.lib.mlir.dialects import hlo
|
|||||||
from jax._src.sharding import Sharding
|
from jax._src.sharding import Sharding
|
||||||
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
|
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DebugEffect(effects.Effect):
|
class DebugEffect(effects.Effect):
|
||||||
__str__ = lambda self: "Debug"
|
__str__ = lambda self: "Debug"
|
||||||
debug_effect = DebugEffect()
|
debug_effect = DebugEffect()
|
||||||
@ -73,7 +77,14 @@ map, unsafe_map = util.safe_map, map
|
|||||||
def debug_callback_impl(*args, callback: Callable[..., Any],
|
def debug_callback_impl(*args, callback: Callable[..., Any],
|
||||||
effect: DebugEffect):
|
effect: DebugEffect):
|
||||||
del effect
|
del effect
|
||||||
callback(*args)
|
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||||
|
args = jax.device_put(args, cpu_device)
|
||||||
|
with jax.default_device(cpu_device):
|
||||||
|
try:
|
||||||
|
callback(*args)
|
||||||
|
except BaseException:
|
||||||
|
logger.exception("jax.debug_callback failed")
|
||||||
|
raise
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@debug_callback_p.def_effectful_abstract_eval
|
@debug_callback_p.def_effectful_abstract_eval
|
||||||
|
@ -110,7 +110,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
return y
|
return y
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
(jdb) array(2., dtype=float32)
|
(jdb) Array(2., dtype=float32)
|
||||||
(jdb) """)
|
(jdb) """)
|
||||||
f(jnp.array(2., jnp.float32))
|
f(jnp.array(2., jnp.float32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
@ -126,7 +126,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
return y
|
return y
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
|
(jdb) (Array(2., dtype=float32), Array(3., dtype=float32))
|
||||||
(jdb) """)
|
(jdb) """)
|
||||||
f(jnp.array(2., jnp.float32))
|
f(jnp.array(2., jnp.float32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
@ -196,7 +196,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
-> y = f\(x\)
|
-> y = f\(x\)
|
||||||
return jnp\.exp\(y\)
|
return jnp\.exp\(y\)
|
||||||
.*
|
.*
|
||||||
\(jdb\) array\(2\., dtype=float32\)
|
\(jdb\) Array\(2\., dtype=float32\)
|
||||||
\(jdb\) > .*debugger_test\.py\([0-9]+\)
|
\(jdb\) > .*debugger_test\.py\([0-9]+\)
|
||||||
def f\(x\):
|
def f\(x\):
|
||||||
y = jnp\.sin\(x\)
|
y = jnp\.sin\(x\)
|
||||||
@ -225,9 +225,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
return jnp.exp(y)
|
return jnp.exp(y)
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
(jdb) array(3., dtype=float32)
|
(jdb) Array(3., dtype=float32)
|
||||||
(jdb) Entering jdb:
|
(jdb) Entering jdb:
|
||||||
(jdb) array(6., dtype=float32)
|
(jdb) Array(6., dtype=float32)
|
||||||
(jdb) """)
|
(jdb) """)
|
||||||
g(jnp.array(2., jnp.float32))
|
g(jnp.array(2., jnp.float32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
@ -249,9 +249,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
return jnp.exp(y)
|
return jnp.exp(y)
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
(jdb) array(1., dtype=float32)
|
(jdb) Array(1., dtype=float32)
|
||||||
(jdb) Entering jdb:
|
(jdb) Entering jdb:
|
||||||
(jdb) array(2., dtype=float32)
|
(jdb) Array(2., dtype=float32)
|
||||||
(jdb) """)
|
(jdb) """)
|
||||||
g(jnp.arange(2., dtype=jnp.float32))
|
g(jnp.arange(2., dtype=jnp.float32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
@ -274,9 +274,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
return jnp.exp(y)
|
return jnp.exp(y)
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
\(jdb\) array\(.*, dtype=float32\)
|
\(jdb\) Array\(.*, dtype=float32\)
|
||||||
\(jdb\) Entering jdb:
|
\(jdb\) Entering jdb:
|
||||||
\(jdb\) array\(.*, dtype=float32\)
|
\(jdb\) Array\(.*, dtype=float32\)
|
||||||
\(jdb\) """)
|
\(jdb\) """)
|
||||||
g(jnp.arange(2., dtype=jnp.float32))
|
g(jnp.arange(2., dtype=jnp.float32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
@ -302,7 +302,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
|||||||
out_shardings=jax.sharding.PartitionSpec("dev"),
|
out_shardings=jax.sharding.PartitionSpec("dev"),
|
||||||
)
|
)
|
||||||
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
|
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
|
||||||
arr = (1 + np.arange(8)).astype(np.int32)
|
arr = (1 + jnp.arange(8)).astype(np.int32)
|
||||||
expected = _format_multiline(r"""
|
expected = _format_multiline(r"""
|
||||||
Entering jdb:
|
Entering jdb:
|
||||||
\(jdb\) {}
|
\(jdb\) {}
|
||||||
|
@ -170,7 +170,7 @@ class DebugPrintTest(jtu.JaxTestCase):
|
|||||||
with jtu.capture_stdout() as output:
|
with jtu.capture_stdout() as output:
|
||||||
f(np.array(2, np.int32))
|
f(np.array(2, np.int32))
|
||||||
jax.effects_barrier()
|
jax.effects_barrier()
|
||||||
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
|
self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n")
|
||||||
|
|
||||||
def test_debug_print_should_use_default_layout(self):
|
def test_debug_print_should_use_default_layout(self):
|
||||||
data = np.array(
|
data = np.array(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user