mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Support error checking in explicit mode
PiperOrigin-RevId: 737051146
This commit is contained in:
parent
d07d642d6f
commit
9b0ace4a11
@ -14,13 +14,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import threading
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
import jax._src.mesh as mesh_lib
|
||||
from jax.experimental.shard_map import shard_map
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
|
||||
|
||||
Traceback = source_info_util.Traceback
|
||||
@ -54,17 +58,61 @@ _error_storage = _ErrorStorage()
|
||||
|
||||
|
||||
def _initialize_error_code_ref() -> None:
|
||||
"""Initialize error_code_ref in the current thread."""
|
||||
"""Initialize error_code_ref in the current thread.
|
||||
|
||||
The size of the error code array is determined by the mesh in the context. In
|
||||
single-device environment, the array is a scalar. In multi-device
|
||||
environment, the array has the same shape as the mesh.
|
||||
"""
|
||||
with core.eval_context():
|
||||
error_code = jnp.uint32(_NO_ERROR)
|
||||
# Get mesh from the context.
|
||||
mesh = mesh_lib.get_concrete_mesh()
|
||||
|
||||
if mesh is None: # single-device case.
|
||||
error_code = jnp.uint32(_NO_ERROR)
|
||||
|
||||
else: # multi-device case.
|
||||
sharding = NamedSharding(mesh, P(*mesh.axis_names))
|
||||
error_code = jnp.full(
|
||||
mesh.axis_sizes,
|
||||
jnp.uint32(_NO_ERROR),
|
||||
device=sharding,
|
||||
)
|
||||
|
||||
_error_storage.ref = core.mutable_array(error_code)
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
class error_checking_context:
|
||||
"""Redefine the error checking state based on the mesh in the context.
|
||||
|
||||
This context manager should be used when starting a multi-device
|
||||
computation, and whenever the mesh is changed.
|
||||
|
||||
When exiting the context, the error checking state will be reset to the
|
||||
original state.
|
||||
"""
|
||||
|
||||
__slots__ = ("old_ref",)
|
||||
|
||||
def __init__(self):
|
||||
self.old_ref = None
|
||||
|
||||
def __enter__(self):
|
||||
self.old_ref = _error_storage.ref
|
||||
_initialize_error_code_ref()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_error_storage.ref = self.old_ref
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, /, msg: str) -> None:
|
||||
"""Set error if any element of pred is true.
|
||||
|
||||
If the error is already set, the new error will be ignored. It will not
|
||||
override the existing error.
|
||||
|
||||
In auto mode, this function does not work under jit.
|
||||
"""
|
||||
if _error_storage.ref is None:
|
||||
_initialize_error_code_ref()
|
||||
@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
new_error_code = jnp.uint32(len(_error_list))
|
||||
_error_list.append((msg, traceback))
|
||||
|
||||
pred = pred.any()
|
||||
out_sharding = core.typeof(_error_storage.ref).sharding
|
||||
in_sharding: NamedSharding = core.typeof(pred).sharding
|
||||
|
||||
if out_sharding.mesh.shape_tuple == (): # single-device case.
|
||||
pred = pred.any()
|
||||
else: # multi-device case.
|
||||
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
|
||||
if has_auto_axes:
|
||||
raise NotImplementedError(
|
||||
"Error checking in auto mode is not supported yet. Please use"
|
||||
" explicit mode."
|
||||
)
|
||||
if out_sharding.mesh != in_sharding.mesh:
|
||||
raise ValueError(
|
||||
"The error code state and the predicate must be on the same mesh, "
|
||||
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
|
||||
"Please use `with error_checking_context()` to redefine the error "
|
||||
"code state based on the mesh."
|
||||
)
|
||||
pred = shard_map(
|
||||
partial(jnp.any, keepdims=True),
|
||||
mesh=out_sharding.mesh,
|
||||
in_specs=in_sharding.spec,
|
||||
out_specs=out_sharding.spec,
|
||||
)(pred) # perform per-device reduction
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
|
||||
error_code = jnp.where(should_update, new_error_code, error_code)
|
||||
@ -93,7 +166,7 @@ def raise_if_error() -> None:
|
||||
if _error_storage.ref is None: # if not initialized, do nothing
|
||||
return
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
error_code = _error_storage.ref[...].min() # reduce to a single error code
|
||||
if isinstance(error_code, core.Tracer):
|
||||
raise ValueError(
|
||||
"raise_if_error() should not be called within a traced context, such as"
|
||||
@ -101,7 +174,11 @@ def raise_if_error() -> None:
|
||||
)
|
||||
if error_code == jnp.uint32(_NO_ERROR):
|
||||
return
|
||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
||||
_error_storage.ref[...] = jnp.full(
|
||||
_error_storage.ref.shape,
|
||||
jnp.uint32(_NO_ERROR),
|
||||
device=_error_storage.ref.sharding,
|
||||
) # clear the error code
|
||||
|
||||
msg, traceback = _error_list[error_code]
|
||||
exc = JaxValueError(msg)
|
||||
|
@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None):
|
||||
finally:
|
||||
jax_config.device_context.set_local(prev_val)
|
||||
|
||||
def get_concrete_mesh():
|
||||
def get_concrete_mesh() -> Mesh | None:
|
||||
return jax_config.device_context.value
|
||||
|
@ -20,12 +20,14 @@ from jax._src import config
|
||||
from jax._src import error_check
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
|
||||
|
||||
JaxValueError = error_check.JaxValueError
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(4)
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase):
|
||||
):
|
||||
jax.jit(error_check.raise_if_error)()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
@jtu.with_user_mesh((2, 2), ("x", "y"))
|
||||
def test_error_check_explicit_mode(self, mesh, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
sharding = NamedSharding(mesh, P("x", "y"))
|
||||
x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
|
||||
with error_check.error_checking_context():
|
||||
f(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user