Add JAX error checking support

In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.

PiperOrigin-RevId: 726920440
This commit is contained in:
Ayaka 2025-02-14 07:27:38 -08:00 committed by jax authors
parent 902ebe1bfe
commit 6addf02add
4 changed files with 271 additions and 0 deletions

View File

@ -209,6 +209,7 @@ py_library_providing_imports_info(
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/error_check.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",

95
jax/_src/error_check.py Normal file
View File

@ -0,0 +1,95 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax.numpy as jnp
Traceback = source_info_util.Traceback
traceback_util.register_exclusion(__file__)
class JaxValueError(ValueError):
"""Exception raised for failed runtime error checks in JAX."""
_NO_ERROR = jnp.iinfo(jnp.uint32).max
"""The default error code for no error.
We choose this value because when performing reductions, we can use `min` to
obtain the smallest error code.
"""
_error_code_ref: core.MutableArray | None = None
_error_list_lock = threading.Lock()
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair
def _initialize_error_code_ref() -> None:
with core.eval_context():
global _error_code_ref
error_code = jnp.uint32(_NO_ERROR)
_error_code_ref = core.mutable_array(error_code)
def set_error_if(pred: jax.Array, msg: str) -> None:
"""Set error if pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
"""
if _error_code_ref is None:
_initialize_error_code_ref()
assert _error_code_ref is not None
traceback = source_info_util.current().traceback
assert traceback is not None
with _error_list_lock:
new_error_code = len(_error_list)
_error_list.append((msg, traceback))
pred = pred.any()
error_code = _error_code_ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
# TODO(ayx): support vmap and shard_map.
_error_code_ref[...] = error_code # pytype: disable=unsupported-operands
def raise_if_error() -> None:
"""Raise error if an error is set."""
if _error_code_ref is None: # if not initialized, do nothing
return
error_code = _error_code_ref[...]
if error_code == jnp.uint32(_NO_ERROR):
return
try:
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)
traceback = traceback.as_python_traceback()
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)
finally:
_error_code_ref[...] = jnp.uint32(_NO_ERROR)

View File

@ -1117,6 +1117,11 @@ jax_multiplatform_test(
},
)
jax_multiplatform_test(
name = "error_check_test",
srcs = ["error_check_test.py"],
)
jax_multiplatform_test(
name = "stax_test",
srcs = ["stax_test.py"],

170
tests/error_check_test.py Normal file
View File

@ -0,0 +1,170 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import error_check
from jax._src import test_util as jtu
import jax.numpy as jnp
JaxValueError = error_check.JaxValueError
config.parse_flags_with_absl()
@jtu.with_config(jax_check_tracer_leaks=True)
class ErrorCheckTests(jtu.JaxTestCase):
@parameterized.product(jit=[True, False])
def test_error_check(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1
if jit:
f = jax.jit(f)
x = jnp.full((4,), -1, dtype=jnp.int32)
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()
@parameterized.product(jit=[True, False])
def test_error_check_no_error(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1
if jit:
f = jax.jit(f)
x = jnp.full((4,), 1, dtype=jnp.int32)
f(x)
error_check.raise_if_error() # should not raise error
@parameterized.product(jit=[True, False])
def test_error_check_should_report_the_first_error(self, jit):
def f(x):
error_check.set_error_if(x >= 1, "x must be less than 1 in f")
return x + 1
def g(x):
error_check.set_error_if(x >= 1, "x must be less than 1 in g")
return x + 1
if jit:
f = jax.jit(f)
g = jax.jit(g)
x = jnp.full((4,), 0, dtype=jnp.int32)
x = f(x) # check passes, so it should not set error
x = g(x) # check fails. so it should set error
_ = f(x) # check fails, but should not override the error
with self.assertRaisesRegex(JaxValueError, "x must be less than 1 in g"):
error_check.raise_if_error()
@parameterized.product(jit=[True, False])
def test_raise_if_error_clears_error(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0 in f")
return x + 1
def g(x):
error_check.set_error_if(x <= 0, "x must be greater than 0 in g")
return x + 1
if jit:
f = jax.jit(f)
g = jax.jit(g)
x = jnp.full((4,), -1, dtype=jnp.int32)
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in f"):
error_check.raise_if_error()
error_check.raise_if_error() # should not raise error
g(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"):
error_check.raise_if_error()
@parameterized.product(jit=[True, False])
def test_error_check_works_with_cond(self, jit):
def f(x):
error_check.set_error_if(x == 0, "x must be non-zero in f")
return x + 1
def g(x):
error_check.set_error_if(x == 0, "x must be non-zero in g")
return x + 1
def body(pred, x):
return jax.lax.cond(pred, f, g, x)
if jit:
body = jax.jit(body)
x = jnp.zeros((4,), dtype=jnp.int32)
_ = body(jnp.bool_(True), x)
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in f"):
error_check.raise_if_error()
_ = body(jnp.bool_(False), x)
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in g"):
error_check.raise_if_error()
@parameterized.product(jit=[True, False])
def test_error_check_works_with_while_loop(self, jit):
def f(x):
error_check.set_error_if(x >= 10, "x must be less than 10")
return x + 1
def body(x):
return jax.lax.while_loop(lambda x: (x < 10).any(), f, x)
if jit:
body = jax.jit(body)
x = jnp.arange(4, dtype=jnp.int32)
_ = body(x)
with self.assertRaisesRegex(JaxValueError, "x must be less than 10"):
error_check.raise_if_error()
def test_error_check_works_with_scan(self):
def f(carry, x):
error_check.set_error_if(x >= 4, "x must be less than 4")
return carry + x, x + 1
def body(init, xs):
return jax.lax.scan(f, init=init, xs=xs)
init = jnp.int32(0)
xs = jnp.arange(5, dtype=jnp.int32)
_ = body(init, xs)
with self.assertRaisesRegex(JaxValueError, "x must be less than 4"):
error_check.raise_if_error()
xs = jnp.arange(4, dtype=jnp.int32)
_ = body(init, xs)
error_check.raise_if_error() # should not raise error
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())