diff --git a/docs/jax.experimental.checkify.rst b/docs/jax.experimental.checkify.rst index dec65c507..ecf253aa6 100644 --- a/docs/jax.experimental.checkify.rst +++ b/docs/jax.experimental.checkify.rst @@ -14,7 +14,7 @@ API check check_error Error - ErrorCategory + JaxRuntimeError user_checks nan_checks index_checks diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 47e05f6df..67b534294 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -12,35 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum -from dataclasses import dataclass -from functools import partial +import dataclasses +import functools import itertools as it -from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable - -import numpy as np - -import jax.numpy as jnp +import types +from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List +import jax from jax import core +from jax import lax from jax import linear_util as lu +from jax._src import prng +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src.config import config +from jax._src.lax import control_flow as cf +from jax._src.sharding import OpShardingSharding +from jax._src.typing import Array +from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, + safe_zip) from jax.api_util import flatten_fun -from jax.experimental import pjit +from jax.api_util import flatten_fun_nokwargs from jax.experimental import maps +from jax.experimental import pjit from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe -from jax._src.sharding import OpShardingSharding -from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node -from jax._src import source_info_util, traceback_util -from jax._src.lax import control_flow as cf -from jax._src.config import config -from jax._src import prng -from jax import lax -from jax._src.typing import Array -from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, - safe_zip) +from jax.tree_util import tree_flatten +from jax.tree_util import tree_map +from jax.tree_util import tree_unflatten +import jax.numpy as jnp +import jax.tree_util as jtu +import numpy as np source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -48,6 +52,11 @@ traceback_util.register_exclusion(__file__) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +Bool = Union[bool, Array] +Int = Union[int, Array] +ErrorCategory = Type['JaxException'] +Payload = List[Union[np.ndarray, Array]] +PyTreeDef = jtu.PyTreeDef ## Utils @@ -61,54 +70,176 @@ def setnewattr(obj, name, val): assert getattr(obj, name, sentinel) is sentinel setattr(obj, name, val) +# Concrete errors -## Error value data type and functional assert. +class JaxException(Exception): + """Python exception which can contain an error message with JAX run-time info.""" -Bool = Union[bool, Array] -Int = Union[int, Array] -Payload = Union[np.ndarray, Array] + def __init__(self, traceback_info): + self.traceback_info = traceback_info + # TODO(lenamartens): re-enable tracebacks when they don't leak tracers. + # self.with_traceback(self.traceback_info) -# For now, the payload needs to be a fixed-size array: 3 int32s, used for the -# OOB message. -# TODO(lenamartens): Relax this fixed-size constraint. -init_payload = np.ones((3,), np.int32) + def __init_subclass__(cls): + jtu.register_pytree_node_class(cls) + + def tree_flatten(self): + return ([], self.traceback_info) + + @classmethod + def tree_unflatten(cls, metadata, payload): + del payload + return cls(metadata) + + def get_effect_type(self) -> core.Effect: + pass -def _format_msg(msg, payloads): - payload_mapping = {} - for i, pl in enumerate(payloads): - payload_mapping[f'payload{i}'] = pl - return msg.format(**payload_mapping) +@functools.total_ordering +@dataclasses.dataclass(eq=True, frozen=True) +class ErrorEffect: + error_type: Type[JaxException] + shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...] + + def __post_init__(self): + cf.allowed_effects.add(self) + mlir.lowerable_effects.add(self) + + def __lt__(self, other: 'ErrorEffect'): + shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable + for sd in x.shape_dtypes) + unpack = lambda x: (str(x.error_type), shape_dtypes(x)) + return (unpack(self) < unpack(other)) -@dataclass(frozen=True) +class DivisionByZeroError(JaxException): + + def __str__(self): + return f'division by zero at {self.traceback_info}' + + def get_effect_type(self): + return ErrorEffect(DivisionByZeroError, ()) + +class NaNError(JaxException): + + def __init__(self, traceback_info, primitive_name): + super().__init__(traceback_info) + self.prim = primitive_name + + def tree_flatten(self): + return ([], (self.traceback_info, self.prim)) + + @classmethod + def tree_unflatten(cls, metadata, _): + return cls(*metadata) + + def get_effect_type(self): + return ErrorEffect(NaNError, ()) + + def __str__(self): + return f'nan generated by primitive: {self.prim} at {self.traceback_info}' + +class OOBError(JaxException): + + def __init__(self, traceback_info, primitive_name, operand_shape, payload): + super().__init__(traceback_info) + self.prim = primitive_name + self.operand_shape = operand_shape + self._payload = payload + + def tree_flatten(self): + return ([self._payload], (self.traceback_info, self.prim, self.operand_shape)) + + @classmethod + def tree_unflatten(cls, metadata, payload): + return cls(*metadata, payload[0]) + + def __str__(self): + return (f'out-of-bounds indexing for array of ' + f'shape {self.operand_shape}: ' + f'index {self._payload[0]} is out of bounds for axis ' + f'{self._payload[1]} with size {self._payload[2]}. ' + f'Failed at {self.traceback_info}') + + def get_effect_type(self): + return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),)) + +class FailedCheckError(JaxException): + + def __init__(self, traceback_info, fmt_string, *a, **k): + super().__init__(traceback_info) + self.fmt_string = fmt_string + self.args = a + self.kwargs = k + + def tree_flatten(self): + return ((jnp.array([], jnp.int32), self.args, self.kwargs), + (self.traceback_info, self.fmt_string)) + + @classmethod + def tree_unflatten(cls, metadata, payload): + _, args, kwargs = payload + return cls(*metadata, *args, **kwargs) + + def __str__(self): + return (self.fmt_string.format(*self.args, **self.kwargs) + + f' (check failed at {self.traceback_info})') + + def get_effect_type(self): + vals = jtu.tree_leaves((self.args, self.kwargs)) + return ErrorEffect( + FailedCheckError, + # Need a 0-size array here for data-dependence. + (jax.ShapeDtypeStruct((0,), jnp.int32), + *tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))) + +@dataclasses.dataclass +class BatchedError(JaxException): + error_mapping: Dict[Tuple[int, ...], JaxException] + + def __post_init__(self): + traceback_info = list(self.error_mapping.values())[0].traceback_info + super().__init__(traceback_info) + + + def __str__(self): + return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}' + for idx, e in self.error_mapping.items()) + + +# Error Value + +@jtu.register_pytree_node_class +@dataclasses.dataclass(frozen=True) class Error: - err: Bool - code: Int - msgs: Dict[int, str] - # There might be many msgs with a {payload}, but only one msg will - # ever be active for an Error instance, so only one Payload is tracked. - payload: Payload - - def __init__(self, err: Bool, code: Int, msgs: Dict[int, str], payload: Optional[Payload] = None): - # We can't directly assign to members of a frozen dataclass, even in __init__. - object.__setattr__(self, "err", err) - object.__setattr__(self, "code", code) - object.__setattr__(self, "msgs", msgs) - object.__setattr__(self, "payload", - init_payload if payload is None else payload) + _pred: Dict[ErrorEffect, Bool] + _code: Dict[ErrorEffect, Int] + _metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef. + _payload: Dict[ErrorEffect, Payload] def get(self) -> Optional[str]: - """Returns error message is error happened, None if no error happened.""" - assert np.shape(self.err) == np.shape(self.code) - if np.size(self.err) == 1: - if self.err: - return _format_msg(self.msgs[int(self.code)], self.payload) + """Returns error message if error happened, None if no error happened.""" + exp = self.get_exception() + if exp is not None: + return str(exp) + return None + + def get_exception(self) -> Optional[JaxException]: + """Returns Python exception if error happened, None if no error happened.""" + if any(map(np.shape, self._pred.values())): + return self._get_batched_exception() else: - return '\n'.join( - f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore - f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore - for idx, e in np.ndenumerate(self.err) if e) or None + min_code = None + cur_effect = None + for error_effect, code in self._code.items(): + if self._pred[error_effect]: + if min_code is None or code < min_code: + min_code = code + cur_effect = error_effect + + if cur_effect is not None: + return tree_unflatten(self._metadata[int(min_code)], # type: ignore + self._payload[cur_effect]) return None def throw(self): @@ -117,31 +248,80 @@ class Error: def __str__(self): return f'Error({self.get()})' + # Internal helpers -def raise_error(error): - err = error.get() - if err: - raise ValueError(err) + def _get_batched_exception(self): + shape = np.shape(list(self._pred.values())[0]) + error_mapping = {} + for idx in np.ndindex(*shape): + min_code = None + cur_effect = None + for error_effect, code in self._code.items(): + if self._pred[error_effect][idx]: # type: ignore + if min_code is None or code[idx] < min_code: + min_code = code[idx] # type: ignore + cur_effect = error_effect + if cur_effect is not None: + payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect]) + jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore + error_mapping[idx] = jax_error + return BatchedError(error_mapping) -register_pytree_node(Error, - lambda e: ((e.err, e.code, e.payload), - tuple(sorted(e.msgs.items()))), - lambda msgs, data: Error(data[0], data[1], # type: ignore - dict(msgs), data[2])) # type: ignore + def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload): + new_errs = {**self._pred, **{effect_type: pred}} # type: ignore + new_codes = {**self._code, **{effect_type: code}} # type: ignore + new_payload = {**self._payload, **{effect_type: payload}} # type: ignore + new_metadata = {**self._metadata, **metadata} + return Error(new_errs, new_codes, new_metadata, new_payload) -init_error = Error(False, 0, {}) + def _add_placeholder_effects(self, effects: Set[ErrorEffect]): + """Fill out Error with `effects` and np.ones arrays of their payloads.""" + new_err = self._pred.copy() + new_code = self._code.copy() + new_payload = self._payload.copy() + for effect in effects: + if effect not in self._pred.keys(): + new_err[effect] = False + new_payload[effect] = list( + tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes)) + # The error value associated with this effect will never become True, so + # we don't need to set a meaningful code. + new_code[effect] = -1 + return Error(new_err, new_code, self._metadata, new_payload) + + def _replace(self, *args, **kwargs): + return dataclasses.replace(self, *args, **kwargs) + + # PyTree methods + + def tree_flatten(self): + return ((self._pred, self._code, self._payload), (self._metadata)) + + @classmethod + def tree_unflatten(cls, metadata, data): + pred, code, payload = data + return cls(pred, code, metadata, payload) + +init_error = Error({}, {}, {}, {}) # value used as initial (empty) error. next_code = it.count(1).__next__ # globally unique ids, could be uuid4 - -def assert_func(error: Error, err: Bool, msg: str, - payload: Optional[Payload]) -> Error: +def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error: code = next_code() - payload = init_payload if payload is None else payload - out_err = error.err | err - out_code = lax.select(error.err, error.code, code) - out_payload = lax.select(error.err, error.payload, payload) - return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload) + effect_type = new_error.get_effect_type() + new_payload, new_metadata = tree_flatten(new_error) + return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type) + +def update_error(error, pred, code, metadata, payload, effect_type): + err_of_type = error._pred.get(effect_type, False) + out_err = err_of_type | pred + out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code) + cur_payload = error._payload.get(effect_type, None) + if cur_payload is not None: + out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload) + else: + out_payload = payload + return error._update(effect_type, out_err, out_code, metadata, out_payload) ## Checkify transformation for plumbing functional error values. @@ -182,142 +362,179 @@ class CheckifyTrace(core.Trace): def process_call(self, primitive, f, tracers, params): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items())) + flat_vals, in_tree = tree_flatten((e, *in_vals)) + f = checkify_subtrace(f, self.main) + f, out_tree = flatten_fun_nokwargs(f, in_tree) if 'donated_invars' in params: - params = dict(params, donated_invars=(False, False, False, + params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)), *params['donated_invars'])) - err, code, payload, *out_vals = primitive.bind(f, e.err, e.code, e.payload, - *in_vals, **params) - setnewattr(self.main, 'error', Error(err, code, msgs(), payload)) + all_vals = primitive.bind(f, *flat_vals, **params) + error, *out_vals = tree_unflatten(out_tree(), all_vals) + setnewattr(self.main, 'error', error) return [CheckifyTracer(self, x) for x in out_vals] def process_map(self, primitive, f, tracers, params): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items())) + flat_vals, in_tree = tree_flatten((e, *in_vals)) + num_error_vals = len(jtu.tree_leaves(e)) + f = checkify_subtrace(f, self.main) + f, out_tree = flatten_fun_nokwargs(f, in_tree) @as_hashable_function(closure=params['out_axes_thunk']) def new_out_axes_thunk(): - return (0, 0, 0, *params['out_axes_thunk']()) + out_val_axes = params['out_axes_thunk']() + out_err_num = out_tree().num_leaves - len(out_val_axes) + return (*(0,)*out_err_num, *out_val_axes) - params_ = dict(params, in_axes=(None, None, None, *params['in_axes']), + params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']), out_axes_thunk=new_out_axes_thunk, - donated_invars=(False, False, False, *params['donated_invars'])) - errs, codes, payloads, *outs = primitive.bind(f, e.err, e.code, e.payload, - *in_vals, **params_) - err, code, payload = _reduce_any_error(errs, codes, payloads) - setnewattr(self.main, 'error', Error(err, code, msgs(), payload)) - return [CheckifyTracer(self, x) for x in outs] + donated_invars=(*(False,)*num_error_vals, *params['donated_invars'])) + all_vals = primitive.bind(f, *flat_vals, **params_) + error, *out_vals = tree_unflatten(out_tree(), all_vals) + error = _reduce_any_error(error) + setnewattr(self.main, 'error', error) + return [CheckifyTracer(self, x) for x in out_vals] def post_process_call(self, primitive, tracers, params): vals = [t.val for t in tracers] main = self.main e = popattr(main, 'error') - err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs + err_leaves, err_tree = tree_flatten(e) + setnewattr(main, 'err_tree', err_tree) def todo(vals): - err, code, payload, *vals = vals - setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload)) + err_tree = popattr(main, 'err_tree') + err_vals, vals = split_list(vals, [err_tree.num_leaves]) + setnewattr(main, 'error', tree_unflatten(err_tree, err_vals)) trace = main.with_cur_sublevel() return [CheckifyTracer(trace, x) for x in vals] - return (err, code, payload, *vals), todo + return (*err_leaves, *vals), todo def post_process_map(self, primitive, tracers, params): vals = [t.val for t in tracers] main = self.main e = popattr(main, 'error') - err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs + err_leaves, err_tree = tree_flatten(e) + num_err_leaves = len(err_leaves) + setnewattr(main, 'err_tree', err_tree) def todo(vals): - errs, codes, payloads, *vals = vals - err, code, payload = _reduce_any_error(errs, codes, payloads) - setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload)) + err_tree = popattr(main, 'err_tree') + err_vals, vals = split_list(vals, [err_tree.num_leaves]) + error = tree_unflatten(err_tree, err_vals) + error = _reduce_any_error(error) + setnewattr(main, 'error', error) trace = main.with_cur_sublevel() return [CheckifyTracer(trace, x) for x in vals] def out_axes_transform(out_axes): - return (0, 0, 0, *out_axes) - return (err, code, payload, *vals), (todo, out_axes_transform) + return (*(0,)*num_err_leaves, *out_axes) + return (*err_leaves, *vals), (todo, out_axes_transform) - def process_custom_jvp_call(self, prim, fun, jvp, tracers): + def process_custom_jvp_call(self, prim, f, jvp, tracers): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - msgs = tuple(e.msgs.items()) - fun, msgs1 = checkify_subtrace(fun, self.main, msgs) - jvp, msgs2 = checkify_custom_jvp_subtrace(jvp, self.main, msgs) - err, code, payload, *out_vals = prim.bind(fun, jvp, e.err, e.code, - e.payload, *in_vals) - fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2) - setattr(self.main, 'error', Error(err, code, out_msgs, payload)) + err_vals, err_tree = tree_flatten(e) + flat_vals, in_tree = tree_flatten((e, *in_vals)) + num_error_vals = len(err_vals) + f = checkify_subtrace(f, self.main) + f, f_out_tree = flatten_fun_nokwargs(f, in_tree) + jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main, + num_error_vals, err_tree) + all_outs = prim.bind(f, jvp, *flat_vals) + fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree) + if fst: + out_err, *out_vals = tree_unflatten(out_tree, all_outs) + else: + err_vals, out_vals = split_list(all_outs, [num_error_vals]) + # forward input error values to output + out_err = tree_unflatten(out_tree, err_vals) + setattr(self.main, 'error', out_err) return [CheckifyTracer(self, x) for x in out_vals] - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): + def post_process_custom_jvp_call(self, tracers, jvp_was_run): if jvp_was_run: - msg = ("support for custom_jvp rules which close over checkify values is " - "not implemented. If you see this, open an issue at " - "https://github.com/google/jax/issues!") + msg = ('support for custom_jvp rules which close over checkify values is ' + 'not implemented. If you see this, open an issue at ' + 'https://github.com/google/jax/issues!') raise NotImplementedError(msg) - vals = [t.val for t in out_tracers] + vals = [t.val for t in tracers] main = self.main e = popattr(main, 'error') - err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs + err_leaves, err_tree = tree_flatten(e) def todo(vals): - err, code, payload, *vals = vals - setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload)) + err_vals, vals = split_list(vals, [len(err_leaves)]) + setnewattr(main, 'error', tree_unflatten(err_tree, err_vals)) trace = main.with_cur_sublevel() return [CheckifyTracer(trace, x) for x in vals] - return (err, code, payload, *vals), todo + return (*err_leaves, *vals), todo def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - msgs = tuple(e.msgs.items()) - fun, msgs1 = checkify_subtrace(fun, self.main, msgs) - fwd, msgs2 = checkify_custom_vjp_subtrace(fwd, self.main, msgs) - out = prim.bind(fun, fwd, bwd, e.err, e.code, e.payload, - *in_vals, out_trees=out_trees) - fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2) + err_vals, err_tree = tree_flatten(e) + flat_vals, in_tree = tree_flatten((e, *in_vals)) + num_error_vals = len(err_vals) + + fun = checkify_subtrace(fun, self.main) + fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree) + fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main, + err_tree, num_error_vals) + + all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees) + fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree) if fst: - err, code, payload, *out = out + error, *out = tree_unflatten(out_tree, all_out_vals) else: - err, code, payload = e.err, e.code, e.payload # forward input error values to output - setattr(self.main, 'error', Error(err, code, out_msgs, payload)) + _, out = split_list(all_out_vals, [num_error_vals]) + # forward input error values to output + error = tree_unflatten(err_tree, err_vals) + setattr(self.main, 'error', error) return [CheckifyTracer(self, x) for x in out] -def _reduce_any_error(errs, codes, payloads): - reduced_idx = jnp.argsort(errs)[-1] - return errs[reduced_idx], codes[reduced_idx], payloads[reduced_idx] +def _reduce_any_error(error: Error): + out_error = init_error + for error_effect in error._pred.keys(): + errs, codes, payloads = (error._pred[error_effect], + error._code[error_effect], + error._payload[error_effect]) + reduced_idx = jnp.argsort(errs)[-1] + pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx], + (errs, codes, payloads)) + out_error = out_error._update(error_effect, pred, code, {}, payload) + + out_error = out_error._replace(_metadata=error._metadata) + return out_error ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: Dict[core.Primitive, ErrorCheckRule] = {} def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args): - fun, msgs = checkify_subtrace(fun) - fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors) - err, code, payload, *outvals = fun.call_wrapped(init_error.err, - init_error.code, - init_error.payload, *args) - return (err, code, payload, outvals), msgs() + fun = checkify_subtrace(fun) + fun = checkify_traceable(fun, enabled_errors) + error, *outvals = fun.call_wrapped(init_error, *args) + return error, outvals @lu.transformation -def checkify_traceable(msgs, enabled_errors, err, code, payload, *args): +def checkify_traceable(enabled_errors, error, *args): with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main: - outs = yield (main, msgs, err, code, payload, *args), {} + outs = yield (main, error, *args), {} del main yield outs -@lu.transformation_with_aux -def checkify_subtrace(main, msgs, err, code, payload, *args): - setnewattr(main, 'error', Error(err, code, dict(msgs), payload)) +@lu.transformation +def checkify_subtrace(main, error, *args): + setnewattr(main, 'error', error) trace = main.with_cur_sublevel() in_tracers = [CheckifyTracer(trace, x) for x in args] out = yield in_tracers, {} out_tracers = map(trace.full_raise, out) out_vals = [t.val for t in out_tracers] - err, code, payload, msgs = main.error.err, main.error.code, main.error.payload, main.error.msgs + error = main.error del main.error - yield (err, code, payload, *out_vals), msgs + yield (error, *out_vals) @lu.transformation_with_aux -def checkify_custom_jvp_subtrace(main, msgs, *args): +def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args): # Like checkify_subtrace, but used specifically on the custom JVP rules # associated with a custom_jvp. This code is called in the context of a # jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs, @@ -336,41 +553,52 @@ def checkify_custom_jvp_subtrace(main, msgs, *args): del main n, ragged = divmod(len(args), 2) assert not ragged - (err,), (code,), (payload,), primals = split_list(args[:n], [1, 1, 1]) - (err_dot,), (code_dot,), (pl_dot,), tangents = split_list(args[n:], [1, 1, 1]) + err_primals, primals = split_list(args[:n], [num_error_vals]) + err_tangents, tangents = split_list(args[n:], [num_error_vals]) outs = yield (*primals, *tangents), {} m, ragged = divmod(len(outs), 2) assert not ragged out_primals, out_tangents = outs[:m], outs[m:] - yield (err, code, payload, *out_primals, - err_dot, code_dot, pl_dot, *out_tangents), dict(msgs) + yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree @lu.transformation_with_aux -def checkify_custom_vjp_subtrace(main, msgs, err, code, payload, *args): +def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args): + del main # We don't add any checks; just drop input error values. - del main, err, code, payload + _, args = split_list(args, [num_error_vals]) outs = yield args, {} - yield outs, dict(msgs) + yield outs, err_tree -# TODO take (error_aval, code_aval) instead of error here? -def checkify_jaxpr(jaxpr, error, enabled_errors): +@lu.transformation_with_aux +def query_error_effects(*args): + (error, *outs) = yield args, {} + yield (error, *outs), set(error._pred.keys()) + +def checkify_jaxpr(jaxpr, error, + enabled_errors) -> Tuple[core.ClosedJaxpr, + Tuple[PyTreeDef, + FrozenSet[ErrorEffect]]]: f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals) -def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): - f, msgs = checkify_subtrace(f) - f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors) - err_aval = core.raise_to_shaped(core.get_aval(error.err)) - code_aval = core.raise_to_shaped(core.get_aval(error.code)) - payload_aval = core.raise_to_shaped(core.get_aval(error.payload)) - avals_in = [err_aval, code_aval, payload_aval, *in_avals] +def checkify_fun_to_jaxpr( + f, error, enabled_errors, + in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]: + flat_error_vals, in_tree = tree_flatten(error) + f = checkify_subtrace(f) + f = checkify_traceable(f, enabled_errors) + f, error_effect = query_error_effects(f) + in_tree = jtu.tree_structure((error, *in_avals)) + f, out_tree = flatten_fun_nokwargs(f, in_tree) + err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), + flat_error_vals) + avals_in = [*err_vals, *in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) - return core.ClosedJaxpr(jaxpr_out, literals_out), msgs() + return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect())) -## assert primitive -def check(pred: Bool, msg: str) -> None: +def check(pred: Bool, msg: str, *args, **kwargs) -> None: """Check a predicate, add an error with msg if predicate is False. This is an effectful operation, and can't be staged (jitted/scanned/...). @@ -393,14 +621,14 @@ def check(pred: Bool, msg: str) -> None: >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: cannot be zero! (check failed at ...) + jax._src.checkify.JaxRuntimeError: cannot be zero! """ if not is_scalar_pred(pred): raise TypeError(f'check takes a scalar pred as argument, got {pred}') - code = next_code() - msg += f' (check failed at {summary()})' - return check_error(Error(jnp.logical_not(pred), code, {code: msg})) + new_error = FailedCheckError(summary(), msg, *args, **kwargs) + error = assert_func(init_error, jnp.logical_not(pred), new_error) + return check_error(error) def is_scalar_pred(pred) -> bool: return (isinstance(pred, bool) or @@ -464,7 +692,7 @@ def check_error(error: Error) -> None: >>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: must be positive! + jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ @@ -472,27 +700,33 @@ def check_error(error: Error) -> None: raise ValueError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') - if np.shape(error.err): - err, code, payload = _reduce_any_error(error.err, error.code, error.payload) - else: - err, code, payload = error.err, error.code, error.payload + error = tree_map(core.raise_as_much_as_possible, error) + if any(map(np.shape, error._pred.values())): + error = _reduce_any_error(error) + err_args, tree_def = tree_flatten(error) - err = core.raise_as_much_as_possible(err) - return assert_p.bind(err, code, payload, msgs=error.msgs) + return check_p.bind(*err_args, err_tree=tree_def) -assert_p = core.Primitive('assert') # TODO: rename to check? -assert_p.multiple_results = True # zero results +## check primitive -@assert_p.def_impl -def assert_impl(err, code, payload, *, msgs): - raise_error(Error(err, code, msgs, payload)) +check_p = core.Primitive('check') +check_p.multiple_results = True # zero results + +# TODO(lenamartens): inherit from Exception instead of ValueError. +class JaxRuntimeError(ValueError): + pass + +@check_p.def_impl +def check_impl(*args, err_tree): + error = tree_unflatten(err_tree, args) + exc = error.get_exception() + if exc: + raise JaxRuntimeError(str(exc)) from exc return [] -CheckEffect = object() - -@assert_p.def_effectful_abstract_eval -def assert_abstract_eval(err, code, payload, *, msgs): - return [], {CheckEffect} +@check_p.def_effectful_abstract_eval +def check_abstract_eval(*args, err_tree): + return [], set(tree_unflatten(err_tree, args)._pred.keys()) # TODO(lenamartens) add in-depth error explanation to link to in module docs. functionalization_error = ValueError( @@ -502,68 +736,80 @@ functionalization_error = ValueError( ' through `checkify.checkify`.' ) -def python_err(msgs, err, code, payload): - error = Error(err, code, msgs, payload) - check_error(error) - return [] - -def assert_lowering_rule(ctx, err, code, payload, *, msgs): +def check_lowering_rule(ctx, *args, err_tree): if not config.jax_experimental_unsafe_xla_runtime_errors: raise functionalization_error - out_op, token_out, keep_alive = mlir.emit_python_callback( - ctx, callback=lambda *a: python_err(msgs, *a), - token=ctx.tokens_in.get(CheckEffect)[0], - operands=[err, code, payload], + out_op, _, keep_alive = mlir.emit_python_callback( + ctx, callback=functools.partial(python_err, err_tree), + token=None, + operands=args, operand_avals=list(ctx.avals_in), result_avals=list(ctx.avals_out), has_side_effect=True) - ctx.set_tokens_out(ctx.tokens_in.update_tokens( - mlir.TokenSet({CheckEffect: token_out}))) ctx.module_context.add_keepalive(keep_alive) return out_op -def assert_lowering_rule_unsupported(*a, **k): +def check_lowering_rule_unsupported(*a, **k): raise functionalization_error -mlir.register_lowering(assert_p, assert_lowering_rule_unsupported, +def python_err(err_tree, *args): + error = tree_unflatten(err_tree, args) + check_error(error) + return [] + +mlir.register_lowering(check_p, check_lowering_rule_unsupported, platform='tpu') -mlir.register_lowering(assert_p, assert_lowering_rule, +mlir.register_lowering(check_p, check_lowering_rule, platform='cpu') -mlir.register_lowering(assert_p, assert_lowering_rule, +mlir.register_lowering(check_p, check_lowering_rule, platform='gpu') -mlir.lowerable_effects.add(CheckEffect) -cf.allowed_effects.add(CheckEffect) -core.ordered_effects.add(CheckEffect) - -def assert_batching_rule(batched_args, batch_dims, *, msgs): +def check_batching_rule(batched_args, batch_dims, *, err_tree): size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) if dim is not batching.not_mapped) - err, code, payload = (batching.bdim_at_front(a, d, size) - for a, d in zip(batched_args, batch_dims)) - err = Error(err, code, msgs, payload) + batched_args = (batching.bdim_at_front(a, d, size) + for a, d in zip(batched_args, batch_dims)) + err = tree_unflatten(err_tree, batched_args) check_error(err) return [], [] -batching.primitive_batchers[assert_p] = assert_batching_rule +batching.primitive_batchers[check_p] = check_batching_rule -def assert_jvp_rule(primals, _, *, msgs): +def check_jvp_rule(primals, _, *, err_tree): # Check primals, discard tangents. - assert_p.bind(*primals, msgs=msgs) + check_p.bind(*primals, err_tree=err_tree) return [], [] -ad.primitive_jvps[assert_p] = assert_jvp_rule +ad.primitive_jvps[check_p] = check_jvp_rule ## checkify rules +def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]: + # TODO(lenamartens): use c++ version from XLA? + tb = None + import inspect + for frame_info in inspect.stack(): + frame = frame_info.frame + if skip_frames: + skip_frames -= 1 + elif not traceback_util.include_frame(frame): + continue + else: + tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + return tb + def summary() -> str: return str(source_info_util.summarize(source_info_util.current())) def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) - if ErrorCategory.NAN not in enabled_errors: - return out, error + err = check_nans(prim, error, enabled_errors, out) + return out, err + +def check_nans(prim, error, enabled_errors, out): + if NaNError not in enabled_errors: + return error def isnan(x): if isinstance(x, prng.PRNGKeyArray): @@ -572,8 +818,7 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params): any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) if prim.multiple_results else isnan(out)) - msg = f'nan generated by primitive {prim.name} at {summary()}' - return out, assert_func(error, any_nans, msg, None) + return assert_func(error, any_nans, NaNError(summary(), prim.name)) # All primitives which can generate a NaN. @@ -594,7 +839,7 @@ nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] for prim in nan_primitives: - error_checks[prim] = partial(nan_error_check, prim) + error_checks[prim] = functools.partial(nan_error_check, prim) def gather_error_check(error, enabled_errors, operand, start_indices, *, @@ -605,7 +850,7 @@ def gather_error_check(error, enabled_errors, operand, start_indices, *, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) - if ErrorCategory.OOB not in enabled_errors: + if OOBError not in enabled_errors: return out, error # compare to OOB masking logic in lax._gather_translation_rule @@ -616,33 +861,30 @@ def gather_error_check(error, enabled_errors, operand, start_indices, *, upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) - out_of_bounds = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) + oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) - # Get first OOB index, axis and axis size so it can be added to the error msg. - flat_idx = jnp.argmin(jnp.logical_not(out_of_bounds)) - multi_idx = jnp.unravel_index(flat_idx, start_indices.shape) - oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]] - oob_axis_size = jnp.array(operand.shape)[oob_axis] - oob_index = jnp.ravel(start_indices)[flat_idx] - payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) - - msg = (f'out-of-bounds indexing at {summary()} for array of ' - f'shape {operand.shape}: ' - 'index {payload0} is out of bounds for axis {payload1} ' - 'with size {payload2}.') - - return out, assert_func(error, jnp.any(out_of_bounds), msg, payload) + payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) + return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload)) error_checks[lax.gather_p] = gather_error_check def div_error_check(error, enabled_errors, x, y): """Checks for division by zero and NaN.""" - if ErrorCategory.DIV in enabled_errors: + if DivisionByZeroError in enabled_errors: any_zero = jnp.any(jnp.equal(y, 0)) - msg = f'division by zero at {summary()}' - error = assert_func(error, any_zero, msg, None) + error = assert_func(error, any_zero, DivisionByZeroError(summary())) return nan_error_check(lax.div_p, error, enabled_errors, x, y) error_checks[lax.div_p] = div_error_check +def oob_payload(oob_mask, indices, dims_map, operand_shape): + # Get first OOB index, axis and axis size so it can be added to the error msg. + flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) + multi_idx = jnp.unravel_index(flat_idx, indices.shape) + oob_axis = jnp.array(dims_map)[multi_idx[-1]] + oob_axis_size = jnp.array(operand_shape)[oob_axis] + oob_index = jnp.ravel(indices)[flat_idx] + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + return payload + def scatter_oob(operand, indices, updates, dnums): # Ref: see clamping code used in scatter_translation_rule slice_sizes = [] @@ -661,9 +903,12 @@ def scatter_oob(operand, indices, updates, dnums): upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) - lower_oob = jnp.any(jnp.less(indices, 0)) - upper_oob = jnp.any(jnp.greater(indices, upper_bound.astype(indices.dtype))) - return jnp.logical_or(lower_oob, upper_oob) + lower_oob = jnp.less(indices, 0) + upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) + oob_mask = jnp.logical_or(lower_oob, upper_oob) + payload = oob_payload(oob_mask, indices, + dnums.scatter_dims_to_operand_dims, operand.shape) + return jnp.any(oob_mask), payload def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, @@ -675,47 +920,71 @@ def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - if ErrorCategory.OOB not in enabled_errors: + if OOBError not in enabled_errors: return out, error - out_of_bounds = scatter_oob(operand, indices, updates, dimension_numbers) - oob_msg = f'out-of-bounds indexing while updating at {summary()}' - oob_error = assert_func(error, out_of_bounds, oob_msg, None) - - any_nans = jnp.any(jnp.isnan(out)) - nan_msg = f'nan generated by primitive {prim.name} at {summary()}' - return out, assert_func(oob_error, any_nans, nan_msg, None) -error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p) -error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p) -error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p) -error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p) -error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p) + out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) + oob_error = OOBError(summary(), prim.name, operand.shape, payload) + error = assert_func(error, out_of_bounds, oob_error) + return out, check_nans(prim, error, enabled_errors, out) +error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) +error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, + lax.scatter_add_p) +error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, + lax.scatter_mul_p) +error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, + lax.scatter_min_p) +error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, + lax.scatter_max_p) def cond_error_check(error, enabled_errors, index, *ops, branches, linear): - new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) - for jxpr in branches) - new_linear = (False, False, False, *linear) - err, code, payload, *outs = lax.cond_p.bind( - index, error.err, error.code, error.payload, *ops, + _, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error, + enabled_errors) + for jxpr in branches) + _, effects = unzip2(out_trees_and_effects) + + merged_error = error._add_placeholder_effects(set().union(*effects)) + new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error, + enabled_errors) + for jxpr in branches) + out_trees, _ = unzip2(out_trees_and_effects) + + flat_error, _ = tree_flatten(merged_error) + new_linear = (*[False] * len(flat_error), *linear) + err_and_outs = lax.cond_p.bind( + index, *flat_error, *ops, branches=tuple(new_branches), linear=new_linear) - new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()} - return outs, Error(err, code, new_msgs, payload) + + # we need to merge metadata across out_trees (a tuple) + # maybe there's a better way to do this, but we can use the outs + # to unflatten all trees. + err0, *out = tree_unflatten(out_trees[0], err_and_outs) + merged_metadata = err0._metadata + for tr in out_trees[1:]: + err, *_ = tree_unflatten(tr, err_and_outs) + merged_metadata = {**merged_metadata, **err._metadata} + return out, err0._replace(_metadata=merged_metadata) error_checks[lax.cond_p] = cond_error_check def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) - checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors) - tomove = [False] * 3 + [True] * len(consts) + [False] * (len(carry) + len(xs)) + _, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors) + merged_error = error._add_placeholder_effects(effects) + checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors) + + flat_error_vals, _ = tree_flatten(merged_error) + tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs)) checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) - new_linear = (False, False, False, *linear) - new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs] - err, code, payload, *outs = lax.scan_p.bind( + new_linear = (*[False] * len(flat_error_vals), *linear) + new_in_flat = [*consts, *flat_error_vals, *carry, *xs] + err_and_out = lax.scan_p.bind( *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, - num_consts=len(consts), num_carry=len(carry)+3, + num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals), linear=new_linear, unroll=unroll) - new_msgs = {**error.msgs, **msgs_} - return outs, Error(err, code, new_msgs, payload) + err, *out = tree_unflatten(out_tree, err_and_out) + return out, err + error_checks[lax.scan_p] = scan_error_check def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts): @@ -729,11 +998,11 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_c return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals) -def ignore_error_output_jaxpr(jaxpr): +def ignore_error_output_jaxpr(jaxpr, num_error_vals): """Constructs a checked jaxpr which does not output its error value.""" consts = jaxpr.consts jaxpr = jaxpr.jaxpr - new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:]) + new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) return core.ClosedJaxpr(new_jaxpr, consts) def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, @@ -744,32 +1013,40 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, '(checkify-of-vmap-of-while). \nHint: if possible, move ' 'the vmap to the outer level to get ' 'vmap-of-checkify-of-while.') - err_args = [error.err, error.code, error.payload] + err_vals, _ = tree_flatten(error) c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. - checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, - enabled_errors) - cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)( - *err_args, *c_consts, *carry) + checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr( + cond_jaxpr, error, enabled_errors) + outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry) + error, _ = tree_unflatten(cond_out_tree, outs) - checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr( + checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) - to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry) + # merged error! + error = error._add_placeholder_effects(error_effects) + checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) + err_vals = jtu.tree_leaves(error) + num_error_vals = len(err_vals) + to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) - compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr) - to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry) + checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors) + compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) + to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) - new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry] + new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] - err, code, payload, *out = lax.while_p.bind( + all_out_vals = lax.while_p.bind( *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) - new_msgs = {**error.msgs, **msgs_body, **msgs_cond} - - return out, Error(err, code, new_msgs, payload) + # body_out_tree will have all the metadata of cond because it executes a cond! + # only need to merge metadata on the input error. + error, *out = tree_unflatten(body_out_tree, all_out_vals) + return out, error error_checks[lax.while_p] = while_loop_error_check @@ -777,13 +1054,19 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): - checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors) - new_vals_in = [error.err, error.code, error.payload, *vals_in] + checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error, + enabled_errors) + out_error = error._add_placeholder_effects(effects) + + flat_error_vals = jtu.tree_leaves(error) + num_error_vals = len(flat_error_vals) + new_vals_in = [*flat_error_vals, *vals_in] sharding = OpShardingSharding.get_replicated( list(resource_env.physical_mesh.devices.flat)) - new_in_shardings = (*[sharding] * 3, *in_shardings) - new_out_shardings = (*[sharding] * 3, *out_shardings) + new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) + new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)), + *out_shardings) if config.jax_array: pos_sem = maps._PositionalSemantics.GLOBAL @@ -794,11 +1077,13 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_positional_semantics = (in_positional_semantics,) if not isinstance(out_positional_semantics, Iterable): out_positional_semantics = (out_positional_semantics,) - new_positional_sems_in = (*[pos_sem] * 3, *in_positional_semantics) - new_positional_sems_out = (*[pos_sem] * 3, *out_positional_semantics) - new_donated_invars = (*[False] * 3, *donated_invars) + new_positional_sems_in = (*[pos_sem] * num_error_vals, + *in_positional_semantics) + new_positional_sems_out = (*[pos_sem] * num_error_vals, + *out_positional_semantics) + new_donated_invars = (*[False] * num_error_vals, *donated_invars) - err, code, payload, *vals_out = pjit.pjit_p.bind( + err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, @@ -808,29 +1093,45 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, name=name, in_positional_semantics=new_positional_sems_in, out_positional_semantics=new_positional_sems_out) - return vals_out, Error(err, code, msgs, payload) + err, *out = tree_unflatten(out_tree, err_and_out) + return out, err error_checks[pjit.pjit_p] = pjit_error_check +def check_discharge_rule(error, enabled_errors, *args, err_tree): + new_error = tree_unflatten(err_tree, args) + # Split up new_error into error to be functionalized if it's included in + # enabled_errors (=discharged_error) and an error to be defunctionalized if + # it's not included (=recharged_error) + discharged_error = error + recharged_error = init_error + for error_effect in new_error._pred.keys(): + pred = new_error._pred[error_effect] + code = new_error._code[error_effect] + payload = new_error._payload[error_effect] + if error_effect.error_type in enabled_errors: + discharged_error = update_error(discharged_error, pred, code, {}, payload, + error_effect) + else: + recharged_error = update_error(recharged_error, pred, code, {}, payload, + error_effect) -def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs): - if ErrorCategory.USER_CHECK not in enabled_errors: - return [], error - - out_err = error.err | err - out_code = lax.select(error.err, error.code, code) - return [], Error(out_err, out_code, {**error.msgs, **msgs}, payload) -error_checks[assert_p] = assert_discharge_rule + discharged_error = discharged_error._replace( + _metadata={**new_error._metadata, **discharged_error._metadata}) + recharged_error = recharged_error._replace(_metadata=new_error._metadata) + # TODO(lenamartens): we actually need to recharge, but this would be a + # breaking API change so leaving for a follow-up. + # check_error(recharged_error) + return [], discharged_error +error_checks[check_p] = check_discharge_rule ## checkify api -ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK']) - -user_checks = frozenset({ErrorCategory.USER_CHECK}) -nan_checks = frozenset({ErrorCategory.NAN}) -index_checks = frozenset({ErrorCategory.OOB}) -div_checks = frozenset({ErrorCategory.DIV}) +user_checks = frozenset({FailedCheckError}) +nan_checks = frozenset({NaNError}) +index_checks = frozenset({OOBError}) +div_checks = frozenset({DivisionByZeroError}) float_checks = nan_checks | div_checks automatic_checks = float_checks | index_checks all_checks = automatic_checks | user_checks @@ -863,9 +1164,9 @@ def checkify(fun: Callable[..., Out], - ``div_checks``: a division by zero. - ``index_checks``: an index was out-of-bounds. - Multiple categories can be enabled together by creating a `Set` (eg. - ``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be - re-combined (eg. ``errors=float_checks|user_checks``) + Multiple categories can be enabled together by passing in an error `Set` (eg. + ``errors=nan_checks``). Multiple sets can be re-combined (eg. + ``errors=float_checks|user_checks``) Args: fun: Callable which can contain user checks (see :func:`~check`). @@ -878,7 +1179,8 @@ def checkify(fun: Callable[..., Out], Returns: A function which accepts the same arguments as ``fun`` and returns as output a pair where the first element is an ``Error`` value, representing the first - failed :func:`~check`, and the second element is the original output of ``fun``. + failed :func:`~check`, and the second element is the original output of + ``fun``. For example: @@ -894,14 +1196,13 @@ def checkify(fun: Callable[..., Out], >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: nan generated by primitive sin - + jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin """ @traceback_util.api_boundary def checked_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) - (err, code, payload, out_flat), msgs = checkify_flat(f, errors, *args_flat) + error, out_flat = checkify_flat(f, errors, *args_flat) out = tree_unflatten(out_tree(), out_flat) - return Error(err, code, msgs, payload), out + return error, out return checked_fun diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 826baf053..4cf1c290a 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -246,6 +246,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, # Raise index in case of effects to allow data-dependence-based discharging # of those effects (even if they don't have an explicit data dependence). index = core.raise_as_much_as_possible(index) + false_jaxpr = false_jaxpr.replace( + jaxpr=false_jaxpr.jaxpr.replace(effects=joined_effects)) + true_jaxpr = true_jaxpr.replace( + jaxpr=true_jaxpr.jaxpr.replace(effects=joined_effects)) linear = [False] * len(consts) + linear_ops out = cond_p.bind( diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index c72e032ec..83993025a 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -15,6 +15,7 @@ from jax._src.checkify import ( Error as Error, ErrorCategory as ErrorCategory, + JaxRuntimeError as JaxRuntimeError, all_checks as all_checks, automatic_checks as automatic_checks, check as check, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 5fc9880c3..f19abb6fc 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1219,7 +1219,7 @@ tf_not_yet_impl = [ # Not high priority? "after_all", "all_to_all", - "assert", + "check", "create_token", "custom_transpose_call", "custom_vmap_call", diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 733a63661..921900b45 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -29,12 +29,13 @@ from jax.experimental import pjit from jax.experimental import maps from jax._src.sharding import NamedSharding from jax._src import array -from jax._src.checkify import CheckEffect +from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError import jax.numpy as jnp config.parse_flags_with_absl() +@jtu.with_config(jax_check_tracer_leaks=True) class CheckifyTransformTests(jtu.JaxTestCase): @jtu.sample_product(jit=[False, True]) @@ -49,11 +50,11 @@ class CheckifyTransformTests(jtu.JaxTestCase): checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(3., 4.) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) err, _ = checked_f(3., jnp.inf) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") @jtu.sample_product(jit=[False, True]) def test_jit_oob(self, jit): @@ -67,7 +68,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) err, _ = checked_f(jnp.arange(3), 5) self.assertIsNotNone(err.get()) @@ -83,7 +84,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) err, _ = checked_f(jnp.arange(3), 3) self.assertIsNotNone(err.get()) @@ -98,15 +99,14 @@ class CheckifyTransformTests(jtu.JaxTestCase): checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,))) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.])) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "division by zero") err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1])) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive div") + self.assertStartsWith(err.get(), "nan generated by primitive: div") @jtu.sample_product(jit=[False, True]) @jtu.skip_on_devices("tpu") @@ -121,7 +121,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): # no error err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) # oob error err, _ = checked_f(jnp.array([0., 1., 2.]), 5) @@ -131,11 +131,20 @@ class CheckifyTransformTests(jtu.JaxTestCase): # nan error err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive cos") + self.assertStartsWith(err.get(), "nan generated by primitive: cos") - def test_numpy_indexing_oobs(self): + @parameterized.named_parameters( + ("gather", lambda x: x.get()), + ("scatter_add", lambda x: x.add(1.)), + ("scatter_mul", lambda x: x.multiply(1.)), + ("scatter_div", lambda x: x.divide(1.)), + ("scatter_pow", lambda x: x.power(1.)), + ("scatter_min", lambda x: x.min(1.)), + ("scatter_max", lambda x: x.max(1.)), + ) + def test_numpy_indexing_oobs(self, update_op): def raises_oob(fn, idx, *expected_strs): - err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx) + err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx) error_txt = err.get() self.assertIsNotNone(error_txt) self.assertStartsWith(error_txt, "out-of-bounds indexing") @@ -147,7 +156,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): axis1_msg = "axis 1 with size 3" axis2_msg = "axis 2 with size 7" - single_idx = lambda x, i: x[i] + single_idx = lambda x, i: update_op(x.at[i]) raises_oob(single_idx, 5, "index 5", axis0_msg) raises_oob(single_idx, -5, "index -3", axis0_msg) raises_oob(single_idx, (0, 100), "index 100", axis1_msg) @@ -158,7 +167,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg) raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg) - multi_idx = lambda x, i: x[i[0], :, i[1]] + multi_idx = lambda x, i: update_op(x.at[i[0], :, i[1]]) raises_oob(multi_idx, (0, 9), "index 9", axis2_msg) # TODO(lenamartens): numpy reports index -5 here, need to normalize? raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg) @@ -194,32 +203,77 @@ class CheckifyTransformTests(jtu.JaxTestCase): xs = jnp.array([0., 2.]) err, _ = checked_f(xs, xs) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) ys = jnp.array([3., jnp.inf]) err, _ = checked_f(xs, ys) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") @jtu.skip_on_devices("tpu") def test_cond_basic(self): @jax.jit def f(x): - return lax.cond(x > 0, - lambda: jnp.sin(x), - lambda: x) + def true_fun(x): + return jnp.sin(x) + def false_fun(x): + checkify.check(x > -1, "oh no") + return x / 0. + return lax.cond(x > 0, true_fun, false_fun, x) - checked_f = checkify.checkify(f, errors=checkify.float_checks) + checked_f = checkify.checkify(f, errors=checkify.all_checks) err, _ = checked_f(3.) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) err, _ = checked_f(jnp.inf) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") err, _ = checked_f(-jnp.inf) - self.assertIs(err.get(), None) + self.assertStartsWith(err.get(), "oh no") + + err, _ = checked_f(0.) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "division by zero") + + def test_cond_different_payloads(self): + @jax.jit + def f(x): + def true_fun(x): + checkify.check(~x, "{one}", one=x) + def false_fun(x): + checkify.check(x, "{one} and {two}", one=x, two=x) + return lax.cond(x, true_fun, false_fun, x) + + checked_f = checkify.checkify(f) + + err, _ = checked_f(True) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "True") + + err, _ = checked_f(False) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "False and False") + + def test_cond_nd_payloads(self): + @jax.jit + def f(x): + def true_fun(x): + checkify.check(jnp.all(x > 0), "{one}", one=x) + def false_fun(x): + checkify.check(jnp.all(x < 0), "{one} and {two}", one=x, two=x) + return lax.cond(jnp.all(x < 0), true_fun, false_fun, x) + + checked_f = checkify.checkify(f) + + err, _ = checked_f(jnp.arange(0, 4)) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "[0 1 2 3] and [0 1 2 3]") + + err, _ = checked_f(jnp.arange(-4, -1)) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "[-4 -3 -2]") @jtu.skip_on_devices("tpu") def test_scan_map(self): @@ -235,14 +289,14 @@ class CheckifyTransformTests(jtu.JaxTestCase): xs = jnp.array([0., 2.]) err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) self.assertArraysEqual(ch_outs, outs) xs = jnp.array([3., jnp.inf]) err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") self.assertArraysEqual(ch_outs, outs) @jtu.skip_on_devices("tpu") @@ -261,7 +315,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): carry, xs = 3., jnp.ones((2,)) err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) @@ -303,7 +357,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) self.assertArraysEqual(ch_out, out) init_val = 0. @@ -331,7 +385,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) - self.assertIs(err.get(), None) + self.assertIsNone(err.get()) self.assertArraysEqual(ch_out, out) init_val = 0. @@ -388,20 +442,20 @@ class CheckifyTransformTests(jtu.JaxTestCase): body_val = 1. err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") cond_val = 1. body_val = jnp.inf err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive cos") + self.assertStartsWith(err.get(), "nan generated by primitive: cos") cond_val = jnp.inf body_val = jnp.inf err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) # first error which occurs is in cond - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") def test_pjit(self): def f(x): @@ -448,8 +502,8 @@ class CheckifyTransformTests(jtu.JaxTestCase): @parameterized.named_parameters( ("assert", checkify.user_checks, "must be negative!"), - ("div", {checkify.ErrorCategory.DIV}, "division by zero"), - ("nan", {checkify.ErrorCategory.NAN}, "nan generated"), + ("div", checkify.div_checks, "division by zero"), + ("nan", checkify.nan_checks, "nan generated"), ("oob", checkify.index_checks, "out-of-bounds indexing"), ("automatic_checks", checkify.automatic_checks, "division by zero"), ) @@ -477,7 +531,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return f(jnp.inf) err, _ = g(2.) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") @jtu.skip_on_devices("tpu") def test_post_process_map(self): @@ -485,11 +539,11 @@ class CheckifyTransformTests(jtu.JaxTestCase): def g(x): @jax.pmap def f(y): - return jnp.sin(x * y) + return jnp.sin(x * y), jnp.cos(x * y) return f(jnp.array([jnp.inf]))[0] err, _ = g(2.) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), 'nan generated by primitive sin') + self.assertStartsWith(err.get(), "nan generated by primitive: sin") @jtu.skip_on_devices("tpu") def test_custom_jvp(self): @@ -508,13 +562,13 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertIsNone(err.get()) err, y = f(jnp.inf) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), 'nan generated by primitive sin') + self.assertStartsWith(err.get(), 'nan generated by primitive: sin') # When we hit the custom jvp rule with jvp-of-checkify, no checks are added. (err, y), (errdot, ydot) = jax.jvp(f, (3.,), (1.,)) # doesn't crash self.assertIsNone(err.get()) # no error - self.assertEmpty(err.msgs) # and no checks were added! - self.assertEmpty(errdot.msgs) + self.assertEmpty(err._metadata) # and no checks were added! + self.assertEmpty(errdot._metadata) y_expected, ydot_expected = jax.jvp(jnp.sin, (3.,), (1.,)) self.assertAllClose(y, y_expected) self.assertAllClose(ydot, ydot_expected) @@ -528,12 +582,12 @@ class CheckifyTransformTests(jtu.JaxTestCase): errors=checkify.float_checks) err, (y, ydot) = g(3., 1.) # doesn't crash self.assertIsNone(err.get()) # no error - self.assertNotEmpty(err.msgs) # but checks were added! + self.assertNotEmpty(err._metadata) # but checks were added! self.assertAllClose(y, jnp.sin(3.)) self.assertAllClose(ydot, jnp.cos(3.)) err, _ = g(jnp.inf, 1.) self.assertIsNotNone(err.get()) # yes error - self.assertStartsWith(err.get(), 'nan generated by primitive sin') + self.assertStartsWith(err.get(), 'nan generated by primitive: sin') @jtu.skip_on_devices("tpu") def test_custom_vjp(self): @@ -556,31 +610,31 @@ class CheckifyTransformTests(jtu.JaxTestCase): # no differentiation, yes error err, y = f(jnp.inf) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), 'nan generated by primitive sin') + self.assertStartsWith(err.get(), 'nan generated by primitive: sin') # When we hit the custom vjp rule with vjp-of-checkify, no checks are added. (err, y), f_vjp = jax.vjp(f, 3.) self.assertIsNone(err.get()) # no error - self.assertEmpty(err.msgs) # and no checks were added! + self.assertEmpty(err._metadata) # and no checks were added! # Checkify-of-vjp adds checks (unlike vjp-of-checkify above). err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.) self.assertIsNone(err.get()) # no error - self.assertNotEmpty(err.msgs) # but checks were added! + self.assertNotEmpty(err._metadata) # but checks were added! err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(jnp.inf) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive sin") + self.assertStartsWith(err.get(), "nan generated by primitive: sin") def test_scan_consts(self): def f(xs): def scan_body(carry, _): # closes oves xs return carry+1, xs[carry] - return lax.scan(scan_body, 1, xs)[1] + return lax.scan(scan_body, 1, xs) checked_f = checkify.checkify(f, errors=checkify.index_checks) - err, _ = checked_f(jnp.ones((7, 3))) + err, _ = checked_f(jnp.ones((7,))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") @@ -649,14 +703,59 @@ class CheckifyTransformTests(jtu.JaxTestCase): errors=checkify.float_checks) cf(jax.random.PRNGKey(123)) # does not crash. + def test_pmap_one_device(self): + @jax.pmap + def f(x, y): + return x/y + cf = checkify.checkify(f, errors=checkify.automatic_checks) + errs, _ = cf(jnp.ones((1,)), jnp.zeros((1,))) + self.assertIsNotNone(errs.get()) + self.assertIn("division by zero", errs.get()) + + def test_psum_nan_check(self): + @partial(jax.vmap, axis_name="i") + def f(x, y): + return lax.psum((x, y), axis_name="i") + + cf = checkify.checkify(f, errors=checkify.nan_checks) + err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2))) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "nan generated by primitive: psum") + + def test_different_payload_effects(self): + def f(x, y): + x = x[y] + checkify.check(jnp.all(x > 0), "{x}", x=x) + return x + + f = checkify.checkify(f, errors=checkify.all_checks) + err, _ = jax.vmap(f)(jnp.ones((2, 3))*-1, jnp.array([0, 5])) + self.assertIsNotNone(err.get()) + + def test_effects_total_ordering(self): + sds0 = jax.ShapeDtypeStruct((2,), jnp.float32) + sds1 = jax.ShapeDtypeStruct((2,), jnp.int32) + sds2 = jax.ShapeDtypeStruct((3,), jnp.int32) + self.assertTotallyOrdered( + [ErrorEffect(FailedCheckError, (sds0,))], + [ErrorEffect(FailedCheckError, (sds0, sds0))], + [ErrorEffect(FailedCheckError, (sds1,))], + [ErrorEffect(FailedCheckError, (sds1, sds0))], + [ErrorEffect(FailedCheckError, (sds2,))], + [ErrorEffect(OOBError, (sds0,))], + [ErrorEffect(OOBError, (sds0, sds0))], + ) + + +@jtu.with_config(jax_check_tracer_leaks=True) class AssertPrimitiveTests(jtu.JaxTestCase): def test_assert_primitive_impl(self): def f(): checkify.check(False, "hi") - with self.assertRaisesRegex(ValueError, "hi"): + with self.assertRaisesRegex(JaxRuntimeError, "hi"): f() def test_assert_primitive_lowering(self): @@ -668,10 +767,24 @@ class AssertPrimitiveTests(jtu.JaxTestCase): f() def test_assert_primitive_jaxpr_effects(self): - def f(): - checkify.check(False, "hi") + def f(x): + checkify.check(False, "hi: {}", x) - self.assertSetEqual(jax.make_jaxpr(f)().effects, {CheckEffect}) + jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32)) + self.assertSetEqual(jaxpr.effects, + {ErrorEffect(FailedCheckError, ( + jax.ShapeDtypeStruct((0,), jnp.int32), + jax.ShapeDtypeStruct((4,), jnp.int32),))}) + def g(x, y): + checkify.check(False, "hi: {} {}", x, y) + + self.assertSetEqual( + jax.make_jaxpr(g)( + jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects, + {ErrorEffect(FailedCheckError, ( + jax.ShapeDtypeStruct((0,), jnp.int32), + jax.ShapeDtypeStruct((4,), jnp.int32), + jax.ShapeDtypeStruct((2,), jnp.float32)))}) def test_assert_primitive_eval_shape(self): # The check is abstractly evaluated but not lowered. @@ -798,10 +911,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase): mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]]) f(no_failures) - with self.assertRaisesRegex(ValueError, "x must sum to one."): + with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."): f(one_batch_fails) - with self.assertRaisesRegex(ValueError, "x must sum to one."): + with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."): f(mult_batch_fail) checked_f = checkify.checkify(f) @@ -817,10 +930,13 @@ class AssertPrimitiveTests(jtu.JaxTestCase): self.assertStartsWith(err.get(), "x must sum to one") def test_check_error(self): + def g(): + checkify.check(False, "hi") def f(): - checkify.check_error(checkify.Error(True, 0, {0: "hi"})) + err, _ = checkify.checkify(g)() + checkify.check_error(err) - with self.assertRaisesRegex(ValueError, "hi"): + with self.assertRaisesRegex(JaxRuntimeError, "hi"): f() f = checkify.checkify(f) @@ -872,7 +988,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase): python_should_be_running = False f(True) - with self.assertRaisesRegex(ValueError, "foo"): + with self.assertRaisesRegex(JaxRuntimeError, "foo"): f(False) def test_cond_of_named_call(self): @@ -950,10 +1066,12 @@ class AssertPrimitiveTests(jtu.JaxTestCase): # self.assertIsNone(err.get()) def test_assert_cond_no_data_dependence(self): + def true_fun(): + return checkify.check(False, "hi!") + def false_fun(): + return checkify.check(False, "bye!") def f(): - return jax.lax.cond(True, - lambda: checkify.check(False, "hi!"), - lambda: checkify.check(False, "bye!")) + return jax.lax.cond(True, true_fun, false_fun) f = checkify.checkify(f) err, _ = f() @@ -973,16 +1091,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "hi!") - def test_psum_nan_check(self): - @partial(jax.vmap, axis_name="i") - def f(x, y): - return lax.psum((x, y), axis_name="i") - - cf = checkify.checkify(f, errors=checkify.nan_checks) - err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2))) - self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive psum") - class LowerableChecksTest(jtu.JaxTestCase): def setUp(self):