Support error checking in explicit mode

PiperOrigin-RevId: 737051146
This commit is contained in:
Ayaka 2025-03-14 18:57:39 -07:00 committed by jax authors
parent d07d642d6f
commit 9b0ace4a11
3 changed files with 103 additions and 7 deletions

View File

@ -14,13 +14,17 @@
from __future__ import annotations
from functools import partial
import threading
import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax._src.mesh as mesh_lib
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
Traceback = source_info_util.Traceback
@ -54,17 +58,61 @@ _error_storage = _ErrorStorage()
def _initialize_error_code_ref() -> None:
"""Initialize error_code_ref in the current thread."""
"""Initialize error_code_ref in the current thread.
The size of the error code array is determined by the mesh in the context. In
single-device environment, the array is a scalar. In multi-device
environment, the array has the same shape as the mesh.
"""
with core.eval_context():
error_code = jnp.uint32(_NO_ERROR)
# Get mesh from the context.
mesh = mesh_lib.get_concrete_mesh()
if mesh is None: # single-device case.
error_code = jnp.uint32(_NO_ERROR)
else: # multi-device case.
sharding = NamedSharding(mesh, P(*mesh.axis_names))
error_code = jnp.full(
mesh.axis_sizes,
jnp.uint32(_NO_ERROR),
device=sharding,
)
_error_storage.ref = core.mutable_array(error_code)
def set_error_if(pred: jax.Array, msg: str) -> None:
class error_checking_context:
"""Redefine the error checking state based on the mesh in the context.
This context manager should be used when starting a multi-device
computation, and whenever the mesh is changed.
When exiting the context, the error checking state will be reset to the
original state.
"""
__slots__ = ("old_ref",)
def __init__(self):
self.old_ref = None
def __enter__(self):
self.old_ref = _error_storage.ref
_initialize_error_code_ref()
return self
def __exit__(self, exc_type, exc_value, traceback):
_error_storage.ref = self.old_ref
def set_error_if(pred: jax.Array, /, msg: str) -> None:
"""Set error if any element of pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
In auto mode, this function does not work under jit.
"""
if _error_storage.ref is None:
_initialize_error_code_ref()
@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
new_error_code = jnp.uint32(len(_error_list))
_error_list.append((msg, traceback))
pred = pred.any()
out_sharding = core.typeof(_error_storage.ref).sharding
in_sharding: NamedSharding = core.typeof(pred).sharding
if out_sharding.mesh.shape_tuple == (): # single-device case.
pred = pred.any()
else: # multi-device case.
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
if has_auto_axes:
raise NotImplementedError(
"Error checking in auto mode is not supported yet. Please use"
" explicit mode."
)
if out_sharding.mesh != in_sharding.mesh:
raise ValueError(
"The error code state and the predicate must be on the same mesh, "
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
"Please use `with error_checking_context()` to redefine the error "
"code state based on the mesh."
)
pred = shard_map(
partial(jnp.any, keepdims=True),
mesh=out_sharding.mesh,
in_specs=in_sharding.spec,
out_specs=out_sharding.spec,
)(pred) # perform per-device reduction
error_code = _error_storage.ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
@ -93,7 +166,7 @@ def raise_if_error() -> None:
if _error_storage.ref is None: # if not initialized, do nothing
return
error_code = _error_storage.ref[...]
error_code = _error_storage.ref[...].min() # reduce to a single error code
if isinstance(error_code, core.Tracer):
raise ValueError(
"raise_if_error() should not be called within a traced context, such as"
@ -101,7 +174,11 @@ def raise_if_error() -> None:
)
if error_code == jnp.uint32(_NO_ERROR):
return
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
_error_storage.ref[...] = jnp.full(
_error_storage.ref.shape,
jnp.uint32(_NO_ERROR),
device=_error_storage.ref.sharding,
) # clear the error code
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)

View File

@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None):
finally:
jax_config.device_context.set_local(prev_val)
def get_concrete_mesh():
def get_concrete_mesh() -> Mesh | None:
return jax_config.device_context.value

View File

@ -20,12 +20,14 @@ from jax._src import config
from jax._src import error_check
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
JaxValueError = error_check.JaxValueError
config.parse_flags_with_absl()
jtu.request_cpu_devices(4)
@jtu.with_config(jax_check_tracer_leaks=True)
@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase):
):
jax.jit(error_check.raise_if_error)()
@parameterized.product(jit=[True, False])
@jtu.with_user_mesh((2, 2), ("x", "y"))
def test_error_check_explicit_mode(self, mesh, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1
if jit:
f = jax.jit(f)
sharding = NamedSharding(mesh, P("x", "y"))
x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
with error_check.error_checking_context():
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())