rocm_jax/jax/_src/checkify.py

1271 lines
49 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import itertools as it
import types
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
from jax.api_util import flatten_fun
from jax.api_util import flatten_fun_nokwargs
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = Type['JaxException']
Payload = List[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
## Utils
def popattr(obj, attrname):
val = getattr(obj, attrname)
delattr(obj, attrname)
return val
def setnewattr(obj, name, val):
sentinel = object()
assert getattr(obj, name, sentinel) is sentinel
setattr(obj, name, val)
# Concrete errors
class JaxException(Exception):
"""Python exception which can contain an error message with JAX run-time info."""
def __init__(self, traceback_info):
self.traceback_info = traceback_info
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
# self.with_traceback(self.traceback_info)
def __init_subclass__(cls):
jtu.register_pytree_node_class(cls)
def tree_flatten(self):
return ([], self.traceback_info)
@classmethod
def tree_unflatten(cls, metadata, payload):
del payload
return cls(metadata)
def get_effect_type(self) -> core.Effect:
pass
@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect:
error_type: Type[JaxException]
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
def __post_init__(self):
cf.allowed_effects.add(self)
mlir.lowerable_effects.add(self)
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
for sd in x.shape_dtypes)
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
return (unpack(self) < unpack(other))
class DivisionByZeroError(JaxException):
def __str__(self):
return f'division by zero at {self.traceback_info}'
def get_effect_type(self):
return ErrorEffect(DivisionByZeroError, ())
class NaNError(JaxException):
def __init__(self, traceback_info, primitive_name):
super().__init__(traceback_info)
self.prim = primitive_name
def tree_flatten(self):
return ([], (self.traceback_info, self.prim))
@classmethod
def tree_unflatten(cls, metadata, _):
return cls(*metadata)
def get_effect_type(self):
return ErrorEffect(NaNError, ())
def __str__(self):
return f'nan generated by primitive: {self.prim} at {self.traceback_info}'
class OOBError(JaxException):
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
super().__init__(traceback_info)
self.prim = primitive_name
self.operand_shape = operand_shape
self._payload = payload
def tree_flatten(self):
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
@classmethod
def tree_unflatten(cls, metadata, payload):
return cls(*metadata, payload[0])
def __str__(self):
return (f'out-of-bounds indexing for array of '
f'shape {self.operand_shape}: '
f'index {self._payload[0]} is out of bounds for axis '
f'{self._payload[1]} with size {self._payload[2]}. '
f'Failed at {self.traceback_info}')
def get_effect_type(self):
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
class FailedCheckError(JaxException):
def __init__(self, traceback_info, fmt_string, *a, **k):
super().__init__(traceback_info)
self.fmt_string = fmt_string
self.args = a
self.kwargs = k
def tree_flatten(self):
return ((jnp.array([], jnp.int32), self.args, self.kwargs),
(self.traceback_info, self.fmt_string))
@classmethod
def tree_unflatten(cls, metadata, payload):
_, args, kwargs = payload
return cls(*metadata, *args, **kwargs)
def __str__(self):
return (self.fmt_string.format(*self.args, **self.kwargs)
+ f' (check failed at {self.traceback_info})')
def get_effect_type(self):
vals = jtu.tree_leaves((self.args, self.kwargs))
return ErrorEffect(
FailedCheckError,
# Need a 0-size array here for data-dependence.
(jax.ShapeDtypeStruct((0,), jnp.int32),
*tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)))
@dataclasses.dataclass
class BatchedError(JaxException):
error_mapping: Dict[Tuple[int, ...], JaxException]
def __post_init__(self):
traceback_info = list(self.error_mapping.values())[0].traceback_info
super().__init__(traceback_info)
def __str__(self):
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
for idx, e in self.error_mapping.items())
# Error Value
@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Error:
_pred: Dict[ErrorEffect, Bool]
_code: Dict[ErrorEffect, Int]
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: Dict[ErrorEffect, Payload]
def get(self) -> Optional[str]:
"""Returns error message if error happened, None if no error happened."""
exp = self.get_exception()
if exp is not None:
return str(exp)
return None
def get_exception(self) -> Optional[JaxException]:
"""Returns Python exception if error happened, None if no error happened."""
if any(map(np.shape, self._pred.values())):
return self._get_batched_exception()
else:
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect]:
if min_code is None or code < min_code:
min_code = code
cur_effect = error_effect
if cur_effect is not None:
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
self._payload[cur_effect])
return None
def throw(self):
_check_error(self)
def __str__(self):
return f'Error({self.get()})'
# Internal helpers
def _get_batched_exception(self):
shape = np.shape(list(self._pred.values())[0])
error_mapping = {}
for idx in np.ndindex(*shape):
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code:
min_code = code[idx] # type: ignore
cur_effect = error_effect
if cur_effect is not None:
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
error_mapping[idx] = jax_error
return BatchedError(error_mapping)
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
new_codes = {**self._code, **{effect_type: code}} # type: ignore
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
new_metadata = {**self._metadata, **metadata}
return Error(new_errs, new_codes, new_metadata, new_payload)
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
new_err = self._pred.copy()
new_code = self._code.copy()
new_payload = self._payload.copy()
for effect in effects:
if effect not in self._pred.keys():
new_err[effect] = False
new_payload[effect] = list(
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
# The error value associated with this effect will never become True, so
# we don't need to set a meaningful code.
new_code[effect] = -1
return Error(new_err, new_code, self._metadata, new_payload)
def _replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)
# PyTree methods
def tree_flatten(self):
return ((self._pred, self._code, self._payload), (self._metadata))
@classmethod
def tree_unflatten(cls, metadata, data):
pred, code, payload = data
return cls(pred, code, metadata, payload)
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
code = next_code()
effect_type = new_error.get_effect_type()
new_payload, new_metadata = tree_flatten(new_error)
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
def update_error(error, pred, code, metadata, payload, effect_type):
err_of_type = error._pred.get(effect_type, False)
out_err = err_of_type | pred
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
cur_payload = error._payload.get(effect_type, None)
if cur_payload is not None:
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
else:
out_payload = payload
return error._update(effect_type, out_err, out_code, metadata, out_payload)
## Checkify transformation for plumbing functional error values.
class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors
def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
*in_vals, **params)
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
return [CheckifyTracer(self, x) for x in out]
else:
return CheckifyTracer(self, out)
def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
if 'donated_invars' in params:
params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)),
*params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(jtu.tree_leaves(e))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
@as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk():
out_val_axes = params['out_axes_thunk']()
out_err_num = out_tree().num_leaves - len(out_val_axes)
return (*(0,)*out_err_num, *out_val_axes)
params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']),
out_axes_thunk=new_out_axes_thunk,
donated_invars=(*(False,)*num_error_vals, *params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params_)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
error = _reduce_any_error(error)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
setnewattr(main, 'err_tree', err_tree)
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
def post_process_map(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
num_err_leaves = len(err_leaves)
setnewattr(main, 'err_tree', err_tree)
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
error = tree_unflatten(err_tree, err_vals)
error = _reduce_any_error(error)
setnewattr(main, 'error', error)
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes):
return (*(0,)*num_err_leaves, *out_axes)
return (*err_leaves, *vals), (todo, out_axes_transform)
def process_custom_jvp_call(self, prim, f, jvp, tracers):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
f = checkify_subtrace(f, self.main)
f, f_out_tree = flatten_fun_nokwargs(f, in_tree)
jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main,
num_error_vals, err_tree)
all_outs = prim.bind(f, jvp, *flat_vals)
fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree)
if fst:
out_err, *out_vals = tree_unflatten(out_tree, all_outs)
else:
err_vals, out_vals = split_list(all_outs, [num_error_vals])
# forward input error values to output
out_err = tree_unflatten(out_tree, err_vals)
setattr(self.main, 'error', out_err)
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_custom_jvp_call(self, tracers, jvp_was_run):
if jvp_was_run:
msg = ('support for custom_jvp rules which close over checkify values is '
'not implemented. If you see this, open an issue at '
'https://github.com/google/jax/issues!')
raise NotImplementedError(msg)
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
def todo(vals):
err_vals, vals = split_list(vals, [len(err_leaves)])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
fun = checkify_subtrace(fun, self.main)
fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree)
fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main,
err_tree, num_error_vals)
all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees)
fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree)
if fst:
error, *out = tree_unflatten(out_tree, all_out_vals)
else:
_, out = split_list(all_out_vals, [num_error_vals])
# forward input error values to output
error = tree_unflatten(err_tree, err_vals)
setattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out]
def _reduce_any_error(error: Error):
out_error = init_error
for error_effect in error._pred.keys():
errs, codes, payloads = (error._pred[error_effect],
error._code[error_effect],
error._payload[error_effect])
reduced_idx = jnp.argsort(errs)[-1]
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
(errs, codes, payloads))
out_error = out_error._update(error_effect, pred, code, {}, payload)
out_error = out_error._replace(_metadata=error._metadata)
return out_error
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
*args):
fun = checkify_subtrace(fun)
fun = checkify_traceable(fun, enabled_errors)
error, *outvals = fun.call_wrapped(init_error, *args)
return error, outvals
@lu.transformation
def checkify_traceable(enabled_errors, error, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, error, *args), {}
del main
yield outs
@lu.transformation
def checkify_subtrace(main, error, *args):
setnewattr(main, 'error', error)
trace = main.with_cur_sublevel()
in_tracers = [CheckifyTracer(trace, x) for x in args]
out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers]
error = main.error
del main.error
yield (error, *out_vals)
@lu.transformation_with_aux
def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args):
# Like checkify_subtrace, but used specifically on the custom JVP rules
# associated with a custom_jvp. This code is called in the context of a
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
# flattened into a single args tuple, and similarly must produce flattened
# primal and tangent outputs. Both primals and tangents include error values,
# but the tangent error values are trivially zero.
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
# where the other Err' components are trivial (of float0 dtype).
# Semantically, we don't add checks to the JVP rule. To check the result of a
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
# just forwards the input error and code (and trivial tangents) to the output.
del main
n, ragged = divmod(len(args), 2)
assert not ragged
err_primals, primals = split_list(args[:n], [num_error_vals])
err_tangents, tangents = split_list(args[n:], [num_error_vals])
outs = yield (*primals, *tangents), {}
m, ragged = divmod(len(outs), 2)
assert not ragged
out_primals, out_tangents = outs[:m], outs[m:]
yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree
@lu.transformation_with_aux
def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args):
del main
# We don't add any checks; just drop input error values.
_, args = split_list(args, [num_error_vals])
outs = yield args, {}
yield outs, err_tree
@lu.transformation_with_aux
def query_error_effects(*args):
(error, *outs) = yield args, {}
yield (error, *outs), set(error._pred.keys())
def checkify_jaxpr(jaxpr, error,
enabled_errors) -> Tuple[core.ClosedJaxpr,
Tuple[PyTreeDef,
FrozenSet[ErrorEffect]]]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
def checkify_fun_to_jaxpr(
f, error, enabled_errors,
in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]:
flat_error_vals, in_tree = tree_flatten(error)
f = checkify_subtrace(f)
f = checkify_traceable(f, enabled_errors)
f, error_effect = query_error_effects(f)
in_tree = jtu.tree_structure((error, *in_avals))
f, out_tree = flatten_fun_nokwargs(f, in_tree)
err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)),
flat_error_vals)
avals_in = [*err_vals, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect()))
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, a FailedCheckError error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "{x} needs to be positive!", x=x)
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
"""
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
if not is_scalar_pred(pred):
prim_name = 'debug_check' if debug else 'check'
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
_check_error(error, debug=debug)
def _check_error(error, *, debug=False):
error = tree_map(core.raise_as_much_as_possible, error)
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate when running under checkify, otherwise is a no-op.
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
otherwise the check will be dropped.
Args:
pred: if False, a FailedCheckError error is added.
msg: error message if error is added.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.debug_check(x!=0, "cannot be zero!")
... return x
>>> _ = f(0) # running without checkify means no debug_check is run.
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: cannot be zero!
"""
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the :func:`~checkify` transformation.
This function is similar to :func:`~check` but with a different signature: whereas
:func:`~check` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both :func:`~check`
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
:func:`~jax.lax.scan`, etc. Both also can
be functionalized by using :func:`~checkify`.
But unlike :func:`~check`, this function is like a direct inverse of
:func:`~checkify`:
whereas :func:`~checkify` takes as input a function which
can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while :func:`~checkify` goes from
functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via
:func:`~checkify`) back into Python Exceptions.
Args:
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through :func:`~jax.jit`, then
re-inject your error value outside of the :func:`~jax.jit`:
>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "must be positive!")
... return x
>>> def with_inner_jit(x):
... checked_f = checkify.checkify(f)
... # a checkified function can be jitted
... error, out = jax.jit(checked_f)(x)
... checkify.check_error(error)
... return out
>>> _ = with_inner_jit(1) # no failed check
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
_check_error(error, debug=False)
## check primitive
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results
# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
pass
@check_p.def_impl
def check_impl(*args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
error = tree_unflatten(err_tree, args)
exc = error.get_exception()
if exc:
raise JaxRuntimeError(str(exc)) from exc
return []
@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree, debug):
del debug
return [], set(tree_unflatten(err_tree, args)._pred.keys())
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
'Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
def check_lowering_rule(ctx, *args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, _, keep_alive = mlir.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
token=None,
operands=args,
operand_avals=list(ctx.avals_in),
result_avals=list(ctx.avals_out),
has_side_effect=True)
ctx.module_context.add_keepalive(keep_alive)
return out_op
def check_lowering_rule_unsupported(*a, debug, **k):
if debug:
return []
raise functionalization_error
def python_err(err_tree, *args):
error = tree_unflatten(err_tree, args)
_check_error(error)
return []
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
platform='tpu')
mlir.register_lowering(check_p, check_lowering_rule,
platform='cpu')
mlir.register_lowering(check_p, check_lowering_rule,
platform='gpu')
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
batched_args = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = tree_unflatten(err_tree, batched_args)
_check_error(err, debug=debug)
return [], []
batching.primitive_batchers[check_p] = check_batching_rule
def check_jvp_rule(primals, _, *, err_tree, debug):
# Check primals, discard tangents.
check_p.bind(*primals, err_tree=err_tree, debug=debug)
return [], []
ad.primitive_jvps[check_p] = check_jvp_rule
## checkify rules
def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]:
# TODO(lenamartens): use c++ version from XLA?
tb = None
import inspect
for frame_info in inspect.stack():
frame = frame_info.frame
if skip_frames:
skip_frames -= 1
elif not traceback_util.include_frame(frame):
continue
else:
tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno)
return tb
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
err = check_nans(prim, error, enabled_errors, out)
return out, err
def check_nans(prim, error, enabled_errors, out):
if NaNError not in enabled_errors:
return error
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
return False
return jnp.any(jnp.isnan(x))
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
if prim.multiple_results else isnan(out))
return assert_func(error, any_nans, NaNError(summary(), prim.name))
# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
for prim in nan_primitives:
error_checks[prim] = functools.partial(nan_error_check, prim)
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if OOBError not in enabled_errors:
return out, error
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
num_batch_dims = len(start_indices.shape) - 1
upper_bound = operand_dims[np.array(dnums.start_index_map)]
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
if DivisionByZeroError in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check
def oob_payload(oob_mask, indices, dims_map, operand_shape):
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand_shape)[oob_axis]
oob_index = jnp.ravel(indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
return payload
def scatter_oob(operand, indices, updates, dnums):
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
np.int64)
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_oob = jnp.less(indices, 0)
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
oob_mask = jnp.logical_or(lower_oob, upper_oob)
payload = oob_payload(oob_mask, indices,
dnums.scatter_dims_to_operand_dims, operand.shape)
return jnp.any(oob_mask), payload
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if OOBError not in enabled_errors:
return out, error
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
oob_error = OOBError(summary(), prim.name, operand.shape, payload)
error = assert_func(error, out_of_bounds, oob_error)
return out, check_nans(prim, error, enabled_errors, out)
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
lax.scatter_min_p)
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
lax.scatter_max_p)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
_, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error,
enabled_errors)
for jxpr in branches)
_, effects = unzip2(out_trees_and_effects)
merged_error = error._add_placeholder_effects(set().union(*effects))
new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error,
enabled_errors)
for jxpr in branches)
out_trees, _ = unzip2(out_trees_and_effects)
flat_error, _ = tree_flatten(merged_error)
new_linear = (*[False] * len(flat_error), *linear)
err_and_outs = lax.cond_p.bind(
index, *flat_error, *ops,
branches=tuple(new_branches), linear=new_linear)
# we need to merge metadata across out_trees (a tuple)
# maybe there's a better way to do this, but we can use the outs
# to unflatten all trees.
err0, *out = tree_unflatten(out_trees[0], err_and_outs)
merged_metadata = err0._metadata
for tr in out_trees[1:]:
err, *_ = tree_unflatten(tr, err_and_outs)
merged_metadata = {**merged_metadata, **err._metadata}
return out, err0._replace(_metadata=merged_metadata)
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
_, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors)
merged_error = error._add_placeholder_effects(effects)
checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors)
flat_error_vals, _ = tree_flatten(merged_error)
tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_linear = (*[False] * len(flat_error_vals), *linear)
new_in_flat = [*consts, *flat_error_vals, *carry, *xs]
err_and_out = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals),
linear=new_linear, unroll=unroll)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
out = body_f(*vals)
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
body_jaxpr.in_avals)
def ignore_error_output_jaxpr(jaxpr, num_error_vals):
"""Constructs a checked jaxpr which does not output its error value."""
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
if cond_jaxpr.out_avals[0].shape:
# TODO(lenamartens, sharadmv): support batched while.
raise ValueError('Checkify does not support batched while-loops '
'(checkify-of-vmap-of-while). \nHint: if possible, move '
'the vmap to the outer level to get '
'vmap-of-checkify-of-while.')
err_vals, _ = tree_flatten(error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr(
cond_jaxpr, error, enabled_errors)
outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry)
error, _ = tree_unflatten(cond_out_tree, outs)
checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
# merged error!
error = error._add_placeholder_effects(error_effects)
checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
err_vals = jtu.tree_leaves(error)
num_error_vals = len(err_vals)
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
all_out_vals = lax.while_p.bind(
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
# body_out_tree will have all the metadata of cond because it executes a cond!
# only need to merge metadata on the input error.
error, *out = tree_unflatten(body_out_tree, all_out_vals)
return out, error
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
enabled_errors)
out_error = error._add_placeholder_effects(effects)
flat_error_vals = jtu.tree_leaves(error)
num_error_vals = len(flat_error_vals)
new_vals_in = [*flat_error_vals, *vals_in]
sharding = OpShardingSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)),
*out_shardings)
if config.jax_array:
pos_sem = maps._PositionalSemantics.GLOBAL
else:
pos_sem = maps._positional_semantics.val
if not isinstance(in_positional_semantics, Iterable):
in_positional_semantics = (in_positional_semantics,)
if not isinstance(out_positional_semantics, Iterable):
out_positional_semantics = (out_positional_semantics,)
new_positional_sems_in = (*[pos_sem] * num_error_vals,
*in_positional_semantics)
new_positional_sems_out = (*[pos_sem] * num_error_vals,
*out_positional_semantics)
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
err_and_out = pjit.pjit_p.bind(
*new_vals_in,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[pjit.pjit_p] = pjit_error_check
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
del debug
new_error = tree_unflatten(err_tree, args)
# Split up new_error into error to be functionalized if it's included in
# enabled_errors (=discharged_error) and an error to be defunctionalized if
# it's not included (=recharged_error)
discharged_error = error
recharged_error = init_error
for error_effect in new_error._pred.keys():
pred = new_error._pred[error_effect]
code = new_error._code[error_effect]
payload = new_error._payload[error_effect]
if error_effect.error_type in enabled_errors:
discharged_error = update_error(discharged_error, pred, code, {}, payload,
error_effect)
else:
recharged_error = update_error(recharged_error, pred, code, {}, payload,
error_effect)
discharged_error = discharged_error._replace(
_metadata={**new_error._metadata, **discharged_error._metadata})
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
# TODO(lenamartens): we actually need to recharge, but this would be a
# breaking API change so leaving for a follow-up.
# check_error(recharged_error)
return [], discharged_error
error_checks[check_p] = check_discharge_rule
## checkify api
user_checks = frozenset({FailedCheckError})
nan_checks = frozenset({NaNError})
index_checks = frozenset({OOBError})
div_checks = frozenset({DivisionByZeroError})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out],
errors: FrozenSet[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added :func:`~check` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
argument.
The returned function will return an Error object `err` along with the output
of the original function. ``err.get()`` will either return ``None`` (if no
error occurred) or a string containing an error message. This error message
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
By default only user-added :func:`~check` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a :func:`~check` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ``div_checks``: a division by zero.
- ``index_checks``: an index was out-of-bounds.
Multiple categories can be enabled together by passing in an error `Set` (eg.
``errors=nan_checks``). Multiple sets can be re-combined (eg.
``errors=float_checks|user_checks``)
Args:
fun: Callable which can contain user checks (see :func:`~check`).
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
DIV errors by passing the ``float_checks`` set, or for
example combine multiple sets through set operations
(``float_checks | user_checks``)
Returns:
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing the first
failed :func:`~check`, and the second element is the original output of
``fun``.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>>
>>> @jax.jit
... def f(x):
... y = jnp.sin(x)
... return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
"""
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
error, out_flat = checkify_flat(f, errors, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return error, out
return checked_fun