mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add JAX error checking support
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs. PiperOrigin-RevId: 726920440
This commit is contained in:
parent
902ebe1bfe
commit
6addf02add
@ -209,6 +209,7 @@ py_library_providing_imports_info(
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/earray.py",
|
||||
"_src/error_check.py",
|
||||
"_src/ffi.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
|
95
jax/_src/error_check.py
Normal file
95
jax/_src/error_check.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
Traceback = source_info_util.Traceback
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
class JaxValueError(ValueError):
|
||||
"""Exception raised for failed runtime error checks in JAX."""
|
||||
|
||||
|
||||
_NO_ERROR = jnp.iinfo(jnp.uint32).max
|
||||
"""The default error code for no error.
|
||||
|
||||
We choose this value because when performing reductions, we can use `min` to
|
||||
obtain the smallest error code.
|
||||
"""
|
||||
|
||||
|
||||
_error_code_ref: core.MutableArray | None = None
|
||||
_error_list_lock = threading.Lock()
|
||||
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair
|
||||
|
||||
|
||||
def _initialize_error_code_ref() -> None:
|
||||
with core.eval_context():
|
||||
global _error_code_ref
|
||||
error_code = jnp.uint32(_NO_ERROR)
|
||||
_error_code_ref = core.mutable_array(error_code)
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
"""Set error if pred is true.
|
||||
|
||||
If the error is already set, the new error will be ignored. It will not
|
||||
override the existing error.
|
||||
"""
|
||||
if _error_code_ref is None:
|
||||
_initialize_error_code_ref()
|
||||
assert _error_code_ref is not None
|
||||
|
||||
traceback = source_info_util.current().traceback
|
||||
assert traceback is not None
|
||||
with _error_list_lock:
|
||||
new_error_code = len(_error_list)
|
||||
_error_list.append((msg, traceback))
|
||||
|
||||
pred = pred.any()
|
||||
error_code = _error_code_ref[...]
|
||||
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
|
||||
error_code = jnp.where(should_update, new_error_code, error_code)
|
||||
# TODO(ayx): support vmap and shard_map.
|
||||
_error_code_ref[...] = error_code # pytype: disable=unsupported-operands
|
||||
|
||||
|
||||
def raise_if_error() -> None:
|
||||
"""Raise error if an error is set."""
|
||||
if _error_code_ref is None: # if not initialized, do nothing
|
||||
return
|
||||
|
||||
error_code = _error_code_ref[...]
|
||||
if error_code == jnp.uint32(_NO_ERROR):
|
||||
return
|
||||
try:
|
||||
msg, traceback = _error_list[error_code]
|
||||
exc = JaxValueError(msg)
|
||||
traceback = traceback.as_python_traceback()
|
||||
filtered_traceback = traceback_util.filter_traceback(traceback)
|
||||
raise exc.with_traceback(filtered_traceback)
|
||||
finally:
|
||||
_error_code_ref[...] = jnp.uint32(_NO_ERROR)
|
@ -1117,6 +1117,11 @@ jax_multiplatform_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "error_check_test",
|
||||
srcs = ["error_check_test.py"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "stax_test",
|
||||
srcs = ["stax_test.py"],
|
||||
|
170
tests/error_check_test.py
Normal file
170
tests/error_check_test.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import error_check
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
JaxValueError = error_check.JaxValueError
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class ErrorCheckTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_error_check(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x = jnp.full((4,), -1, dtype=jnp.int32)
|
||||
f(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_error_check_no_error(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x = jnp.full((4,), 1, dtype=jnp.int32)
|
||||
f(x)
|
||||
error_check.raise_if_error() # should not raise error
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_error_check_should_report_the_first_error(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x >= 1, "x must be less than 1 in f")
|
||||
return x + 1
|
||||
|
||||
def g(x):
|
||||
error_check.set_error_if(x >= 1, "x must be less than 1 in g")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
g = jax.jit(g)
|
||||
|
||||
x = jnp.full((4,), 0, dtype=jnp.int32)
|
||||
|
||||
x = f(x) # check passes, so it should not set error
|
||||
x = g(x) # check fails. so it should set error
|
||||
_ = f(x) # check fails, but should not override the error
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be less than 1 in g"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_raise_if_error_clears_error(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0 in f")
|
||||
return x + 1
|
||||
|
||||
def g(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0 in g")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
g = jax.jit(g)
|
||||
|
||||
x = jnp.full((4,), -1, dtype=jnp.int32)
|
||||
f(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in f"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
error_check.raise_if_error() # should not raise error
|
||||
|
||||
g(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_error_check_works_with_cond(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x == 0, "x must be non-zero in f")
|
||||
return x + 1
|
||||
|
||||
def g(x):
|
||||
error_check.set_error_if(x == 0, "x must be non-zero in g")
|
||||
return x + 1
|
||||
|
||||
def body(pred, x):
|
||||
return jax.lax.cond(pred, f, g, x)
|
||||
|
||||
if jit:
|
||||
body = jax.jit(body)
|
||||
|
||||
x = jnp.zeros((4,), dtype=jnp.int32)
|
||||
|
||||
_ = body(jnp.bool_(True), x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in f"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
_ = body(jnp.bool_(False), x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in g"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_error_check_works_with_while_loop(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x >= 10, "x must be less than 10")
|
||||
return x + 1
|
||||
|
||||
def body(x):
|
||||
return jax.lax.while_loop(lambda x: (x < 10).any(), f, x)
|
||||
|
||||
if jit:
|
||||
body = jax.jit(body)
|
||||
|
||||
x = jnp.arange(4, dtype=jnp.int32)
|
||||
_ = body(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be less than 10"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
def test_error_check_works_with_scan(self):
|
||||
def f(carry, x):
|
||||
error_check.set_error_if(x >= 4, "x must be less than 4")
|
||||
return carry + x, x + 1
|
||||
|
||||
def body(init, xs):
|
||||
return jax.lax.scan(f, init=init, xs=xs)
|
||||
|
||||
init = jnp.int32(0)
|
||||
xs = jnp.arange(5, dtype=jnp.int32)
|
||||
_ = body(init, xs)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be less than 4"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
xs = jnp.arange(4, dtype=jnp.int32)
|
||||
_ = body(init, xs)
|
||||
error_check.raise_if_error() # should not raise error
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user