diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 53b337952..e916dd7c1 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -16,12 +16,13 @@ import dataclasses import functools import itertools as it import types -from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List +from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List, Sequence, Any import jax from jax import lax from jax._src import linear_util as lu from jax._src import core +from jax._src import custom_derivatives from jax._src import prng from jax._src import source_info_util from jax._src import traceback_util @@ -29,10 +30,8 @@ from jax._src.config import config from jax._src.lax import control_flow as cf from jax._src.sharding import OpShardingSharding from jax._src.typing import Array -from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, - safe_zip) +from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, safe_zip) from jax.api_util import flatten_fun -from jax.api_util import flatten_fun_nokwargs from jax.experimental import maps from jax.experimental import pjit from jax.interpreters import ad @@ -57,6 +56,7 @@ Int = Union[int, Array] ErrorCategory = Type['JaxException'] Payload = List[Union[np.ndarray, Array]] PyTreeDef = jtu.PyTreeDef +Out = TypeVar('Out') ## Utils @@ -173,12 +173,12 @@ class FailedCheckError(JaxException): self.kwargs = k def tree_flatten(self): - return ((jnp.array([], jnp.int32), self.args, self.kwargs), - (self.traceback_info, self.fmt_string)) + return ((self.args, self.kwargs), # leaves + (self.traceback_info, self.fmt_string)) # treedef @classmethod def tree_unflatten(cls, metadata, payload): - _, args, kwargs = payload + args, kwargs = payload return cls(*metadata, *args, **kwargs) def __str__(self): @@ -189,9 +189,7 @@ class FailedCheckError(JaxException): vals = jtu.tree_leaves((self.args, self.kwargs)) return ErrorEffect( FailedCheckError, - # Need a 0-size array here for data-dependence. - (jax.ShapeDtypeStruct((0,), jnp.int32), - *tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))) + tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)) @dataclasses.dataclass class BatchedError(JaxException): @@ -326,169 +324,101 @@ def update_error(error, pred, code, metadata, payload, effect_type): ## Checkify transformation for plumbing functional error values. -class CheckifyTracer(core.Tracer): - def __init__(self, trace, val): - self._trace = trace - self.val = val - aval = property(lambda self: core.get_aval(self.val)) - full_lower = lambda self: self +@lu.transformation_with_aux +def _flatten_and_get_error_metadata_thunk(*invals): + error, out = yield invals, {} + out_vals, out_tree = jtu.tree_flatten((error, out)) + yield out_vals, (out_tree, set(error._pred.keys())) -class CheckifyTrace(core.Trace): - pure = lift = lambda self, val: CheckifyTracer(self, val) +def default_checkify_rule(primitive: core.Primitive, error: Error, + enabled_errors, *invals: core.Value, + **params: Any) -> Tuple[Error, Sequence[core.Value]]: + """Default rule for primitives in `checkify` interpreter.""" + if 'call_jaxpr' not in params: + # Default non-HOP case: just call primitive and don't update error. + return error, primitive.bind(*invals, **params) - def __init__(self, main: core.MainTrace, sublevel: core.Sublevel, - enabled_errors: FrozenSet['ErrorCategory']) -> None: - self.main = main - self.level = main.level - self.sublevel = sublevel - self.main.enabled_errors = enabled_errors + # Code below handles call- and map-primitives, by recursively calling + # checkify_jaxpr. + err_vals, err_tree = jtu.tree_flatten(error) + num_error_vals = len(err_vals) + if 'donated_invars' in params: + params = dict(params, donated_invars=(*[False]*num_error_vals, + *params['donated_invars'])) - def sublift(self, tracer): - return CheckifyTracer(self, tracer.val) + # call_jaxpr handling + call_jaxpr = params.pop('call_jaxpr') + partial_checkify = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, call_jaxpr, (), enabled_errors, + err_tree)) + partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( + partial_checkify) - def process_primitive(self, primitive, tracers, params): - in_vals = [t.val for t in tracers] - rule = error_checks.get(primitive) - if rule: - out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore - *in_vals, **params) - else: - out = primitive.bind(*in_vals, **params) - if primitive.multiple_results: - return [CheckifyTracer(self, x) for x in out] - else: - return CheckifyTracer(self, out) + # map-specific params handling. + if isinstance(primitive, core.MapPrimitive): + # Update `in_axes` and `out_axes_thunk` params for map primitive. + out_val_axes = params.pop('out_axes') - def process_call(self, primitive, f, tracers, params): - in_vals = [t.val for t in tracers] - e = popattr(self.main, 'error') - flat_vals, in_tree = tree_flatten((e, *in_vals)) - f = checkify_subtrace(f, self.main) - f, out_tree = flatten_fun_nokwargs(f, in_tree) - if 'donated_invars' in params: - params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)), - *params['donated_invars'])) - all_vals = primitive.bind(f, *flat_vals, **params) - error, *out_vals = tree_unflatten(out_tree(), all_vals) - setnewattr(self.main, 'error', error) - return [CheckifyTracer(self, x) for x in out_vals] - - def process_map(self, primitive, f, tracers, params): - in_vals = [t.val for t in tracers] - e = popattr(self.main, 'error') - flat_vals, in_tree = tree_flatten((e, *in_vals)) - num_error_vals = len(jtu.tree_leaves(e)) - f = checkify_subtrace(f, self.main) - f, out_tree = flatten_fun_nokwargs(f, in_tree) - - @as_hashable_function(closure=params['out_axes_thunk']) - def new_out_axes_thunk(): - out_val_axes = params['out_axes_thunk']() - out_err_num = out_tree().num_leaves - len(out_val_axes) + @as_hashable_function(closure=out_val_axes) + def out_axes_thunk(): + out_err_num = metadata()[0].num_leaves - len(out_val_axes) return (*(0,)*out_err_num, *out_val_axes) - params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']), - out_axes_thunk=new_out_axes_thunk, - donated_invars=(*(False,)*num_error_vals, *params['donated_invars'])) - all_vals = primitive.bind(f, *flat_vals, **params_) - error, *out_vals = tree_unflatten(out_tree(), all_vals) - error = _reduce_any_error(error) - setnewattr(self.main, 'error', error) - return [CheckifyTracer(self, x) for x in out_vals] + params = dict(params, + in_axes=(*(None,)*num_error_vals, *params['in_axes']), + out_axes_thunk=out_axes_thunk) - def post_process_call(self, primitive, tracers, params): - vals = [t.val for t in tracers] - main = self.main - e = popattr(main, 'error') - err_leaves, err_tree = tree_flatten(e) - setnewattr(main, 'err_tree', err_tree) - def todo(vals): - err_tree = popattr(main, 'err_tree') - err_vals, vals = split_list(vals, [err_tree.num_leaves]) - setnewattr(main, 'error', tree_unflatten(err_tree, err_vals)) - trace = main.with_cur_sublevel() - return [CheckifyTracer(trace, x) for x in vals] - return (*err_leaves, *vals), todo + all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params) - def post_process_map(self, primitive, tracers, params): - vals = [t.val for t in tracers] - main = self.main - e = popattr(main, 'error') - err_leaves, err_tree = tree_flatten(e) - num_err_leaves = len(err_leaves) - setnewattr(main, 'err_tree', err_tree) - def todo(vals): - err_tree = popattr(main, 'err_tree') - err_vals, vals = split_list(vals, [err_tree.num_leaves]) - error = tree_unflatten(err_tree, err_vals) - error = _reduce_any_error(error) - setnewattr(main, 'error', error) - trace = main.with_cur_sublevel() - return [CheckifyTracer(trace, x) for x in vals] - def out_axes_transform(out_axes): - return (*(0,)*num_err_leaves, *out_axes) - return (*err_leaves, *vals), (todo, out_axes_transform) + out_tree, _ = metadata() + error, out_vals = tree_unflatten(out_tree, all_vals) + return error, out_vals - def process_custom_jvp_call(self, prim, f, jvp, tracers): - in_vals = [t.val for t in tracers] - e = popattr(self.main, 'error') - err_vals, err_tree = tree_flatten(e) - flat_vals, in_tree = tree_flatten((e, *in_vals)) - num_error_vals = len(err_vals) - f = checkify_subtrace(f, self.main) - f, f_out_tree = flatten_fun_nokwargs(f, in_tree) - jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main, - num_error_vals, err_tree) - all_outs = prim.bind(f, jvp, *flat_vals) - fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree) - if fst: - out_err, *out_vals = tree_unflatten(out_tree, all_outs) +def get_shaped_aval(val): + return core.raise_to_shaped(core.get_aval(val)) + +def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors, + error: Error, *args) -> Tuple[Error, List[core.Value]]: + err_vals, err_tree = jtu.tree_flatten(error) + return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts, + enabled_errors, err_tree, *err_vals, *args) + +def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], + enabled_errors, err_tree: PyTreeDef, + *args: core.Value) -> Tuple[Error, List[Any]]: + env: Dict[core.Var, Any] = {} + err_vals, in_args = split_list(args, [err_tree.num_leaves]) + error = jtu.tree_unflatten(err_tree, err_vals) + + def read_env(var: core.Atom): + if isinstance(var, core.Literal): + return var.val + return env[var] + + def write_env(var: core.Var, val: Any): + env[var] = val + + map(write_env, jaxpr.constvars, consts) + map(write_env, jaxpr.invars, in_args) + + # interpreter loop + for eqn in jaxpr.eqns: + invals = map(read_env, eqn.invars) + checkify_rule = error_checks.get( + eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive)) + error, outvals = checkify_rule(error, enabled_errors, *invals, **eqn.params) + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) else: - err_vals, out_vals = split_list(all_outs, [num_error_vals]) - # forward input error values to output - out_err = tree_unflatten(out_tree, err_vals) - setattr(self.main, 'error', out_err) - return [CheckifyTracer(self, x) for x in out_vals] + write_env(eqn.outvars[0], outvals) - def post_process_custom_jvp_call(self, tracers, jvp_was_run): - if jvp_was_run: - msg = ('support for custom_jvp rules which close over checkify values is ' - 'not implemented. If you see this, open an issue at ' - 'https://github.com/google/jax/issues!') - raise NotImplementedError(msg) - vals = [t.val for t in tracers] - main = self.main - e = popattr(main, 'error') - err_leaves, err_tree = tree_flatten(e) - def todo(vals): - err_vals, vals = split_list(vals, [len(err_leaves)]) - setnewattr(main, 'error', tree_unflatten(err_tree, err_vals)) - trace = main.with_cur_sublevel() - return [CheckifyTracer(trace, x) for x in vals] - return (*err_leaves, *vals), todo + return error, map(read_env, jaxpr.outvars) - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): - in_vals = [t.val for t in tracers] - e = popattr(self.main, 'error') - err_vals, err_tree = tree_flatten(e) - flat_vals, in_tree = tree_flatten((e, *in_vals)) - num_error_vals = len(err_vals) +@lu.transformation_with_aux +def flatten_fun_output(*args): + ans = yield args, {} + yield tree_flatten(ans) - fun = checkify_subtrace(fun, self.main) - fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree) - fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main, - err_tree, num_error_vals) - - all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees) - fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree) - if fst: - error, *out = tree_unflatten(out_tree, all_out_vals) - else: - _, out = split_list(all_out_vals, [num_error_vals]) - # forward input error values to output - error = tree_unflatten(err_tree, err_vals) - setattr(self.main, 'error', error) - return [CheckifyTracer(self, x) for x in out] def _reduce_any_error(error: Error): out_error = init_error @@ -504,43 +434,468 @@ def _reduce_any_error(error: Error): out_error = out_error._replace(_metadata=error._metadata) return out_error +## check_p primitive + +check_p = core.Primitive('check') +check_p.multiple_results = True # zero results + +# TODO(lenamartens): inherit from Exception instead of ValueError. +class JaxRuntimeError(ValueError): + pass + +@check_p.def_impl +def check_impl(*args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + error = tree_unflatten(err_tree, args) + exc = error.get_exception() + if exc: + raise JaxRuntimeError(str(exc)) from exc + return [] + +@check_p.def_effectful_abstract_eval +def check_abstract_eval(*args, err_tree, debug): + del debug + return [], set(tree_unflatten(err_tree, args)._pred.keys()) + +# TODO(lenamartens) add in-depth error explanation to link to in module docs. +functionalization_error = ValueError( + 'Cannot abstractly evaluate a checkify.check which was not' + ' functionalized. This probably means you tried to stage' + ' (jit/scan/pmap/...) a `check` without functionalizing it' + ' through `checkify.checkify`.' + ) + +def check_lowering_rule(ctx, *args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + if not config.jax_experimental_unsafe_xla_runtime_errors: + raise functionalization_error + + out_op, _, keep_alive = mlir.emit_python_callback( + ctx, callback=functools.partial(python_err, err_tree), + token=None, + operands=args, + operand_avals=list(ctx.avals_in), + result_avals=list(ctx.avals_out), + has_side_effect=True) + ctx.module_context.add_keepalive(keep_alive) + return out_op + +def check_lowering_rule_unsupported(*a, debug, **k): + if debug: + return [] + raise functionalization_error + +def python_err(err_tree, *args): + error = tree_unflatten(err_tree, args) + _check_error(error) + return [] + +mlir.register_lowering(check_p, check_lowering_rule_unsupported, + platform='tpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='cpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='gpu') + +def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): + size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) + if dim is not batching.not_mapped) + batched_args = (batching.bdim_at_front(a, d, size) + for a, d in zip(batched_args, batch_dims)) + err = tree_unflatten(err_tree, batched_args) + _check_error(err, debug=debug) + return [], [] +batching.primitive_batchers[check_p] = check_batching_rule + +def check_jvp_rule(primals, _, *, err_tree, debug): + # Check primals, discard tangents. + check_p.bind(*primals, err_tree=err_tree, debug=debug) + return [], [] +ad.primitive_jvps[check_p] = check_jvp_rule + +## checkify rules + ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: Dict[core.Primitive, ErrorCheckRule] = {} -def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], - *args): - fun = checkify_subtrace(fun) - fun = checkify_traceable(fun, enabled_errors) - error, *outvals = fun.call_wrapped(init_error, *args) - return error, outvals -@lu.transformation -def checkify_traceable(enabled_errors, error, *args): - with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main: - outs = yield (main, error, *args), {} - del main - yield outs +def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]: + # TODO(lenamartens): use c++ version from XLA? + tb = None + import inspect + for frame_info in inspect.stack(): + frame = frame_info.frame + if skip_frames: + skip_frames -= 1 + elif not traceback_util.include_frame(frame): + continue + else: + tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + return tb -@lu.transformation -def checkify_subtrace(main, error, *args): - setnewattr(main, 'error', error) - trace = main.with_cur_sublevel() - in_tracers = [CheckifyTracer(trace, x) for x in args] - out = yield in_tracers, {} - out_tracers = map(trace.full_raise, out) - out_vals = [t.val for t in out_tracers] - error = main.error - del main.error - yield (error, *out_vals) +def summary() -> str: + return str(source_info_util.summarize(source_info_util.current())) -@lu.transformation_with_aux -def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args): - # Like checkify_subtrace, but used specifically on the custom JVP rules - # associated with a custom_jvp. This code is called in the context of a - # jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs, - # flattened into a single args tuple, and similarly must produce flattened - # primal and tangent outputs. Both primals and tangents include error values, - # but the tangent error values are trivially zero. +def nan_error_check(prim, error, enabled_errors, *in_vals, **params): + out = prim.bind(*in_vals, **params) + err = check_nans(prim, error, enabled_errors, out) + return err, out + +def check_nans(prim, error, enabled_errors, out): + if NaNError not in enabled_errors: + return error + + def isnan(x): + if isinstance(x, prng.PRNGKeyArray): + return False + return jnp.any(jnp.isnan(x)) + + any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) + if prim.multiple_results else isnan(out)) + return assert_func(error, any_nans, NaNError(summary(), prim.name)) + + +# All primitives which can generate a NaN. +nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, + lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, + lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, + lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, + lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, + lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, + lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, + lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, + lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, + lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, + lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_sum_p, lax.reduce_window_p, + lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, + lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, + lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] + +for _prim in nan_primitives: + error_checks[_prim] = functools.partial(nan_error_check, _prim) + + +def gather_error_check(error, enabled_errors, operand, start_indices, *, + dimension_numbers, slice_sizes, unique_indices, + indices_are_sorted, mode, fill_value): + out = lax.gather_p.bind( + operand, start_indices, dimension_numbers=dimension_numbers, + slice_sizes=slice_sizes, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) + + if OOBError not in enabled_errors: + return error, out + + # compare to OOB masking logic in lax._gather_translation_rule + dnums = dimension_numbers + operand_dims = np.array(operand.shape) + num_batch_dims = len(start_indices.shape) - 1 + + upper_bound = operand_dims[np.array(dnums.start_index_map)] + upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] + upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) + oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) + + payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) + error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload)) + return error, out +error_checks[lax.gather_p] = gather_error_check + +def div_error_check(error, enabled_errors, x, y): + """Checks for division by zero and NaN.""" + if DivisionByZeroError in enabled_errors: + any_zero = jnp.any(jnp.equal(y, 0)) + error = assert_func(error, any_zero, DivisionByZeroError(summary())) + return nan_error_check(lax.div_p, error, enabled_errors, x, y) +error_checks[lax.div_p] = div_error_check + +def oob_payload(oob_mask, indices, dims_map, operand_shape): + # Get first OOB index, axis and axis size so it can be added to the error msg. + flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) + multi_idx = jnp.unravel_index(flat_idx, indices.shape) + oob_axis = jnp.array(dims_map)[multi_idx[-1]] + oob_axis_size = jnp.array(operand_shape)[oob_axis] + oob_index = jnp.ravel(indices)[flat_idx] + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + return payload + +def scatter_oob(operand, indices, updates, dnums): + # Ref: see clamping code used in scatter_translation_rule + slice_sizes = [] + pos = 0 + for i in range(len(operand.shape)): + if i in dnums.inserted_window_dims: + slice_sizes.append(1) + else: + slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) + pos += 1 + + upper_bound = np.array([operand.shape[i] - slice_sizes[i] + for i in dnums.scatter_dims_to_operand_dims], + np.int64) + upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) + upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, + (len(indices.shape) - 1,)) + + lower_oob = jnp.less(indices, 0) + upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) + oob_mask = jnp.logical_or(lower_oob, upper_oob) + payload = oob_payload(oob_mask, indices, + dnums.scatter_dims_to_operand_dims, operand.shape) + return jnp.any(oob_mask), payload + +def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, + *, update_jaxpr, update_consts, dimension_numbers, + indices_are_sorted, unique_indices, mode): + """Checks if indices are within bounds and update does not generate NaN.""" + out = prim.bind( + operand, indices, updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) + + if OOBError not in enabled_errors: + return error, out + + out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) + oob_error = OOBError(summary(), prim.name, operand.shape, payload) + error = assert_func(error, out_of_bounds, oob_error) + error = check_nans(prim, error, enabled_errors, out) + return error, out +error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) +error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, + lax.scatter_add_p) +error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, + lax.scatter_mul_p) +error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, + lax.scatter_min_p) +error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, + lax.scatter_max_p) + +# HOP error check rules + +def get_error_effects_from_jaxpr(closed_jaxpr: core.ClosedJaxpr, + enabled_errors, + error, + *args) -> Set[ErrorEffect]: + """Probes a jaxpr for its error effects.""" + err_vals, err_tree = jtu.tree_flatten(error) + checkify_fun = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, closed_jaxpr.jaxpr, + closed_jaxpr.literals, enabled_errors, err_tree)) + checkify_fun, metadata = _flatten_and_get_error_metadata_thunk(checkify_fun) + in_avals = map(get_shaped_aval, [*err_vals, *args]) + pe.trace_to_jaxpr_final(checkify_fun, in_avals) + _, error_effects = metadata() + return error_effects + +def jaxpr_to_checkify_jaxpr( + jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, + *flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: + checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, + jaxpr.consts, enabled_errors, + err_tree) + fun = lu.wrap_init(checkify_jaxpr_partial) + fun, metadata = _flatten_and_get_error_metadata_thunk(fun) + + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) + checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) + out_tree, error_effects = metadata() + return checked_jaxpr, out_tree, error_effects + +def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear): + # Get the error-effects out of all branches so the cond can be called with + # a merged error with all these effects. + effects = [get_error_effects_from_jaxpr(jxpr, enabled_errors, error, *ops) + for jxpr in branches] + merged_error = error._add_placeholder_effects(set().union(*effects)) + err_vals, err_tree = jtu.tree_flatten(merged_error) + new_linear = (*[False] * len(err_vals), *linear) + + # Update branch jaxprs to be checkified jaxprs. + checked_branch_partials = tuple( + functools.partial(checkify_jaxpr_flat, closed_jaxpr.jaxpr, + closed_jaxpr.consts, enabled_errors, err_tree) + for closed_jaxpr in branches) + checked_branch_funs_ = map(lu.wrap_init, checked_branch_partials) + checked_branch_funs, out_trees_and_effects = unzip2( + map(_flatten_and_get_error_metadata_thunk, checked_branch_funs_)) + in_vals = jtu.tree_leaves((merged_error, ops)) + in_avals = tuple(map(get_shaped_aval, in_vals)) + def to_jaxpr(fun, in_avals): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + return core.ClosedJaxpr(jaxpr, consts) + new_branches = map( + lambda fun: to_jaxpr(fun, in_avals), + checked_branch_funs) + + + err_and_outs = lax.cond_p.bind( + index, *err_vals, *ops, + branches=tuple(new_branches), linear=new_linear) + + # we need to merge metadata across out_trees (a tuple) + out_trees, _ = unzip2(map(lambda fun: fun(), out_trees_and_effects)) + err0, out = tree_unflatten(out_trees[0], err_and_outs) + merged_metadata = err0._metadata + for tr in out_trees[1:]: + err, _ = tree_unflatten(tr, err_and_outs) + merged_metadata = {**merged_metadata, **err._metadata} + return err0._replace(_metadata=merged_metadata), out +error_checks[lax.cond_p] = cond_error_check + +def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, + num_consts, num_carry, linear, unroll): + + consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) + # Query body effects to create a merged error containing all effects (such + # that in and out carried error are of the same type). + effects = get_error_effects_from_jaxpr(jaxpr, enabled_errors, error, *in_flat) + merged_error = error._add_placeholder_effects(effects) + err_vals, err_tree = jtu.tree_flatten(merged_error) + + # Create checked-jaxpr, with the needed pre-processing on the inputs. + xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs] + new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped + checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, + err_tree, *new_in_aval) + + new_in_flat = [*consts, *err_vals, *carry, *xs] + new_linear = (*[False] * len(err_vals), *linear) + tomove = ([False] * len(err_vals) + [True] * len(consts) + + [False] * (len(carry) + len(xs))) + checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) + new_in_flat = [*consts, *err_vals, *carry, *xs] + err_and_out = lax.scan_p.bind( + *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, + num_consts=len(consts), num_carry=len(carry)+len(err_vals), + linear=new_linear, unroll=unroll) + err, out = tree_unflatten(out_tree, err_and_out) + return err, out + +error_checks[lax.scan_p] = scan_error_check + +def checkify_while_body_jaxpr( + cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr, + enabled_errors, error: Error, + c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: + cond_f = core.jaxpr_as_fun(cond_jaxpr) + body_f = core.jaxpr_as_fun(body_jaxpr) + def new_body_f(*vals): + out = body_f(*vals) + # This checks if the next cond application will error + _ = cond_f(*c_consts, *out) + return out + new_body_f_ = lu.wrap_init(new_body_f) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + err_vals, err_tree = jtu.tree_flatten(error) + err_vals = map(get_shaped_aval, err_vals) + flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( + closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + return jaxpr, out_tree, error_effects + +def ignore_error_output_jaxpr(jaxpr, num_error_vals): + """Constructs a checked jaxpr which does not output its error value.""" + consts = jaxpr.consts + jaxpr = jaxpr.jaxpr + new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) + return core.ClosedJaxpr(new_jaxpr, consts) + +def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, + cond_jaxpr, body_nconsts, body_jaxpr): + if cond_jaxpr.out_avals[0].shape: + # TODO(lenamartens, sharadmv): support batched while. + raise ValueError('Checkify does not support batched while-loops ' + '(checkify-of-vmap-of-while). \nHint: if possible, move ' + 'the vmap to the outer level to get ' + 'vmap-of-checkify-of-while.') + + c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) + # Check if the first cond application will error. + error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry) + + _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, + enabled_errors, error, c_consts) + # merged error! + error = error._add_placeholder_effects(error_effects) + err_vals, err_tree = jtu.tree_flatten(error) + checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts) + num_error_vals = len(err_vals) + to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) + checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) + + cond_in_flat = [*err_vals, *c_consts, *carry] + cond_in_flat = map(get_shaped_aval, cond_in_flat) + checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors, + err_tree, *cond_in_flat) + compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) + to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) + compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) + + new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] + all_out_vals = lax.while_p.bind( + *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, + body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) + # body_out_tree will have all the metadata of cond because it executes a cond! + error, out = tree_unflatten(body_out_tree, all_out_vals) + return error, out +error_checks[lax.while_p] = while_loop_error_check + +def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, + in_shardings, out_shardings, resource_env, + donated_invars, name, + in_positional_semantics, out_positional_semantics, inline): + # jaxpr to checked_jaxpr + err_vals, err_tree = jtu.tree_flatten(error) + new_vals_in = [*err_vals, *vals_in] + in_avals = tuple(map(get_shaped_aval, new_vals_in)) + checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, + err_tree, *in_avals) + + # Update pjit params to account for extra error values. + num_error_vals = len(err_vals) + num_out_error_vals = out_tree.num_leaves - len(out_shardings) + sharding = OpShardingSharding.get_replicated( + list(resource_env.physical_mesh.devices.flat)) + new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + + pos_sem = (maps._PositionalSemantics.GLOBAL if jax.config.jax_array + else maps._positional_semantics.val) + if not isinstance(in_positional_semantics, Iterable): + in_positional_semantics = (in_positional_semantics,) + if not isinstance(out_positional_semantics, Iterable): + out_positional_semantics = (out_positional_semantics,) + new_positional_sems_in = (*[pos_sem] * num_error_vals, + *in_positional_semantics) + new_positional_sems_out = (*[pos_sem] * num_error_vals, + *out_positional_semantics) + new_donated_invars = (*[False] * num_error_vals, *donated_invars) + + err_and_out = pjit.pjit_p.bind( + *new_vals_in, + jaxpr=checked_jaxpr, + in_shardings=new_in_shardings, + out_shardings=new_out_shardings, + resource_env=resource_env, + donated_invars=new_donated_invars, + name=name, + in_positional_semantics=new_positional_sems_in, + out_positional_semantics=new_positional_sems_out, + inline=inline) + return tree_unflatten(out_tree, err_and_out) +error_checks[pjit.pjit_p] = pjit_error_check + +def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, + jvp_jaxpr_thunk, call_jaxpr, **params): # The types to have in mind are: # jvp : (a -> b) -> (a, T a) -> (b, T b) # checkify : (a -> b) -> a -> Err b @@ -550,53 +905,180 @@ def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args): # Semantically, we don't add checks to the JVP rule. To check the result of a # JVP rule, one must instead use checkify-of-jvp. Thus this implementation # just forwards the input error and code (and trivial tangents) to the output. - del main - n, ragged = divmod(len(args), 2) - assert not ragged - err_primals, primals = split_list(args[:n], [num_error_vals]) - err_tangents, tangents = split_list(args[n:], [num_error_vals]) - outs = yield (*primals, *tangents), {} - m, ragged = divmod(len(outs), 2) - assert not ragged - out_primals, out_tangents = outs[:m], outs[m:] - yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree + err_vals, err_tree = jtu.tree_flatten(in_err) + partial_checkify = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, + call_jaxpr.consts, enabled_errors, err_tree)) + partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk( + partial_checkify) -@lu.transformation_with_aux -def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args): - del main - # We don't add any checks; just drop input error values. - _, args = split_list(args, [num_error_vals]) - outs = yield args, {} - yield outs, err_tree + # Construct the defaul jvp function, without checkify-ing. + @lu.wrap_init + def jvp(*xs): + # TODO(lenamartens, sharadmv): why not checkify here? + jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() + n, ragged = divmod(len(xs), 2) + assert not ragged + primals, tangents = xs[num_consts:n], xs[n+num_consts:] + return core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *tangents) -@lu.transformation_with_aux -def query_error_effects(*args): - (error, *outs) = yield args, {} - yield (error, *outs), set(error._pred.keys()) + jvp, jvp_out_tree = flatten_fun_output(jvp) + all_outs = custom_derivatives.custom_jvp_call_p.bind(partial_checkify, jvp, + *err_vals, *in_vals) + fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree) + if fst: + err_and_out_tree, _ = out_metadata + out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) + else: + err_vals, out_vals = split_list(all_outs, [len(err_vals)]) + # forward input error to output + out_err = jtu.tree_unflatten(err_tree, err_vals) + return out_err, out_vals +error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule -def checkify_jaxpr(jaxpr, error, - enabled_errors) -> Tuple[core.ClosedJaxpr, - Tuple[PyTreeDef, - FrozenSet[ErrorEffect]]]: - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) - return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals) +def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, + fwd_jaxpr_thunk, num_consts, bwd, out_trees): + err_vals, err_tree = jtu.tree_flatten(in_err) + fun = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, + fun_jaxpr.consts, enabled_errors, err_tree)) + fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) -def checkify_fun_to_jaxpr( - f, error, enabled_errors, - in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]: - flat_error_vals, in_tree = tree_flatten(error) - f = checkify_subtrace(f) - f = checkify_traceable(f, enabled_errors) - f, error_effect = query_error_effects(f) - in_tree = jtu.tree_structure((error, *in_avals)) - f, out_tree = flatten_fun_nokwargs(f, in_tree) - err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), - flat_error_vals) - avals_in = [*err_vals, *in_avals] - jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) - return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect())) + @lu.wrap_init + def fwd(*xs): + # TODO(lenamartens, sharadmv): why not checkify here? + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() + xs_without_consts = xs[num_consts:] + return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) + + fwd, fwd_out_tree = flatten_fun_output(fwd) + all_outs = custom_derivatives.custom_vjp_call_p.bind( + fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees) + fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) + if fst: + err_and_out_tree, _ = out_metadata + out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) + else: + err_vals, out_vals = split_list(all_outs, [len(err_vals)]) + # forward input error to output + out_err = jtu.tree_unflatten(err_tree, err_vals) + return out_err, out_vals +error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule + +def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): + del debug + new_error = tree_unflatten(err_tree, args) + # Split up new_error into error to be functionalized if it's included in + # enabled_errors (=discharged_error) and an error to be defunctionalized if + # it's not included (=recharged_error) + discharged_error = error + recharged_error = init_error + for error_effect in new_error._pred.keys(): + pred = new_error._pred[error_effect] + code = new_error._code[error_effect] + payload = new_error._payload[error_effect] + if error_effect.error_type in enabled_errors: + discharged_error = update_error(discharged_error, pred, code, {}, payload, + error_effect) + else: + recharged_error = update_error(recharged_error, pred, code, {}, payload, + error_effect) + + discharged_error = discharged_error._replace( + _metadata={**new_error._metadata, **discharged_error._metadata}) + recharged_error = recharged_error._replace(_metadata=new_error._metadata) + # TODO(lenamartens): we actually need to recharge, but this would be a + # breaking API change so leaving for a follow-up. + # check_error(recharged_error) + return discharged_error, [] +error_checks[check_p] = check_discharge_rule +## checkify public api + +user_checks = frozenset({FailedCheckError}) +nan_checks = frozenset({NaNError}) +index_checks = frozenset({OOBError}) +div_checks = frozenset({DivisionByZeroError}) +float_checks = nan_checks | div_checks +automatic_checks = float_checks | index_checks +all_checks = automatic_checks | user_checks + + +def checkify(f: Callable[..., Out], + errors: FrozenSet[ErrorCategory] = user_checks + ) -> Callable[..., Tuple[Error, Out]]: + """Functionalize `check` calls in `fun`, and optionally add run-time error checks. + + Run-time errors are either user-added :func:`~check` assertions, or + automatically added checks like NaN checks, depending on the ``errors`` + argument. + + The returned function will return an Error object `err` along with the output + of the original function. ``err.get()`` will either return ``None`` (if no + error occurred) or a string containing an error message. This error message + will correspond to the first error which occurred. ``err.throw()`` will raise + a ValueError with the error message if an error occurred. + + By default only user-added :func:`~check` assertions are enabled. You can + enable automatic checks through the ``errors`` argument. + + The automatic check sets which can be enabled, and when an error is generated: + - ``user_checks``: a :func:`~check` evaluated to False. + - ``nan_checks``: a floating-point operation generated a NaN value + as output. + - ``div_checks``: a division by zero. + - ``index_checks``: an index was out-of-bounds. + + Multiple categories can be enabled together by passing in an error `Set` (eg. + ``errors=nan_checks``). Multiple sets can be re-combined (eg. + ``errors=float_checks|user_checks``) + + Args: + fun: Callable which can contain user checks (see :func:`~check`). + errors: A set of ErrorCategory values which defines the set of enabled + checks. By default only explicit ``checks`` are enabled + (``user_checks``). You can also for example enable NAN and + DIV errors by passing the ``float_checks`` set, or for + example combine multiple sets through set operations + (``float_checks | user_checks``) + Returns: + A function which accepts the same arguments as ``fun`` and returns as output + a pair where the first element is an ``Error`` value, representing the first + failed :func:`~check`, and the second element is the original output of + ``fun``. + + For example: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> + >>> @jax.jit + ... def f(x): + ... y = jnp.sin(x) + ... return x+y + >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) + >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin + """ + @traceback_util.api_boundary + def checked_fun(*args, **kwargs): + # stage: + fun = lu.wrap_init(f) + flat_args, in_tree = jtu.tree_flatten((args, kwargs)) + flat_fun, out_tree = flatten_fun(fun, in_tree) + flat_avals = map(get_shaped_aval, flat_args) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree() + # checkify: + flat_args = jtu.tree_leaves((args, kwargs)) + error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors, + init_error, *flat_args) + return error, jtu.tree_unflatten(out_tree, out_flat) + return checked_fun def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: """Check a predicate, add an error with msg if predicate is False. @@ -757,517 +1239,4 @@ def check_error(error: Error) -> None: if not isinstance(error, Error): raise ValueError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') - _check_error(error, debug=False) - - -## check primitive - -check_p = core.Primitive('check') -check_p.multiple_results = True # zero results - -# TODO(lenamartens): inherit from Exception instead of ValueError. -class JaxRuntimeError(ValueError): - pass - -@check_p.def_impl -def check_impl(*args, err_tree, debug): - if debug: - # NOOP (check will only trigger when discharged) - return [] - error = tree_unflatten(err_tree, args) - exc = error.get_exception() - if exc: - raise JaxRuntimeError(str(exc)) from exc - return [] - -@check_p.def_effectful_abstract_eval -def check_abstract_eval(*args, err_tree, debug): - del debug - return [], set(tree_unflatten(err_tree, args)._pred.keys()) - -# TODO(lenamartens) add in-depth error explanation to link to in module docs. -functionalization_error = ValueError( - 'Cannot abstractly evaluate a checkify.check which was not' - ' functionalized. This probably means you tried to stage' - ' (jit/scan/pmap/...) a `check` without functionalizing it' - ' through `checkify.checkify`.' - ) - -def check_lowering_rule(ctx, *args, err_tree, debug): - if debug: - # NOOP (check will only trigger when discharged) - return [] - if not config.jax_experimental_unsafe_xla_runtime_errors: - raise functionalization_error - - out_op, _, keep_alive = mlir.emit_python_callback( - ctx, callback=functools.partial(python_err, err_tree), - token=None, - operands=args, - operand_avals=list(ctx.avals_in), - result_avals=list(ctx.avals_out), - has_side_effect=True) - ctx.module_context.add_keepalive(keep_alive) - return out_op - -def check_lowering_rule_unsupported(*a, debug, **k): - if debug: - return [] - raise functionalization_error - -def python_err(err_tree, *args): - error = tree_unflatten(err_tree, args) - _check_error(error) - return [] - -mlir.register_lowering(check_p, check_lowering_rule_unsupported, - platform='tpu') -mlir.register_lowering(check_p, check_lowering_rule, - platform='cpu') -mlir.register_lowering(check_p, check_lowering_rule, - platform='gpu') - -def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): - size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) - if dim is not batching.not_mapped) - batched_args = (batching.bdim_at_front(a, d, size) - for a, d in zip(batched_args, batch_dims)) - err = tree_unflatten(err_tree, batched_args) - _check_error(err, debug=debug) - return [], [] -batching.primitive_batchers[check_p] = check_batching_rule - -def check_jvp_rule(primals, _, *, err_tree, debug): - # Check primals, discard tangents. - check_p.bind(*primals, err_tree=err_tree, debug=debug) - return [], [] -ad.primitive_jvps[check_p] = check_jvp_rule - -## checkify rules - -def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]: - # TODO(lenamartens): use c++ version from XLA? - tb = None - import inspect - for frame_info in inspect.stack(): - frame = frame_info.frame - if skip_frames: - skip_frames -= 1 - elif not traceback_util.include_frame(frame): - continue - else: - tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) - return tb - -def summary() -> str: - return str(source_info_util.summarize(source_info_util.current())) - -def nan_error_check(prim, error, enabled_errors, *in_vals, **params): - out = prim.bind(*in_vals, **params) - err = check_nans(prim, error, enabled_errors, out) - return out, err - -def check_nans(prim, error, enabled_errors, out): - if NaNError not in enabled_errors: - return error - - def isnan(x): - if isinstance(x, prng.PRNGKeyArray): - return False - return jnp.any(jnp.isnan(x)) - - any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) - if prim.multiple_results else isnan(out)) - return assert_func(error, any_nans, NaNError(summary(), prim.name)) - - -# All primitives which can generate a NaN. -nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, - lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, - lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, - lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, - lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, - lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, - lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, - lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, - lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, - lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, - lax.reduce_sum_p, lax.reduce_window_p, - lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, - lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, - lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] - -for prim in nan_primitives: - error_checks[prim] = functools.partial(nan_error_check, prim) - - -def gather_error_check(error, enabled_errors, operand, start_indices, *, - dimension_numbers, slice_sizes, unique_indices, - indices_are_sorted, mode, fill_value): - out = lax.gather_p.bind( - operand, start_indices, dimension_numbers=dimension_numbers, - slice_sizes=slice_sizes, unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) - - if OOBError not in enabled_errors: - return out, error - - # compare to OOB masking logic in lax._gather_translation_rule - dnums = dimension_numbers - operand_dims = np.array(operand.shape) - num_batch_dims = len(start_indices.shape) - 1 - - upper_bound = operand_dims[np.array(dnums.start_index_map)] - upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] - upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) - oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) - - payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) - return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload)) -error_checks[lax.gather_p] = gather_error_check - -def div_error_check(error, enabled_errors, x, y): - """Checks for division by zero and NaN.""" - if DivisionByZeroError in enabled_errors: - any_zero = jnp.any(jnp.equal(y, 0)) - error = assert_func(error, any_zero, DivisionByZeroError(summary())) - return nan_error_check(lax.div_p, error, enabled_errors, x, y) -error_checks[lax.div_p] = div_error_check - -def oob_payload(oob_mask, indices, dims_map, operand_shape): - # Get first OOB index, axis and axis size so it can be added to the error msg. - flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) - multi_idx = jnp.unravel_index(flat_idx, indices.shape) - oob_axis = jnp.array(dims_map)[multi_idx[-1]] - oob_axis_size = jnp.array(operand_shape)[oob_axis] - oob_index = jnp.ravel(indices)[flat_idx] - payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) - return payload - -def scatter_oob(operand, indices, updates, dnums): - # Ref: see clamping code used in scatter_translation_rule - slice_sizes = [] - pos = 0 - for i in range(len(operand.shape)): - if i in dnums.inserted_window_dims: - slice_sizes.append(1) - else: - slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) - pos += 1 - - upper_bound = np.array([operand.shape[i] - slice_sizes[i] - for i in dnums.scatter_dims_to_operand_dims], - np.int64) - upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) - upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, - (len(indices.shape) - 1,)) - - lower_oob = jnp.less(indices, 0) - upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) - oob_mask = jnp.logical_or(lower_oob, upper_oob) - payload = oob_payload(oob_mask, indices, - dnums.scatter_dims_to_operand_dims, operand.shape) - return jnp.any(oob_mask), payload - -def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, - *, update_jaxpr, update_consts, dimension_numbers, - indices_are_sorted, unique_indices, mode): - """Checks if indices are within bounds and update does not generate NaN.""" - out = prim.bind( - operand, indices, updates, update_jaxpr=update_jaxpr, - update_consts=update_consts, dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) - - if OOBError not in enabled_errors: - return out, error - - out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) - oob_error = OOBError(summary(), prim.name, operand.shape, payload) - error = assert_func(error, out_of_bounds, oob_error) - return out, check_nans(prim, error, enabled_errors, out) -error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) -error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, - lax.scatter_add_p) -error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, - lax.scatter_mul_p) -error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, - lax.scatter_min_p) -error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, - lax.scatter_max_p) - -def cond_error_check(error, enabled_errors, index, *ops, branches, linear): - _, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error, - enabled_errors) - for jxpr in branches) - _, effects = unzip2(out_trees_and_effects) - - merged_error = error._add_placeholder_effects(set().union(*effects)) - new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error, - enabled_errors) - for jxpr in branches) - out_trees, _ = unzip2(out_trees_and_effects) - - flat_error, _ = tree_flatten(merged_error) - new_linear = (*[False] * len(flat_error), *linear) - err_and_outs = lax.cond_p.bind( - index, *flat_error, *ops, - branches=tuple(new_branches), linear=new_linear) - - # we need to merge metadata across out_trees (a tuple) - # maybe there's a better way to do this, but we can use the outs - # to unflatten all trees. - err0, *out = tree_unflatten(out_trees[0], err_and_outs) - merged_metadata = err0._metadata - for tr in out_trees[1:]: - err, *_ = tree_unflatten(tr, err_and_outs) - merged_metadata = {**merged_metadata, **err._metadata} - return out, err0._replace(_metadata=merged_metadata) -error_checks[lax.cond_p] = cond_error_check - -def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, - num_consts, num_carry, linear, unroll): - consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) - _, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors) - merged_error = error._add_placeholder_effects(effects) - checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors) - - flat_error_vals, _ = tree_flatten(merged_error) - tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs)) - checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) - new_linear = (*[False] * len(flat_error_vals), *linear) - new_in_flat = [*consts, *flat_error_vals, *carry, *xs] - err_and_out = lax.scan_p.bind( - *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, - num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals), - linear=new_linear, unroll=unroll) - err, *out = tree_unflatten(out_tree, err_and_out) - return out, err - -error_checks[lax.scan_p] = scan_error_check - -def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts): - cond_f = core.jaxpr_as_fun(cond_jaxpr) - body_f = core.jaxpr_as_fun(body_jaxpr) - def new_body_f(*vals): - out = body_f(*vals) - # This checks if the next cond application will error - _ = cond_f(*c_consts, *out) - return out - return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, - body_jaxpr.in_avals) - -def ignore_error_output_jaxpr(jaxpr, num_error_vals): - """Constructs a checked jaxpr which does not output its error value.""" - consts = jaxpr.consts - jaxpr = jaxpr.jaxpr - new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) - return core.ClosedJaxpr(new_jaxpr, consts) - -def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, - cond_jaxpr, body_nconsts, body_jaxpr): - if cond_jaxpr.out_avals[0].shape: - # TODO(lenamartens, sharadmv): support batched while. - raise ValueError('Checkify does not support batched while-loops ' - '(checkify-of-vmap-of-while). \nHint: if possible, move ' - 'the vmap to the outer level to get ' - 'vmap-of-checkify-of-while.') - - err_vals, _ = tree_flatten(error) - c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) - - # Check if the first cond application will error. - checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr( - cond_jaxpr, error, enabled_errors) - outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry) - error, _ = tree_unflatten(cond_out_tree, outs) - - checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr( - cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) - # merged error! - error = error._add_placeholder_effects(error_effects) - checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr( - cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) - err_vals = jtu.tree_leaves(error) - num_error_vals = len(err_vals) - to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) - checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) - - checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors) - compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) - to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) - compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) - new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] - - all_out_vals = lax.while_p.bind( - *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, - body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) - # body_out_tree will have all the metadata of cond because it executes a cond! - # only need to merge metadata on the input error. - error, *out = tree_unflatten(body_out_tree, all_out_vals) - return out, error -error_checks[lax.while_p] = while_loop_error_check - - -def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, - in_shardings, out_shardings, resource_env, - donated_invars, name, - in_positional_semantics, out_positional_semantics, - keep_unused, inline): - checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error, - enabled_errors) - out_error = error._add_placeholder_effects(effects) - - flat_error_vals = jtu.tree_leaves(error) - num_error_vals = len(flat_error_vals) - new_vals_in = [*flat_error_vals, *vals_in] - - sharding = OpShardingSharding.get_replicated( - list(resource_env.physical_mesh.devices.flat)) - new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)), - *out_shardings) - - if config.jax_array: - pos_sem = maps._PositionalSemantics.GLOBAL - else: - pos_sem = maps._positional_semantics.val - - if not isinstance(in_positional_semantics, Iterable): - in_positional_semantics = (in_positional_semantics,) - if not isinstance(out_positional_semantics, Iterable): - out_positional_semantics = (out_positional_semantics,) - new_positional_sems_in = (*[pos_sem] * num_error_vals, - *in_positional_semantics) - new_positional_sems_out = (*[pos_sem] * num_error_vals, - *out_positional_semantics) - new_donated_invars = (*[False] * num_error_vals, *donated_invars) - - err_and_out = pjit.pjit_p.bind( - *new_vals_in, - jaxpr=checked_jaxpr, - in_shardings=new_in_shardings, - out_shardings=new_out_shardings, - resource_env=resource_env, - donated_invars=new_donated_invars, - name=name, - in_positional_semantics=new_positional_sems_in, - out_positional_semantics=new_positional_sems_out, - keep_unused=keep_unused, - inline=inline) - err, *out = tree_unflatten(out_tree, err_and_out) - return out, err -error_checks[pjit.pjit_p] = pjit_error_check - - -def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): - del debug - new_error = tree_unflatten(err_tree, args) - # Split up new_error into error to be functionalized if it's included in - # enabled_errors (=discharged_error) and an error to be defunctionalized if - # it's not included (=recharged_error) - discharged_error = error - recharged_error = init_error - for error_effect in new_error._pred.keys(): - pred = new_error._pred[error_effect] - code = new_error._code[error_effect] - payload = new_error._payload[error_effect] - if error_effect.error_type in enabled_errors: - discharged_error = update_error(discharged_error, pred, code, {}, payload, - error_effect) - else: - recharged_error = update_error(recharged_error, pred, code, {}, payload, - error_effect) - - discharged_error = discharged_error._replace( - _metadata={**new_error._metadata, **discharged_error._metadata}) - recharged_error = recharged_error._replace(_metadata=new_error._metadata) - # TODO(lenamartens): we actually need to recharge, but this would be a - # breaking API change so leaving for a follow-up. - # check_error(recharged_error) - return [], discharged_error -error_checks[check_p] = check_discharge_rule - - -## checkify api - -user_checks = frozenset({FailedCheckError}) -nan_checks = frozenset({NaNError}) -index_checks = frozenset({OOBError}) -div_checks = frozenset({DivisionByZeroError}) -float_checks = nan_checks | div_checks -automatic_checks = float_checks | index_checks -all_checks = automatic_checks | user_checks - -Out = TypeVar('Out') - - -def checkify(fun: Callable[..., Out], - errors: FrozenSet[ErrorCategory] = user_checks - ) -> Callable[..., Tuple[Error, Out]]: - """Functionalize `check` calls in `fun`, and optionally add run-time error checks. - - Run-time errors are either user-added :func:`~check` assertions, or - automatically added checks like NaN checks, depending on the ``errors`` - argument. - - The returned function will return an Error object `err` along with the output - of the original function. ``err.get()`` will either return ``None`` (if no - error occurred) or a string containing an error message. This error message - will correspond to the first error which occurred. ``err.throw()`` will raise - a ValueError with the error message if an error occurred. - - By default only user-added :func:`~check` assertions are enabled. You can - enable automatic checks through the ``errors`` argument. - - The automatic check sets which can be enabled, and when an error is generated: - - ``user_checks``: a :func:`~check` evaluated to False. - - ``nan_checks``: a floating-point operation generated a NaN value - as output. - - ``div_checks``: a division by zero. - - ``index_checks``: an index was out-of-bounds. - - Multiple categories can be enabled together by passing in an error `Set` (eg. - ``errors=nan_checks``). Multiple sets can be re-combined (eg. - ``errors=float_checks|user_checks``) - - Args: - fun: Callable which can contain user checks (see :func:`~check`). - errors: A set of ErrorCategory values which defines the set of enabled - checks. By default only explicit ``checks`` are enabled - (``user_checks``). You can also for example enable NAN and - DIV errors by passing the ``float_checks`` set, or for - example combine multiple sets through set operations - (``float_checks | user_checks``) - Returns: - A function which accepts the same arguments as ``fun`` and returns as output - a pair where the first element is an ``Error`` value, representing the first - failed :func:`~check`, and the second element is the original output of - ``fun``. - - For example: - - >>> import jax - >>> import jax.numpy as jnp - >>> from jax.experimental import checkify - >>> - >>> @jax.jit - ... def f(x): - ... y = jnp.sin(x) - ... return x+y - >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) - >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin - """ - @traceback_util.api_boundary - def checked_fun(*args, **kwargs): - args_flat, in_tree = tree_flatten((args, kwargs)) - f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) - error, out_flat = checkify_flat(f, errors, *args_flat) - out = tree_unflatten(out_tree(), out_flat) - return error, out - return checked_fun diff --git a/jax/_src/checkify.py.orig b/jax/_src/checkify.py.orig new file mode 100644 index 000000000..e3526c255 --- /dev/null +++ b/jax/_src/checkify.py.orig @@ -0,0 +1,1760 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools +import itertools as it +import types +from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List, Sequence, Any + +import jax +from jax import lax +from jax import linear_util as lu +<<<<<<< HEAD +from jax._src import core +======= +from jax._src import custom_derivatives +>>>>>>> da3607926 (Checkify: switch to initial-style.) +from jax._src import prng +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src.config import config +from jax._src.lax import control_flow as cf +from jax._src.sharding import OpShardingSharding +from jax._src.typing import Array +from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, safe_zip) +from jax.api_util import flatten_fun +from jax.experimental import maps +from jax.experimental import pjit +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.tree_util import tree_flatten +from jax.tree_util import tree_map +from jax.tree_util import tree_unflatten +import jax.numpy as jnp +import jax.tree_util as jtu +import numpy as np + +source_info_util.register_exclusion(__file__) +traceback_util.register_exclusion(__file__) + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Bool = Union[bool, Array] +Int = Union[int, Array] +ErrorCategory = Type['JaxException'] +Payload = List[Union[np.ndarray, Array]] +PyTreeDef = jtu.PyTreeDef +Out = TypeVar('Out') + +## Utils + +def popattr(obj, attrname): + val = getattr(obj, attrname) + delattr(obj, attrname) + return val + +def setnewattr(obj, name, val): + sentinel = object() + assert getattr(obj, name, sentinel) is sentinel + setattr(obj, name, val) + +# Concrete errors + +class JaxException(Exception): + """Python exception which can contain an error message with JAX run-time info.""" + + def __init__(self, traceback_info): + self.traceback_info = traceback_info + # TODO(lenamartens): re-enable tracebacks when they don't leak tracers. + # self.with_traceback(self.traceback_info) + + def __init_subclass__(cls): + jtu.register_pytree_node_class(cls) + + def tree_flatten(self): + return ([], self.traceback_info) + + @classmethod + def tree_unflatten(cls, metadata, payload): + del payload + return cls(metadata) + + def get_effect_type(self) -> core.Effect: + pass + + +@functools.total_ordering +@dataclasses.dataclass(eq=True, frozen=True) +class ErrorEffect: + error_type: Type[JaxException] + shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...] + + def __post_init__(self): + cf.allowed_effects.add(self) + mlir.lowerable_effects.add(self) + + def __lt__(self, other: 'ErrorEffect'): + shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable + for sd in x.shape_dtypes) + unpack = lambda x: (str(x.error_type), shape_dtypes(x)) + return (unpack(self) < unpack(other)) + + +class DivisionByZeroError(JaxException): + + def __str__(self): + return f'division by zero at {self.traceback_info}' + + def get_effect_type(self): + return ErrorEffect(DivisionByZeroError, ()) + +class NaNError(JaxException): + + def __init__(self, traceback_info, primitive_name): + super().__init__(traceback_info) + self.prim = primitive_name + + def tree_flatten(self): + return ([], (self.traceback_info, self.prim)) + + @classmethod + def tree_unflatten(cls, metadata, _): + return cls(*metadata) + + def get_effect_type(self): + return ErrorEffect(NaNError, ()) + + def __str__(self): + return f'nan generated by primitive: {self.prim} at {self.traceback_info}' + +class OOBError(JaxException): + + def __init__(self, traceback_info, primitive_name, operand_shape, payload): + super().__init__(traceback_info) + self.prim = primitive_name + self.operand_shape = operand_shape + self._payload = payload + + def tree_flatten(self): + return ([self._payload], (self.traceback_info, self.prim, self.operand_shape)) + + @classmethod + def tree_unflatten(cls, metadata, payload): + return cls(*metadata, payload[0]) + + def __str__(self): + return (f'out-of-bounds indexing for array of ' + f'shape {self.operand_shape}: ' + f'index {self._payload[0]} is out of bounds for axis ' + f'{self._payload[1]} with size {self._payload[2]}. ' + f'Failed at {self.traceback_info}') + + def get_effect_type(self): + return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),)) + +class FailedCheckError(JaxException): + + def __init__(self, traceback_info, fmt_string, *a, **k): + super().__init__(traceback_info) + self.fmt_string = fmt_string + self.args = a + self.kwargs = k + + def tree_flatten(self): + return ((self.args, self.kwargs), # leaves + (self.traceback_info, self.fmt_string)) # treedef + + @classmethod + def tree_unflatten(cls, metadata, payload): + args, kwargs = payload + return cls(*metadata, *args, **kwargs) + + def __str__(self): + return (self.fmt_string.format(*self.args, **self.kwargs) + + f' (check failed at {self.traceback_info})') + + def get_effect_type(self): + vals = jtu.tree_leaves((self.args, self.kwargs)) + return ErrorEffect( + FailedCheckError, + tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)) + +@dataclasses.dataclass +class BatchedError(JaxException): + error_mapping: Dict[Tuple[int, ...], JaxException] + + def __post_init__(self): + traceback_info = list(self.error_mapping.values())[0].traceback_info + super().__init__(traceback_info) + + + def __str__(self): + return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}' + for idx, e in self.error_mapping.items()) + + +# Error Value + +@jtu.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class Error: + _pred: Dict[ErrorEffect, Bool] + _code: Dict[ErrorEffect, Int] + _metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef. + _payload: Dict[ErrorEffect, Payload] + + def get(self) -> Optional[str]: + """Returns error message if error happened, None if no error happened.""" + exp = self.get_exception() + if exp is not None: + return str(exp) + return None + + def get_exception(self) -> Optional[JaxException]: + """Returns Python exception if error happened, None if no error happened.""" + if any(map(np.shape, self._pred.values())): + return self._get_batched_exception() + else: + min_code = None + cur_effect = None + for error_effect, code in self._code.items(): + if self._pred[error_effect]: + if min_code is None or code < min_code: + min_code = code + cur_effect = error_effect + + if cur_effect is not None: + return tree_unflatten(self._metadata[int(min_code)], # type: ignore + self._payload[cur_effect]) + return None + + def throw(self): + _check_error(self) + + def __str__(self): + return f'Error({self.get()})' + + # Internal helpers + + def _get_batched_exception(self): + shape = np.shape(list(self._pred.values())[0]) + error_mapping = {} + for idx in np.ndindex(*shape): + min_code = None + cur_effect = None + for error_effect, code in self._code.items(): + if self._pred[error_effect][idx]: # type: ignore + if min_code is None or code[idx] < min_code: + min_code = code[idx] # type: ignore + cur_effect = error_effect + + if cur_effect is not None: + payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect]) + jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore + error_mapping[idx] = jax_error + return BatchedError(error_mapping) + + def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload): + new_errs = {**self._pred, **{effect_type: pred}} # type: ignore + new_codes = {**self._code, **{effect_type: code}} # type: ignore + new_payload = {**self._payload, **{effect_type: payload}} # type: ignore + new_metadata = {**self._metadata, **metadata} + return Error(new_errs, new_codes, new_metadata, new_payload) + + def _add_placeholder_effects(self, effects: Set[ErrorEffect]): + """Fill out Error with `effects` and np.ones arrays of their payloads.""" + new_err = self._pred.copy() + new_code = self._code.copy() + new_payload = self._payload.copy() + for effect in effects: + if effect not in self._pred.keys(): + new_err[effect] = False + new_payload[effect] = list( + tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes)) + # The error value associated with this effect will never become True, so + # we don't need to set a meaningful code. + new_code[effect] = -1 + return Error(new_err, new_code, self._metadata, new_payload) + + def _replace(self, *args, **kwargs): + return dataclasses.replace(self, *args, **kwargs) + + # PyTree methods + + def tree_flatten(self): + return ((self._pred, self._code, self._payload), (self._metadata)) + + @classmethod + def tree_unflatten(cls, metadata, data): + pred, code, payload = data + return cls(pred, code, metadata, payload) + +init_error = Error({}, {}, {}, {}) # value used as initial (empty) error. +next_code = it.count(1).__next__ # globally unique ids, could be uuid4 + +def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error: + code = next_code() + effect_type = new_error.get_effect_type() + new_payload, new_metadata = tree_flatten(new_error) + return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type) + +def update_error(error, pred, code, metadata, payload, effect_type): + err_of_type = error._pred.get(effect_type, False) + out_err = err_of_type | pred + out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code) + cur_payload = error._payload.get(effect_type, None) + if cur_payload is not None: + out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload) + else: + out_payload = payload + return error._update(effect_type, out_err, out_code, metadata, out_payload) + + +## Checkify transformation for plumbing functional error values. + +@lu.transformation_with_aux +def _flatten_and_get_error_metadata_thunk(*invals): + error, out = yield invals, {} + out_vals, out_tree = jtu.tree_flatten((error, out)) + yield out_vals, (out_tree, set(error._pred.keys())) + +def default_checkify_rule(primitive: core.Primitive, error: Error, + enabled_errors, *invals: core.Value, + **params: Any) -> Tuple[Error, Sequence[core.Value]]: + """Default rule for primitives in `checkify` interpreter.""" + if 'call_jaxpr' not in params: + # Default non-HOP case: just call primitive and don't update error. + return error, primitive.bind(*invals, **params) + + # Code below handles call- and map-primitives, by recursively calling + # checkify_jaxpr. + err_vals, err_tree = jtu.tree_flatten(error) + num_error_vals = len(err_vals) + if 'donated_invars' in params: + params = dict(params, donated_invars=(*[False]*num_error_vals, + *params['donated_invars'])) + + # call_jaxpr handling + call_jaxpr = params.pop('call_jaxpr') + partial_checkify = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, call_jaxpr, (), enabled_errors, + err_tree)) + partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( + partial_checkify) + + # map-specific params handling. + if isinstance(primitive, core.MapPrimitive): + # Update `in_axes` and `out_axes_thunk` params for map primitive. + out_val_axes = params.pop('out_axes') + + @as_hashable_function(closure=out_val_axes) + def out_axes_thunk(): + out_err_num = metadata()[0].num_leaves - len(out_val_axes) + return (*(0,)*out_err_num, *out_val_axes) + + params = dict(params, + in_axes=(*(None,)*num_error_vals, *params['in_axes']), + out_axes_thunk=out_axes_thunk) + + all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params) + + out_tree, _ = metadata() + error, out_vals = tree_unflatten(out_tree, all_vals) + return error, out_vals + +def get_shaped_aval(val): + return core.raise_to_shaped(core.get_aval(val)) + +def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors, + error: Error, *args) -> Tuple[Error, List[core.Value]]: + err_vals, err_tree = jtu.tree_flatten(error) + return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts, + enabled_errors, err_tree, *err_vals, *args) + +def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], + enabled_errors, err_tree: PyTreeDef, + *args: core.Value) -> Tuple[Error, List[Any]]: + env: Dict[core.Var, Any] = {} + err_vals, in_args = split_list(args, [err_tree.num_leaves]) + error = jtu.tree_unflatten(err_tree, err_vals) + + def read_env(var: core.Atom): + if isinstance(var, core.Literal): + return var.val + return env[var] + + def write_env(var: core.Var, val: Any): + env[var] = val + + map(write_env, jaxpr.constvars, consts) + map(write_env, jaxpr.invars, in_args) + + # interpreter loop + for eqn in jaxpr.eqns: + invals = map(read_env, eqn.invars) + checkify_rule = error_checks.get( + eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive)) + error, outvals = checkify_rule(error, enabled_errors, *invals, **eqn.params) + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) + else: + write_env(eqn.outvars[0], outvals) + + return error, map(read_env, jaxpr.outvars) + +@lu.transformation_with_aux +def flatten_fun_output(*args): + ans = yield args, {} + yield tree_flatten(ans) + + +def _reduce_any_error(error: Error): + out_error = init_error + for error_effect in error._pred.keys(): + errs, codes, payloads = (error._pred[error_effect], + error._code[error_effect], + error._payload[error_effect]) + reduced_idx = jnp.argsort(errs)[-1] + pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx], + (errs, codes, payloads)) + out_error = out_error._update(error_effect, pred, code, {}, payload) + + out_error = out_error._replace(_metadata=error._metadata) + return out_error + +## check_p primitive + +check_p = core.Primitive('check') +check_p.multiple_results = True # zero results + +# TODO(lenamartens): inherit from Exception instead of ValueError. +class JaxRuntimeError(ValueError): + pass + +@check_p.def_impl +def check_impl(*args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + error = tree_unflatten(err_tree, args) + exc = error.get_exception() + if exc: + raise JaxRuntimeError(str(exc)) from exc + return [] + +@check_p.def_effectful_abstract_eval +def check_abstract_eval(*args, err_tree, debug): + del debug + return [], set(tree_unflatten(err_tree, args)._pred.keys()) + +# TODO(lenamartens) add in-depth error explanation to link to in module docs. +functionalization_error = ValueError( + 'Cannot abstractly evaluate a checkify.check which was not' + ' functionalized. This probably means you tried to stage' + ' (jit/scan/pmap/...) a `check` without functionalizing it' + ' through `checkify.checkify`.' + ) + +def check_lowering_rule(ctx, *args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + if not config.jax_experimental_unsafe_xla_runtime_errors: + raise functionalization_error + + out_op, _, keep_alive = mlir.emit_python_callback( + ctx, callback=functools.partial(python_err, err_tree), + token=None, + operands=args, + operand_avals=list(ctx.avals_in), + result_avals=list(ctx.avals_out), + has_side_effect=True) + ctx.module_context.add_keepalive(keep_alive) + return out_op + +def check_lowering_rule_unsupported(*a, debug, **k): + if debug: + return [] + raise functionalization_error + +def python_err(err_tree, *args): + error = tree_unflatten(err_tree, args) + _check_error(error) + return [] + +mlir.register_lowering(check_p, check_lowering_rule_unsupported, + platform='tpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='cpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='gpu') + +def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): + size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) + if dim is not batching.not_mapped) + batched_args = (batching.bdim_at_front(a, d, size) + for a, d in zip(batched_args, batch_dims)) + err = tree_unflatten(err_tree, batched_args) + _check_error(err, debug=debug) + return [], [] +batching.primitive_batchers[check_p] = check_batching_rule + +def check_jvp_rule(primals, _, *, err_tree, debug): + # Check primals, discard tangents. + check_p.bind(*primals, err_tree=err_tree, debug=debug) + return [], [] +ad.primitive_jvps[check_p] = check_jvp_rule + +## checkify rules + +ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) +error_checks: Dict[core.Primitive, ErrorCheckRule] = {} + + +def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]: + # TODO(lenamartens): use c++ version from XLA? + tb = None + import inspect + for frame_info in inspect.stack(): + frame = frame_info.frame + if skip_frames: + skip_frames -= 1 + elif not traceback_util.include_frame(frame): + continue + else: + tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + return tb + +def summary() -> str: + return str(source_info_util.summarize(source_info_util.current())) + +def nan_error_check(prim, error, enabled_errors, *in_vals, **params): + out = prim.bind(*in_vals, **params) + err = check_nans(prim, error, enabled_errors, out) + return err, out + +def check_nans(prim, error, enabled_errors, out): + if NaNError not in enabled_errors: + return error + + def isnan(x): + if isinstance(x, prng.PRNGKeyArray): + return False + return jnp.any(jnp.isnan(x)) + + any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) + if prim.multiple_results else isnan(out)) + return assert_func(error, any_nans, NaNError(summary(), prim.name)) + + +# All primitives which can generate a NaN. +nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, + lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, + lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, + lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, + lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, + lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, + lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, + lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, + lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, + lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, + lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_sum_p, lax.reduce_window_p, + lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, + lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, + lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] + +for _prim in nan_primitives: + error_checks[_prim] = functools.partial(nan_error_check, _prim) + + +def gather_error_check(error, enabled_errors, operand, start_indices, *, + dimension_numbers, slice_sizes, unique_indices, + indices_are_sorted, mode, fill_value): + out = lax.gather_p.bind( + operand, start_indices, dimension_numbers=dimension_numbers, + slice_sizes=slice_sizes, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) + + if OOBError not in enabled_errors: + return error, out + + # compare to OOB masking logic in lax._gather_translation_rule + dnums = dimension_numbers + operand_dims = np.array(operand.shape) + num_batch_dims = len(start_indices.shape) - 1 + + upper_bound = operand_dims[np.array(dnums.start_index_map)] + upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] + upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) + oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) + + payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) + error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload)) + return error, out +error_checks[lax.gather_p] = gather_error_check + +def div_error_check(error, enabled_errors, x, y): + """Checks for division by zero and NaN.""" + if DivisionByZeroError in enabled_errors: + any_zero = jnp.any(jnp.equal(y, 0)) + error = assert_func(error, any_zero, DivisionByZeroError(summary())) + return nan_error_check(lax.div_p, error, enabled_errors, x, y) +error_checks[lax.div_p] = div_error_check + +def oob_payload(oob_mask, indices, dims_map, operand_shape): + # Get first OOB index, axis and axis size so it can be added to the error msg. + flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) + multi_idx = jnp.unravel_index(flat_idx, indices.shape) + oob_axis = jnp.array(dims_map)[multi_idx[-1]] + oob_axis_size = jnp.array(operand_shape)[oob_axis] + oob_index = jnp.ravel(indices)[flat_idx] + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + return payload + +def scatter_oob(operand, indices, updates, dnums): + # Ref: see clamping code used in scatter_translation_rule + slice_sizes = [] + pos = 0 + for i in range(len(operand.shape)): + if i in dnums.inserted_window_dims: + slice_sizes.append(1) + else: + slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) + pos += 1 + + upper_bound = np.array([operand.shape[i] - slice_sizes[i] + for i in dnums.scatter_dims_to_operand_dims], + np.int64) + upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) + upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, + (len(indices.shape) - 1,)) + + lower_oob = jnp.less(indices, 0) + upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) + oob_mask = jnp.logical_or(lower_oob, upper_oob) + payload = oob_payload(oob_mask, indices, + dnums.scatter_dims_to_operand_dims, operand.shape) + return jnp.any(oob_mask), payload + +def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, + *, update_jaxpr, update_consts, dimension_numbers, + indices_are_sorted, unique_indices, mode): + """Checks if indices are within bounds and update does not generate NaN.""" + out = prim.bind( + operand, indices, updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) + + if OOBError not in enabled_errors: + return error, out + + out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) + oob_error = OOBError(summary(), prim.name, operand.shape, payload) + error = assert_func(error, out_of_bounds, oob_error) + error = check_nans(prim, error, enabled_errors, out) + return error, out +error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) +error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, + lax.scatter_add_p) +error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, + lax.scatter_mul_p) +error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, + lax.scatter_min_p) +error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, + lax.scatter_max_p) + +# HOP error check rules + +def get_error_effects_from_jaxpr(closed_jaxpr: core.ClosedJaxpr, + enabled_errors, + error, + *args) -> Set[ErrorEffect]: + """Probes a jaxpr for its error effects.""" + err_vals, err_tree = jtu.tree_flatten(error) + checkify_fun = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, closed_jaxpr.jaxpr, + closed_jaxpr.literals, enabled_errors, err_tree)) + checkify_fun, metadata = _flatten_and_get_error_metadata_thunk(checkify_fun) + in_avals = map(get_shaped_aval, [*err_vals, *args]) + pe.trace_to_jaxpr_final(checkify_fun, in_avals) + _, error_effects = metadata() + return error_effects + +def jaxpr_to_checkify_jaxpr( + jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, + *flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: + checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, + jaxpr.consts, enabled_errors, + err_tree) + fun = lu.wrap_init(checkify_jaxpr_partial) + fun, metadata = _flatten_and_get_error_metadata_thunk(fun) + + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) + checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) + out_tree, error_effects = metadata() + return checked_jaxpr, out_tree, error_effects + +def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear): + # Get the error-effects out of all branches so the cond can be called with + # a merged error with all these effects. + effects = [get_error_effects_from_jaxpr(jxpr, enabled_errors, error, *ops) + for jxpr in branches] + merged_error = error._add_placeholder_effects(set().union(*effects)) + err_vals, err_tree = jtu.tree_flatten(merged_error) + new_linear = (*[False] * len(err_vals), *linear) + + # Update branch jaxprs to be checkified jaxprs. + checked_branch_partials = tuple( + functools.partial(checkify_jaxpr_flat, closed_jaxpr.jaxpr, + closed_jaxpr.consts, enabled_errors, err_tree) + for closed_jaxpr in branches) + checked_branch_funs_ = map(lu.wrap_init, checked_branch_partials) + checked_branch_funs, out_trees_and_effects = unzip2( + map(_flatten_and_get_error_metadata_thunk, checked_branch_funs_)) + in_vals = jtu.tree_leaves((merged_error, ops)) + in_avals = tuple(map(get_shaped_aval, in_vals)) + def to_jaxpr(fun, in_avals): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + return core.ClosedJaxpr(jaxpr, consts) + new_branches = map( + lambda fun: to_jaxpr(fun, in_avals), + checked_branch_funs) + + + err_and_outs = lax.cond_p.bind( + index, *err_vals, *ops, + branches=tuple(new_branches), linear=new_linear) + + # we need to merge metadata across out_trees (a tuple) + out_trees, _ = unzip2(map(lambda fun: fun(), out_trees_and_effects)) + err0, out = tree_unflatten(out_trees[0], err_and_outs) + merged_metadata = err0._metadata + for tr in out_trees[1:]: + err, _ = tree_unflatten(tr, err_and_outs) + merged_metadata = {**merged_metadata, **err._metadata} + return err0._replace(_metadata=merged_metadata), out +error_checks[lax.cond_p] = cond_error_check + +def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, + num_consts, num_carry, linear, unroll): + + consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) + # Query body effects to create a merged error containing all effects (such + # that in and out carried error are of the same type). + effects = get_error_effects_from_jaxpr(jaxpr, enabled_errors, error, *in_flat) + merged_error = error._add_placeholder_effects(effects) + err_vals, err_tree = jtu.tree_flatten(merged_error) + + # Create checked-jaxpr, with the needed pre-processing on the inputs. + xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs] + new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped + checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, + err_tree, *new_in_aval) + + new_in_flat = [*consts, *err_vals, *carry, *xs] + new_linear = (*[False] * len(err_vals), *linear) + tomove = ([False] * len(err_vals) + [True] * len(consts) + + [False] * (len(carry) + len(xs))) + checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) + new_in_flat = [*consts, *err_vals, *carry, *xs] + err_and_out = lax.scan_p.bind( + *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, + num_consts=len(consts), num_carry=len(carry)+len(err_vals), + linear=new_linear, unroll=unroll) + err, out = tree_unflatten(out_tree, err_and_out) + return err, out + +error_checks[lax.scan_p] = scan_error_check + +def checkify_while_body_jaxpr( + cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr, + enabled_errors, error: Error, + c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: + cond_f = core.jaxpr_as_fun(cond_jaxpr) + body_f = core.jaxpr_as_fun(body_jaxpr) + def new_body_f(*vals): + out = body_f(*vals) + # This checks if the next cond application will error + _ = cond_f(*c_consts, *out) + return out + new_body_f_ = lu.wrap_init(new_body_f) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + err_vals, err_tree = jtu.tree_flatten(error) + err_vals = map(get_shaped_aval, err_vals) + flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( + closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + return jaxpr, out_tree, error_effects + +def ignore_error_output_jaxpr(jaxpr, num_error_vals): + """Constructs a checked jaxpr which does not output its error value.""" + consts = jaxpr.consts + jaxpr = jaxpr.jaxpr + new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) + return core.ClosedJaxpr(new_jaxpr, consts) + +def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, + cond_jaxpr, body_nconsts, body_jaxpr): + if cond_jaxpr.out_avals[0].shape: + # TODO(lenamartens, sharadmv): support batched while. + raise ValueError('Checkify does not support batched while-loops ' + '(checkify-of-vmap-of-while). \nHint: if possible, move ' + 'the vmap to the outer level to get ' + 'vmap-of-checkify-of-while.') + + c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) + # Check if the first cond application will error. + error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry) + + _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, + enabled_errors, error, c_consts) + # merged error! + error = error._add_placeholder_effects(error_effects) + err_vals, err_tree = jtu.tree_flatten(error) + checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts) + num_error_vals = len(err_vals) + to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) + checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) + + cond_in_flat = [*err_vals, *c_consts, *carry] + cond_in_flat = map(get_shaped_aval, cond_in_flat) + checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors, + err_tree, *cond_in_flat) + compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) + to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) + compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) + + new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] + all_out_vals = lax.while_p.bind( + *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, + body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) + # body_out_tree will have all the metadata of cond because it executes a cond! + error, out = tree_unflatten(body_out_tree, all_out_vals) + return error, out +error_checks[lax.while_p] = while_loop_error_check + +def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, + in_shardings, out_shardings, resource_env, + donated_invars, name, + in_positional_semantics, out_positional_semantics, inline): + # jaxpr to checked_jaxpr + err_vals, err_tree = jtu.tree_flatten(error) + new_vals_in = [*err_vals, *vals_in] + in_avals = tuple(map(get_shaped_aval, new_vals_in)) + checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, + err_tree, *in_avals) + + # Update pjit params to account for extra error values. + num_error_vals = len(err_vals) + num_out_error_vals = out_tree.num_leaves - len(out_shardings) + sharding = OpShardingSharding.get_replicated( + list(resource_env.physical_mesh.devices.flat)) + new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + + pos_sem = (maps._PositionalSemantics.GLOBAL if jax.config.jax_array + else maps._positional_semantics.val) + if not isinstance(in_positional_semantics, Iterable): + in_positional_semantics = (in_positional_semantics,) + if not isinstance(out_positional_semantics, Iterable): + out_positional_semantics = (out_positional_semantics,) + new_positional_sems_in = (*[pos_sem] * num_error_vals, + *in_positional_semantics) + new_positional_sems_out = (*[pos_sem] * num_error_vals, + *out_positional_semantics) + new_donated_invars = (*[False] * num_error_vals, *donated_invars) + + err_and_out = pjit.pjit_p.bind( + *new_vals_in, + jaxpr=checked_jaxpr, + in_shardings=new_in_shardings, + out_shardings=new_out_shardings, + resource_env=resource_env, + donated_invars=new_donated_invars, + name=name, + in_positional_semantics=new_positional_sems_in, + out_positional_semantics=new_positional_sems_out, + inline=inline) + return tree_unflatten(out_tree, err_and_out) +error_checks[pjit.pjit_p] = pjit_error_check + +def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, + jvp_jaxpr_thunk, call_jaxpr, **params): + # The types to have in mind are: + # jvp : (a -> b) -> (a, T a) -> (b, T b) + # checkify : (a -> b) -> a -> Err b + # jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b)) + # where because Err is a pytree, we necessarily have T (Err b) = Err' (T b) + # where the other Err' components are trivial (of float0 dtype). + # Semantically, we don't add checks to the JVP rule. To check the result of a + # JVP rule, one must instead use checkify-of-jvp. Thus this implementation + # just forwards the input error and code (and trivial tangents) to the output. + err_vals, err_tree = jtu.tree_flatten(in_err) + partial_checkify = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, + call_jaxpr.consts, enabled_errors, err_tree)) + partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk( + partial_checkify) + + # Construct the defaul jvp function, without checkify-ing. + @lu.wrap_init + def jvp(*xs): + # TODO(lenamartens, sharadmv): why not checkify here? + jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() + n, ragged = divmod(len(xs), 2) + assert not ragged + primals, tangents = xs[num_consts:n], xs[n+num_consts:] + return core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *tangents) + + jvp, jvp_out_tree = flatten_fun_output(jvp) + all_outs = custom_derivatives.custom_jvp_call_p.bind(partial_checkify, jvp, + *err_vals, *in_vals) + fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree) + if fst: + err_and_out_tree, _ = out_metadata + out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) + else: + err_vals, out_vals = split_list(all_outs, [len(err_vals)]) + # forward input error to output + out_err = jtu.tree_unflatten(err_tree, err_vals) + return out_err, out_vals +error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule + +def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, + fwd_jaxpr_thunk, num_consts, bwd, out_trees): + err_vals, err_tree = jtu.tree_flatten(in_err) + fun = lu.wrap_init( + functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, + fun_jaxpr.consts, enabled_errors, err_tree)) + fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) + + @lu.wrap_init + def fwd(*xs): + # TODO(lenamartens, sharadmv): why not checkify here? + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() + xs_without_consts = xs[num_consts:] + return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) + + fwd, fwd_out_tree = flatten_fun_output(fwd) + all_outs = custom_derivatives.custom_vjp_call_p.bind( + fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees) + fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) + if fst: + err_and_out_tree, _ = out_metadata + out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) + else: + err_vals, out_vals = split_list(all_outs, [len(err_vals)]) + # forward input error to output + out_err = jtu.tree_unflatten(err_tree, err_vals) + return out_err, out_vals +error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule + +def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): + del debug + new_error = tree_unflatten(err_tree, args) + # Split up new_error into error to be functionalized if it's included in + # enabled_errors (=discharged_error) and an error to be defunctionalized if + # it's not included (=recharged_error) + discharged_error = error + recharged_error = init_error + for error_effect in new_error._pred.keys(): + pred = new_error._pred[error_effect] + code = new_error._code[error_effect] + payload = new_error._payload[error_effect] + if error_effect.error_type in enabled_errors: + discharged_error = update_error(discharged_error, pred, code, {}, payload, + error_effect) + else: + recharged_error = update_error(recharged_error, pred, code, {}, payload, + error_effect) + + discharged_error = discharged_error._replace( + _metadata={**new_error._metadata, **discharged_error._metadata}) + recharged_error = recharged_error._replace(_metadata=new_error._metadata) + # TODO(lenamartens): we actually need to recharge, but this would be a + # breaking API change so leaving for a follow-up. + # check_error(recharged_error) + return discharged_error, [] +error_checks[check_p] = check_discharge_rule + + +## checkify public api + +user_checks = frozenset({FailedCheckError}) +nan_checks = frozenset({NaNError}) +index_checks = frozenset({OOBError}) +div_checks = frozenset({DivisionByZeroError}) +float_checks = nan_checks | div_checks +automatic_checks = float_checks | index_checks +all_checks = automatic_checks | user_checks + + +def checkify(f: Callable[..., Out], + errors: FrozenSet[ErrorCategory] = user_checks + ) -> Callable[..., Tuple[Error, Out]]: + """Functionalize `check` calls in `fun`, and optionally add run-time error checks. + + Run-time errors are either user-added :func:`~check` assertions, or + automatically added checks like NaN checks, depending on the ``errors`` + argument. + + The returned function will return an Error object `err` along with the output + of the original function. ``err.get()`` will either return ``None`` (if no + error occurred) or a string containing an error message. This error message + will correspond to the first error which occurred. ``err.throw()`` will raise + a ValueError with the error message if an error occurred. + + By default only user-added :func:`~check` assertions are enabled. You can + enable automatic checks through the ``errors`` argument. + + The automatic check sets which can be enabled, and when an error is generated: + - ``user_checks``: a :func:`~check` evaluated to False. + - ``nan_checks``: a floating-point operation generated a NaN value + as output. + - ``div_checks``: a division by zero. + - ``index_checks``: an index was out-of-bounds. + + Multiple categories can be enabled together by passing in an error `Set` (eg. + ``errors=nan_checks``). Multiple sets can be re-combined (eg. + ``errors=float_checks|user_checks``) + + Args: + fun: Callable which can contain user checks (see :func:`~check`). + errors: A set of ErrorCategory values which defines the set of enabled + checks. By default only explicit ``checks`` are enabled + (``user_checks``). You can also for example enable NAN and + DIV errors by passing the ``float_checks`` set, or for + example combine multiple sets through set operations + (``float_checks | user_checks``) + Returns: + A function which accepts the same arguments as ``fun`` and returns as output + a pair where the first element is an ``Error`` value, representing the first + failed :func:`~check`, and the second element is the original output of + ``fun``. + + For example: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> + >>> @jax.jit + ... def f(x): + ... y = jnp.sin(x) + ... return x+y + >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) + >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin + """ + @traceback_util.api_boundary + def checked_fun(*args, **kwargs): + # stage: + fun = lu.wrap_init(f) + flat_args, in_tree = jtu.tree_flatten((args, kwargs)) + flat_fun, out_tree = flatten_fun(fun, in_tree) + flat_avals = map(get_shaped_aval, flat_args) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree() + # checkify: + flat_args = jtu.tree_leaves((args, kwargs)) + error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors, + init_error, *flat_args) + return error, jtu.tree_unflatten(out_tree, out_flat) + return checked_fun + +def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: + """Check a predicate, add an error with msg if predicate is False. + + This is an effectful operation, and can't be staged (jitted/scanned/...). + Before staging a function with checks, :func:`~checkify` it! + + Args: + pred: if False, a FailedCheckError error is added. + msg: error message if error is added. Can be a format string. + fmt_args, fmt_kwargs: Positional and keyword formatting arguments for + `msg`, eg.: + ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)`` + Note that these arguments can be traced values allowing you to add + run-time values to the error message. + Note that tracking these run-time arrays will increase your memory usage, + even if no error happens. + + For example: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> def f(x): + ... checkify.check(x>0, "{x} needs to be positive!", x=x) + ... return 1/x + >>> checked_f = checkify.checkify(f) + >>> err, out = jax.jit(checked_f)(-3.) + >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.checkify.JaxRuntimeError: -3. needs to be positive! + + """ + _check(pred, msg, False, *fmt_args, **fmt_kwargs) + +def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): + if not is_scalar_pred(pred): + prim_name = 'debug_check' if debug else 'check' + raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}') + new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs) + error = assert_func(init_error, jnp.logical_not(pred), new_error) + _check_error(error, debug=debug) + +def _check_error(error, *, debug=False): + error = tree_map(core.raise_as_much_as_possible, error) + if any(map(np.shape, error._pred.values())): + error = _reduce_any_error(error) + err_args, tree_def = tree_flatten(error) + + return check_p.bind(*err_args, err_tree=tree_def, debug=debug) + + +def is_scalar_pred(pred) -> bool: + return (isinstance(pred, bool) or + isinstance(pred, jnp.ndarray) and pred.shape == () and + pred.dtype == jnp.dtype('bool')) + + +def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: + """Check a predicate when running under checkify, otherwise is a no-op. + + A `debug_check` will only be run if it is transformed by :func:`~checkify`, + otherwise the check will be dropped. + + Args: + pred: if False, a FailedCheckError error is added. + msg: error message if error is added. + fmt_args, fmt_kwargs: Positional and keyword formatting arguments for + `msg`, eg.: + ``debug_check(.., "check failed on values {} and {named}", x, named=y)`` + Note that these arguments can be traced values allowing you to add + run-time values to the error message. + Note that tracking these run-time arrays will increase your memory usage, + even if no error happens. + + For example: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> def f(x): + ... checkify.debug_check(x!=0, "cannot be zero!") + ... return x + >>> _ = f(0) # running without checkify means no debug_check is run. + >>> checked_f = checkify.checkify(f) + >>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check. + >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.checkify.JaxRuntimeError: cannot be zero! + + """ + _check(pred, msg, True, *fmt_args, **fmt_kwargs) + + +def check_error(error: Error) -> None: + """Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`. + + The semantics of this function are equivalent to: + + >>> def check_error(err: Error) -> None: + ... err.throw() # can raise ValueError + + But unlike that implementation, ``check_error`` can be functionalized using + the :func:`~checkify` transformation. + + This function is similar to :func:`~check` but with a different signature: whereas + :func:`~check` takes as arguments a boolean predicate and a new error message + string, this function takes an ``Error`` value as argument. Both :func:`~check` + and this function raise a Python Exception on failure (a side-effect), and + thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`, + :func:`~jax.lax.scan`, etc. Both also can + be functionalized by using :func:`~checkify`. + + But unlike :func:`~check`, this function is like a direct inverse of + :func:`~checkify`: + whereas :func:`~checkify` takes as input a function which + can raise a Python + Exception and produces a new function without that effect but which produces + an ``Error`` value as output, this ``check_error`` function can accept an + ``Error`` value as input and can produce the side-effect of raising an + Exception. That is, while :func:`~checkify` goes from + functionalizable Exception + effect to error value, this ``check_error`` goes from error value to + functionalizable Exception effect. + + ``check_error`` is useful when you want to turn checks represented by an + ``Error`` value (produced by functionalizing ``checks`` via + :func:`~checkify`) back into Python Exceptions. + + Args: + error: Error to check. + + For example, you might want to functionalize part of your program through + checkify, stage out your functionalized code through :func:`~jax.jit`, then + re-inject your error value outside of the :func:`~jax.jit`: + + >>> import jax + >>> from jax.experimental import checkify + >>> def f(x): + ... checkify.check(x>0, "must be positive!") + ... return x + >>> def with_inner_jit(x): + ... checked_f = checkify.checkify(f) + ... # a checkified function can be jitted + ... error, out = jax.jit(checked_f)(x) + ... checkify.check_error(error) + ... return out + >>> _ = with_inner_jit(1) # no failed check + >>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.JaxRuntimeError: must be positive! + >>> # can re-checkify + >>> error, _ = checkify.checkify(with_inner_jit)(-1) + """ + if not isinstance(error, Error): + raise ValueError('check_error takes an Error as argument, ' + f'got type {type(error)} instead.') + _check_error(error, debug=False) +<<<<<<< HEAD + + +## check primitive + +check_p = core.Primitive('check') +check_p.multiple_results = True # zero results + +# TODO(lenamartens): inherit from Exception instead of ValueError. +class JaxRuntimeError(ValueError): + pass + +@check_p.def_impl +def check_impl(*args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + error = tree_unflatten(err_tree, args) + exc = error.get_exception() + if exc: + raise JaxRuntimeError(str(exc)) from exc + return [] + +@check_p.def_effectful_abstract_eval +def check_abstract_eval(*args, err_tree, debug): + del debug + return [], set(tree_unflatten(err_tree, args)._pred.keys()) + +# TODO(lenamartens) add in-depth error explanation to link to in module docs. +functionalization_error = ValueError( + 'Cannot abstractly evaluate a checkify.check which was not' + ' functionalized. This probably means you tried to stage' + ' (jit/scan/pmap/...) a `check` without functionalizing it' + ' through `checkify.checkify`.' + ) + +def check_lowering_rule(ctx, *args, err_tree, debug): + if debug: + # NOOP (check will only trigger when discharged) + return [] + if not config.jax_experimental_unsafe_xla_runtime_errors: + raise functionalization_error + + out_op, _, keep_alive = mlir.emit_python_callback( + ctx, callback=functools.partial(python_err, err_tree), + token=None, + operands=args, + operand_avals=list(ctx.avals_in), + result_avals=list(ctx.avals_out), + has_side_effect=True) + ctx.module_context.add_keepalive(keep_alive) + return out_op + +def check_lowering_rule_unsupported(*a, debug, **k): + if debug: + return [] + raise functionalization_error + +def python_err(err_tree, *args): + error = tree_unflatten(err_tree, args) + _check_error(error) + return [] + +mlir.register_lowering(check_p, check_lowering_rule_unsupported, + platform='tpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='cpu') +mlir.register_lowering(check_p, check_lowering_rule, + platform='gpu') + +def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): + size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) + if dim is not batching.not_mapped) + batched_args = (batching.bdim_at_front(a, d, size) + for a, d in zip(batched_args, batch_dims)) + err = tree_unflatten(err_tree, batched_args) + _check_error(err, debug=debug) + return [], [] +batching.primitive_batchers[check_p] = check_batching_rule + +def check_jvp_rule(primals, _, *, err_tree, debug): + # Check primals, discard tangents. + check_p.bind(*primals, err_tree=err_tree, debug=debug) + return [], [] +ad.primitive_jvps[check_p] = check_jvp_rule + +## checkify rules + +def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]: + # TODO(lenamartens): use c++ version from XLA? + tb = None + import inspect + for frame_info in inspect.stack(): + frame = frame_info.frame + if skip_frames: + skip_frames -= 1 + elif not traceback_util.include_frame(frame): + continue + else: + tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + return tb + +def summary() -> str: + return str(source_info_util.summarize(source_info_util.current())) + +def nan_error_check(prim, error, enabled_errors, *in_vals, **params): + out = prim.bind(*in_vals, **params) + err = check_nans(prim, error, enabled_errors, out) + return out, err + +def check_nans(prim, error, enabled_errors, out): + if NaNError not in enabled_errors: + return error + + def isnan(x): + if isinstance(x, prng.PRNGKeyArray): + return False + return jnp.any(jnp.isnan(x)) + + any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) + if prim.multiple_results else isnan(out)) + return assert_func(error, any_nans, NaNError(summary(), prim.name)) + + +# All primitives which can generate a NaN. +nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, + lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, + lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, + lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, + lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, + lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, + lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, + lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, + lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, + lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, + lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_sum_p, lax.reduce_window_p, + lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, + lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, + lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] + +for prim in nan_primitives: + error_checks[prim] = functools.partial(nan_error_check, prim) + + +def gather_error_check(error, enabled_errors, operand, start_indices, *, + dimension_numbers, slice_sizes, unique_indices, + indices_are_sorted, mode, fill_value): + out = lax.gather_p.bind( + operand, start_indices, dimension_numbers=dimension_numbers, + slice_sizes=slice_sizes, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) + + if OOBError not in enabled_errors: + return out, error + + # compare to OOB masking logic in lax._gather_translation_rule + dnums = dimension_numbers + operand_dims = np.array(operand.shape) + num_batch_dims = len(start_indices.shape) - 1 + + upper_bound = operand_dims[np.array(dnums.start_index_map)] + upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] + upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) + oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) + + payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) + return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload)) +error_checks[lax.gather_p] = gather_error_check + +def div_error_check(error, enabled_errors, x, y): + """Checks for division by zero and NaN.""" + if DivisionByZeroError in enabled_errors: + any_zero = jnp.any(jnp.equal(y, 0)) + error = assert_func(error, any_zero, DivisionByZeroError(summary())) + return nan_error_check(lax.div_p, error, enabled_errors, x, y) +error_checks[lax.div_p] = div_error_check + +def oob_payload(oob_mask, indices, dims_map, operand_shape): + # Get first OOB index, axis and axis size so it can be added to the error msg. + flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) + multi_idx = jnp.unravel_index(flat_idx, indices.shape) + oob_axis = jnp.array(dims_map)[multi_idx[-1]] + oob_axis_size = jnp.array(operand_shape)[oob_axis] + oob_index = jnp.ravel(indices)[flat_idx] + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + return payload + +def scatter_oob(operand, indices, updates, dnums): + # Ref: see clamping code used in scatter_translation_rule + slice_sizes = [] + pos = 0 + for i in range(len(operand.shape)): + if i in dnums.inserted_window_dims: + slice_sizes.append(1) + else: + slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) + pos += 1 + + upper_bound = np.array([operand.shape[i] - slice_sizes[i] + for i in dnums.scatter_dims_to_operand_dims], + np.int64) + upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) + upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, + (len(indices.shape) - 1,)) + + lower_oob = jnp.less(indices, 0) + upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) + oob_mask = jnp.logical_or(lower_oob, upper_oob) + payload = oob_payload(oob_mask, indices, + dnums.scatter_dims_to_operand_dims, operand.shape) + return jnp.any(oob_mask), payload + +def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, + *, update_jaxpr, update_consts, dimension_numbers, + indices_are_sorted, unique_indices, mode): + """Checks if indices are within bounds and update does not generate NaN.""" + out = prim.bind( + operand, indices, updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) + + if OOBError not in enabled_errors: + return out, error + + out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) + oob_error = OOBError(summary(), prim.name, operand.shape, payload) + error = assert_func(error, out_of_bounds, oob_error) + return out, check_nans(prim, error, enabled_errors, out) +error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) +error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, + lax.scatter_add_p) +error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, + lax.scatter_mul_p) +error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, + lax.scatter_min_p) +error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, + lax.scatter_max_p) + +def cond_error_check(error, enabled_errors, index, *ops, branches, linear): + _, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error, + enabled_errors) + for jxpr in branches) + _, effects = unzip2(out_trees_and_effects) + + merged_error = error._add_placeholder_effects(set().union(*effects)) + new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error, + enabled_errors) + for jxpr in branches) + out_trees, _ = unzip2(out_trees_and_effects) + + flat_error, _ = tree_flatten(merged_error) + new_linear = (*[False] * len(flat_error), *linear) + err_and_outs = lax.cond_p.bind( + index, *flat_error, *ops, + branches=tuple(new_branches), linear=new_linear) + + # we need to merge metadata across out_trees (a tuple) + # maybe there's a better way to do this, but we can use the outs + # to unflatten all trees. + err0, *out = tree_unflatten(out_trees[0], err_and_outs) + merged_metadata = err0._metadata + for tr in out_trees[1:]: + err, *_ = tree_unflatten(tr, err_and_outs) + merged_metadata = {**merged_metadata, **err._metadata} + return out, err0._replace(_metadata=merged_metadata) +error_checks[lax.cond_p] = cond_error_check + +def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, + num_consts, num_carry, linear, unroll): + consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) + _, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors) + merged_error = error._add_placeholder_effects(effects) + checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors) + + flat_error_vals, _ = tree_flatten(merged_error) + tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs)) + checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) + new_linear = (*[False] * len(flat_error_vals), *linear) + new_in_flat = [*consts, *flat_error_vals, *carry, *xs] + err_and_out = lax.scan_p.bind( + *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, + num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals), + linear=new_linear, unroll=unroll) + err, *out = tree_unflatten(out_tree, err_and_out) + return out, err + +error_checks[lax.scan_p] = scan_error_check + +def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts): + cond_f = core.jaxpr_as_fun(cond_jaxpr) + body_f = core.jaxpr_as_fun(body_jaxpr) + def new_body_f(*vals): + out = body_f(*vals) + # This checks if the next cond application will error + _ = cond_f(*c_consts, *out) + return out + return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, + body_jaxpr.in_avals) + +def ignore_error_output_jaxpr(jaxpr, num_error_vals): + """Constructs a checked jaxpr which does not output its error value.""" + consts = jaxpr.consts + jaxpr = jaxpr.jaxpr + new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) + return core.ClosedJaxpr(new_jaxpr, consts) + +def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, + cond_jaxpr, body_nconsts, body_jaxpr): + if cond_jaxpr.out_avals[0].shape: + # TODO(lenamartens, sharadmv): support batched while. + raise ValueError('Checkify does not support batched while-loops ' + '(checkify-of-vmap-of-while). \nHint: if possible, move ' + 'the vmap to the outer level to get ' + 'vmap-of-checkify-of-while.') + + err_vals, _ = tree_flatten(error) + c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) + + # Check if the first cond application will error. + checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr( + cond_jaxpr, error, enabled_errors) + outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry) + error, _ = tree_unflatten(cond_out_tree, outs) + + checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) + # merged error! + error = error._add_placeholder_effects(error_effects) + checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) + err_vals = jtu.tree_leaves(error) + num_error_vals = len(err_vals) + to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) + checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) + + checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors) + compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) + to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) + compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) + new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] + + all_out_vals = lax.while_p.bind( + *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, + body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) + # body_out_tree will have all the metadata of cond because it executes a cond! + # only need to merge metadata on the input error. + error, *out = tree_unflatten(body_out_tree, all_out_vals) + return out, error +error_checks[lax.while_p] = while_loop_error_check + + +def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, + in_shardings, out_shardings, resource_env, + donated_invars, name, + in_positional_semantics, out_positional_semantics, + keep_unused, inline): + checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error, + enabled_errors) + out_error = error._add_placeholder_effects(effects) + + flat_error_vals = jtu.tree_leaves(error) + num_error_vals = len(flat_error_vals) + new_vals_in = [*flat_error_vals, *vals_in] + + sharding = OpShardingSharding.get_replicated( + list(resource_env.physical_mesh.devices.flat)) + new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) + new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)), + *out_shardings) + + if config.jax_array: + pos_sem = maps._PositionalSemantics.GLOBAL + else: + pos_sem = maps._positional_semantics.val + + if not isinstance(in_positional_semantics, Iterable): + in_positional_semantics = (in_positional_semantics,) + if not isinstance(out_positional_semantics, Iterable): + out_positional_semantics = (out_positional_semantics,) + new_positional_sems_in = (*[pos_sem] * num_error_vals, + *in_positional_semantics) + new_positional_sems_out = (*[pos_sem] * num_error_vals, + *out_positional_semantics) + new_donated_invars = (*[False] * num_error_vals, *donated_invars) + + err_and_out = pjit.pjit_p.bind( + *new_vals_in, + jaxpr=checked_jaxpr, + in_shardings=new_in_shardings, + out_shardings=new_out_shardings, + resource_env=resource_env, + donated_invars=new_donated_invars, + name=name, + in_positional_semantics=new_positional_sems_in, + out_positional_semantics=new_positional_sems_out, + keep_unused=keep_unused, + inline=inline) + err, *out = tree_unflatten(out_tree, err_and_out) + return out, err +error_checks[pjit.pjit_p] = pjit_error_check + + +def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): + del debug + new_error = tree_unflatten(err_tree, args) + # Split up new_error into error to be functionalized if it's included in + # enabled_errors (=discharged_error) and an error to be defunctionalized if + # it's not included (=recharged_error) + discharged_error = error + recharged_error = init_error + for error_effect in new_error._pred.keys(): + pred = new_error._pred[error_effect] + code = new_error._code[error_effect] + payload = new_error._payload[error_effect] + if error_effect.error_type in enabled_errors: + discharged_error = update_error(discharged_error, pred, code, {}, payload, + error_effect) + else: + recharged_error = update_error(recharged_error, pred, code, {}, payload, + error_effect) + + discharged_error = discharged_error._replace( + _metadata={**new_error._metadata, **discharged_error._metadata}) + recharged_error = recharged_error._replace(_metadata=new_error._metadata) + # TODO(lenamartens): we actually need to recharge, but this would be a + # breaking API change so leaving for a follow-up. + # check_error(recharged_error) + return [], discharged_error +error_checks[check_p] = check_discharge_rule + + +## checkify api + +user_checks = frozenset({FailedCheckError}) +nan_checks = frozenset({NaNError}) +index_checks = frozenset({OOBError}) +div_checks = frozenset({DivisionByZeroError}) +float_checks = nan_checks | div_checks +automatic_checks = float_checks | index_checks +all_checks = automatic_checks | user_checks + +Out = TypeVar('Out') + + +def checkify(fun: Callable[..., Out], + errors: FrozenSet[ErrorCategory] = user_checks + ) -> Callable[..., Tuple[Error, Out]]: + """Functionalize `check` calls in `fun`, and optionally add run-time error checks. + + Run-time errors are either user-added :func:`~check` assertions, or + automatically added checks like NaN checks, depending on the ``errors`` + argument. + + The returned function will return an Error object `err` along with the output + of the original function. ``err.get()`` will either return ``None`` (if no + error occurred) or a string containing an error message. This error message + will correspond to the first error which occurred. ``err.throw()`` will raise + a ValueError with the error message if an error occurred. + + By default only user-added :func:`~check` assertions are enabled. You can + enable automatic checks through the ``errors`` argument. + + The automatic check sets which can be enabled, and when an error is generated: + - ``user_checks``: a :func:`~check` evaluated to False. + - ``nan_checks``: a floating-point operation generated a NaN value + as output. + - ``div_checks``: a division by zero. + - ``index_checks``: an index was out-of-bounds. + + Multiple categories can be enabled together by passing in an error `Set` (eg. + ``errors=nan_checks``). Multiple sets can be re-combined (eg. + ``errors=float_checks|user_checks``) + + Args: + fun: Callable which can contain user checks (see :func:`~check`). + errors: A set of ErrorCategory values which defines the set of enabled + checks. By default only explicit ``checks`` are enabled + (``user_checks``). You can also for example enable NAN and + DIV errors by passing the ``float_checks`` set, or for + example combine multiple sets through set operations + (``float_checks | user_checks``) + Returns: + A function which accepts the same arguments as ``fun`` and returns as output + a pair where the first element is an ``Error`` value, representing the first + failed :func:`~check`, and the second element is the original output of + ``fun``. + + For example: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> + >>> @jax.jit + ... def f(x): + ... y = jnp.sin(x) + ... return x+y + >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) + >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin + """ + @traceback_util.api_boundary + def checked_fun(*args, **kwargs): + args_flat, in_tree = tree_flatten((args, kwargs)) + f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) + error, out_flat = checkify_flat(f, errors, *args_flat) + out = tree_unflatten(out_tree(), out_flat) + return error, out + return checked_fun +======= +>>>>>>> da3607926 (Checkify: switch to initial-style.) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index c7838287a..1c8fb8caf 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -544,7 +544,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return f(jnp.array([jnp.inf]))[0] err, _ = g(2.) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive: sin") + self.assertIn("nan generated by primitive: sin", err.get()) @jtu.skip_on_devices("tpu") def test_custom_jvp(self): @@ -774,7 +774,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32)) self.assertSetEqual(jaxpr.effects, {ErrorEffect(FailedCheckError, ( - jax.ShapeDtypeStruct((0,), jnp.int32), jax.ShapeDtypeStruct((4,), jnp.int32),))}) def g(x, y): checkify.check(False, "hi: {} {}", x, y) @@ -783,7 +782,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase): jax.make_jaxpr(g)( jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects, {ErrorEffect(FailedCheckError, ( - jax.ShapeDtypeStruct((0,), jnp.int32), jax.ShapeDtypeStruct((4,), jnp.int32), jax.ShapeDtypeStruct((2,), jnp.float32)))})