rocm_jax/jax/_src/error_check.py
Ayaka 9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00

188 lines
5.7 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import threading
import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax._src.mesh as mesh_lib
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
Traceback = source_info_util.Traceback
traceback_util.register_exclusion(__file__)
class JaxValueError(ValueError):
"""Exception raised for failed runtime error checks in JAX."""
#: The default error code for no error.
#:
#: This value is chosen because we can use `jnp.min()` to obtain the
#: first error when performing reductions.
_NO_ERROR = jnp.iinfo(jnp.uint32).max
_error_list_lock = threading.Lock()
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair
class _ErrorStorage(threading.local):
def __init__(self):
self.ref: core.MutableArray | None = None
_error_storage = _ErrorStorage()
def _initialize_error_code_ref() -> None:
"""Initialize error_code_ref in the current thread.
The size of the error code array is determined by the mesh in the context. In
single-device environment, the array is a scalar. In multi-device
environment, the array has the same shape as the mesh.
"""
with core.eval_context():
# Get mesh from the context.
mesh = mesh_lib.get_concrete_mesh()
if mesh is None: # single-device case.
error_code = jnp.uint32(_NO_ERROR)
else: # multi-device case.
sharding = NamedSharding(mesh, P(*mesh.axis_names))
error_code = jnp.full(
mesh.axis_sizes,
jnp.uint32(_NO_ERROR),
device=sharding,
)
_error_storage.ref = core.mutable_array(error_code)
class error_checking_context:
"""Redefine the error checking state based on the mesh in the context.
This context manager should be used when starting a multi-device
computation, and whenever the mesh is changed.
When exiting the context, the error checking state will be reset to the
original state.
"""
__slots__ = ("old_ref",)
def __init__(self):
self.old_ref = None
def __enter__(self):
self.old_ref = _error_storage.ref
_initialize_error_code_ref()
return self
def __exit__(self, exc_type, exc_value, traceback):
_error_storage.ref = self.old_ref
def set_error_if(pred: jax.Array, /, msg: str) -> None:
"""Set error if any element of pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
In auto mode, this function does not work under jit.
"""
if _error_storage.ref is None:
_initialize_error_code_ref()
assert _error_storage.ref is not None
traceback = source_info_util.current().traceback
assert traceback is not None
with _error_list_lock:
new_error_code = jnp.uint32(len(_error_list))
_error_list.append((msg, traceback))
out_sharding = core.typeof(_error_storage.ref).sharding
in_sharding: NamedSharding = core.typeof(pred).sharding
if out_sharding.mesh.shape_tuple == (): # single-device case.
pred = pred.any()
else: # multi-device case.
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
if has_auto_axes:
raise NotImplementedError(
"Error checking in auto mode is not supported yet. Please use"
" explicit mode."
)
if out_sharding.mesh != in_sharding.mesh:
raise ValueError(
"The error code state and the predicate must be on the same mesh, "
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
"Please use `with error_checking_context()` to redefine the error "
"code state based on the mesh."
)
pred = shard_map(
partial(jnp.any, keepdims=True),
mesh=out_sharding.mesh,
in_specs=in_sharding.spec,
out_specs=out_sharding.spec,
)(pred) # perform per-device reduction
error_code = _error_storage.ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
# TODO(ayx): support vmap and shard_map.
_error_storage.ref[...] = error_code
def raise_if_error() -> None:
"""Raise error if an error is set.
This function should be called after the computation is finished. It should
not be called within a traced context, such as within a jitted function."
"""
if _error_storage.ref is None: # if not initialized, do nothing
return
error_code = _error_storage.ref[...].min() # reduce to a single error code
if isinstance(error_code, core.Tracer):
raise ValueError(
"raise_if_error() should not be called within a traced context, such as"
" within a jitted function."
)
if error_code == jnp.uint32(_NO_ERROR):
return
_error_storage.ref[...] = jnp.full(
_error_storage.ref.shape,
jnp.uint32(_NO_ERROR),
device=_error_storage.ref.sharding,
) # clear the error code
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)
traceback = traceback.as_python_traceback()
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)