rocm_jax/jax/_src/checkify.py

1274 lines
49 KiB
Python
Raw Normal View History

# Copyright 2021 The JAX Authors.
2021-10-29 09:23:27 -07:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
2021-10-29 09:23:27 -07:00
import itertools as it
import types
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
2021-10-29 09:23:27 -07:00
import jax
from jax import lax
from jax._src import linear_util as lu
from jax._src import core
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
2021-10-29 09:23:27 -07:00
from jax.api_util import flatten_fun
from jax.api_util import flatten_fun_nokwargs
2022-07-29 17:27:40 +01:00
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
2021-10-29 09:23:27 -07:00
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
2021-10-29 09:23:27 -07:00
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = Type['JaxException']
Payload = List[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
2021-10-29 09:23:27 -07:00
## Utils
def popattr(obj, attrname):
val = getattr(obj, attrname)
delattr(obj, attrname)
return val
def setnewattr(obj, name, val):
sentinel = object()
assert getattr(obj, name, sentinel) is sentinel
setattr(obj, name, val)
# Concrete errors
class JaxException(Exception):
"""Python exception which can contain an error message with JAX run-time info."""
2021-10-29 09:23:27 -07:00
def __init__(self, traceback_info):
self.traceback_info = traceback_info
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
# self.with_traceback(self.traceback_info)
def __init_subclass__(cls):
jtu.register_pytree_node_class(cls)
def tree_flatten(self):
return ([], self.traceback_info)
@classmethod
def tree_unflatten(cls, metadata, payload):
del payload
return cls(metadata)
def get_effect_type(self) -> core.Effect:
pass
@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect:
error_type: Type[JaxException]
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
def __post_init__(self):
cf.allowed_effects.add(self)
mlir.lowerable_effects.add(self)
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
for sd in x.shape_dtypes)
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
return (unpack(self) < unpack(other))
class DivisionByZeroError(JaxException):
def __str__(self):
return f'division by zero at {self.traceback_info}'
def get_effect_type(self):
return ErrorEffect(DivisionByZeroError, ())
class NaNError(JaxException):
def __init__(self, traceback_info, primitive_name):
super().__init__(traceback_info)
self.prim = primitive_name
def tree_flatten(self):
return ([], (self.traceback_info, self.prim))
@classmethod
def tree_unflatten(cls, metadata, _):
return cls(*metadata)
def get_effect_type(self):
return ErrorEffect(NaNError, ())
def __str__(self):
return f'nan generated by primitive: {self.prim} at {self.traceback_info}'
class OOBError(JaxException):
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
super().__init__(traceback_info)
self.prim = primitive_name
self.operand_shape = operand_shape
self._payload = payload
def tree_flatten(self):
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
@classmethod
def tree_unflatten(cls, metadata, payload):
return cls(*metadata, payload[0])
def __str__(self):
return (f'out-of-bounds indexing for array of '
f'shape {self.operand_shape}: '
f'index {self._payload[0]} is out of bounds for axis '
f'{self._payload[1]} with size {self._payload[2]}. '
f'Failed at {self.traceback_info}')
def get_effect_type(self):
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
class FailedCheckError(JaxException):
def __init__(self, traceback_info, fmt_string, *a, **k):
super().__init__(traceback_info)
self.fmt_string = fmt_string
self.args = a
self.kwargs = k
def tree_flatten(self):
return ((jnp.array([], jnp.int32), self.args, self.kwargs),
(self.traceback_info, self.fmt_string))
@classmethod
def tree_unflatten(cls, metadata, payload):
_, args, kwargs = payload
return cls(*metadata, *args, **kwargs)
def __str__(self):
return (self.fmt_string.format(*self.args, **self.kwargs)
+ f' (check failed at {self.traceback_info})')
def get_effect_type(self):
vals = jtu.tree_leaves((self.args, self.kwargs))
return ErrorEffect(
FailedCheckError,
# Need a 0-size array here for data-dependence.
(jax.ShapeDtypeStruct((0,), jnp.int32),
*tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)))
@dataclasses.dataclass
class BatchedError(JaxException):
error_mapping: Dict[Tuple[int, ...], JaxException]
def __post_init__(self):
traceback_info = list(self.error_mapping.values())[0].traceback_info
super().__init__(traceback_info)
def __str__(self):
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
for idx, e in self.error_mapping.items())
# Error Value
@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
2021-10-29 09:23:27 -07:00
class Error:
_pred: Dict[ErrorEffect, Bool]
_code: Dict[ErrorEffect, Int]
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: Dict[ErrorEffect, Payload]
2021-10-29 09:23:27 -07:00
def get(self) -> Optional[str]:
"""Returns error message if error happened, None if no error happened."""
exp = self.get_exception()
if exp is not None:
return str(exp)
return None
def get_exception(self) -> Optional[JaxException]:
"""Returns Python exception if error happened, None if no error happened."""
if any(map(np.shape, self._pred.values())):
return self._get_batched_exception()
2021-10-29 09:23:27 -07:00
else:
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect]:
if min_code is None or code < min_code:
min_code = code
cur_effect = error_effect
if cur_effect is not None:
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
self._payload[cur_effect])
2021-10-29 09:23:27 -07:00
return None
def throw(self):
_check_error(self)
def __str__(self):
return f'Error({self.get()})'
# Internal helpers
def _get_batched_exception(self):
shape = np.shape(list(self._pred.values())[0])
error_mapping = {}
for idx in np.ndindex(*shape):
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code:
min_code = code[idx] # type: ignore
cur_effect = error_effect
if cur_effect is not None:
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
error_mapping[idx] = jax_error
return BatchedError(error_mapping)
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
new_codes = {**self._code, **{effect_type: code}} # type: ignore
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
new_metadata = {**self._metadata, **metadata}
return Error(new_errs, new_codes, new_metadata, new_payload)
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
new_err = self._pred.copy()
new_code = self._code.copy()
new_payload = self._payload.copy()
for effect in effects:
if effect not in self._pred.keys():
new_err[effect] = False
new_payload[effect] = list(
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
# The error value associated with this effect will never become True, so
# we don't need to set a meaningful code.
new_code[effect] = -1
return Error(new_err, new_code, self._metadata, new_payload)
def _replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)
# PyTree methods
def tree_flatten(self):
return ((self._pred, self._code, self._payload), (self._metadata))
@classmethod
def tree_unflatten(cls, metadata, data):
pred, code, payload = data
return cls(pred, code, metadata, payload)
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
2021-10-29 09:23:27 -07:00
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
2021-10-29 09:23:27 -07:00
code = next_code()
effect_type = new_error.get_effect_type()
new_payload, new_metadata = tree_flatten(new_error)
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
def update_error(error, pred, code, metadata, payload, effect_type):
err_of_type = error._pred.get(effect_type, False)
out_err = err_of_type | pred
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
cur_payload = error._payload.get(effect_type, None)
if cur_payload is not None:
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
else:
out_payload = payload
return error._update(effect_type, out_err, out_code, metadata, out_payload)
2021-10-29 09:23:27 -07:00
2021-10-29 09:23:27 -07:00
## Checkify transformation for plumbing functional error values.
class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors
def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)
2021-10-29 09:23:27 -07:00
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
*in_vals, **params)
2021-10-29 09:23:27 -07:00
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
return [CheckifyTracer(self, x) for x in out]
2021-10-29 09:23:27 -07:00
else:
return CheckifyTracer(self, out)
2021-10-29 09:23:27 -07:00
def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
if 'donated_invars' in params:
params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)),
*params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
2021-10-29 09:23:27 -07:00
def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(jtu.tree_leaves(e))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
2021-10-29 09:23:27 -07:00
@as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk():
out_val_axes = params['out_axes_thunk']()
out_err_num = out_tree().num_leaves - len(out_val_axes)
return (*(0,)*out_err_num, *out_val_axes)
2021-10-29 09:23:27 -07:00
params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']),
2021-10-29 09:23:27 -07:00
out_axes_thunk=new_out_axes_thunk,
donated_invars=(*(False,)*num_error_vals, *params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params_)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
error = _reduce_any_error(error)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
2021-10-29 09:23:27 -07:00
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
setnewattr(main, 'err_tree', err_tree)
2021-10-29 09:23:27 -07:00
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
2021-10-29 09:23:27 -07:00
def post_process_map(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
num_err_leaves = len(err_leaves)
setnewattr(main, 'err_tree', err_tree)
2021-10-29 09:23:27 -07:00
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
error = tree_unflatten(err_tree, err_vals)
error = _reduce_any_error(error)
setnewattr(main, 'error', error)
2021-10-29 09:23:27 -07:00
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
2021-10-29 09:23:27 -07:00
def out_axes_transform(out_axes):
return (*(0,)*num_err_leaves, *out_axes)
return (*err_leaves, *vals), (todo, out_axes_transform)
2021-10-29 09:23:27 -07:00
def process_custom_jvp_call(self, prim, f, jvp, tracers):
2022-02-09 11:18:40 -08:00
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
f = checkify_subtrace(f, self.main)
f, f_out_tree = flatten_fun_nokwargs(f, in_tree)
jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main,
num_error_vals, err_tree)
all_outs = prim.bind(f, jvp, *flat_vals)
fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree)
if fst:
out_err, *out_vals = tree_unflatten(out_tree, all_outs)
else:
err_vals, out_vals = split_list(all_outs, [num_error_vals])
# forward input error values to output
out_err = tree_unflatten(out_tree, err_vals)
setattr(self.main, 'error', out_err)
2022-02-09 11:18:40 -08:00
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_custom_jvp_call(self, tracers, jvp_was_run):
2022-02-09 11:18:40 -08:00
if jvp_was_run:
msg = ('support for custom_jvp rules which close over checkify values is '
'not implemented. If you see this, open an issue at '
'https://github.com/google/jax/issues!')
2022-02-09 11:18:40 -08:00
raise NotImplementedError(msg)
vals = [t.val for t in tracers]
2022-02-09 11:18:40 -08:00
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
2022-02-09 11:18:40 -08:00
def todo(vals):
err_vals, vals = split_list(vals, [len(err_leaves)])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
2022-02-09 11:18:40 -08:00
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
2022-02-09 11:18:40 -08:00
2022-02-09 12:04:34 -08:00
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
fun = checkify_subtrace(fun, self.main)
fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree)
fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main,
err_tree, num_error_vals)
all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees)
fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree)
2022-02-09 12:04:34 -08:00
if fst:
error, *out = tree_unflatten(out_tree, all_out_vals)
2022-02-09 12:04:34 -08:00
else:
_, out = split_list(all_out_vals, [num_error_vals])
# forward input error values to output
error = tree_unflatten(err_tree, err_vals)
setattr(self.main, 'error', error)
2022-02-09 12:04:34 -08:00
return [CheckifyTracer(self, x) for x in out]
def _reduce_any_error(error: Error):
out_error = init_error
for error_effect in error._pred.keys():
errs, codes, payloads = (error._pred[error_effect],
error._code[error_effect],
error._payload[error_effect])
reduced_idx = jnp.argsort(errs)[-1]
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
(errs, codes, payloads))
out_error = out_error._update(error_effect, pred, code, {}, payload)
out_error = out_error._replace(_metadata=error._metadata)
return out_error
2021-10-29 09:23:27 -07:00
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
2021-10-29 09:23:27 -07:00
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
*args):
fun = checkify_subtrace(fun)
fun = checkify_traceable(fun, enabled_errors)
error, *outvals = fun.call_wrapped(init_error, *args)
return error, outvals
2021-10-29 09:23:27 -07:00
@lu.transformation
def checkify_traceable(enabled_errors, error, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, error, *args), {}
2021-10-29 09:23:27 -07:00
del main
yield outs
@lu.transformation
def checkify_subtrace(main, error, *args):
setnewattr(main, 'error', error)
2021-10-29 09:23:27 -07:00
trace = main.with_cur_sublevel()
in_tracers = [CheckifyTracer(trace, x) for x in args]
2021-10-29 09:23:27 -07:00
out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers]
error = main.error
2021-10-29 09:23:27 -07:00
del main.error
yield (error, *out_vals)
2021-10-29 09:23:27 -07:00
2022-02-09 11:18:40 -08:00
@lu.transformation_with_aux
def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args):
2022-02-09 11:18:40 -08:00
# Like checkify_subtrace, but used specifically on the custom JVP rules
# associated with a custom_jvp. This code is called in the context of a
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
# flattened into a single args tuple, and similarly must produce flattened
# primal and tangent outputs. Both primals and tangents include error values,
# but the tangent error values are trivially zero.
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
# where the other Err' components are trivial (of float0 dtype).
# Semantically, we don't add checks to the JVP rule. To check the result of a
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
# just forwards the input error and code (and trivial tangents) to the output.
2022-02-09 12:04:34 -08:00
del main
2022-02-09 11:18:40 -08:00
n, ragged = divmod(len(args), 2)
assert not ragged
err_primals, primals = split_list(args[:n], [num_error_vals])
err_tangents, tangents = split_list(args[n:], [num_error_vals])
2022-02-09 11:18:40 -08:00
outs = yield (*primals, *tangents), {}
m, ragged = divmod(len(outs), 2)
assert not ragged
out_primals, out_tangents = outs[:m], outs[m:]
yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree
2021-10-29 09:23:27 -07:00
2022-02-09 12:04:34 -08:00
@lu.transformation_with_aux
def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args):
del main
2022-02-09 12:04:34 -08:00
# We don't add any checks; just drop input error values.
_, args = split_list(args, [num_error_vals])
2022-02-09 12:04:34 -08:00
outs = yield args, {}
yield outs, err_tree
2022-02-09 12:04:34 -08:00
@lu.transformation_with_aux
def query_error_effects(*args):
(error, *outs) = yield args, {}
yield (error, *outs), set(error._pred.keys())
def checkify_jaxpr(jaxpr, error,
enabled_errors) -> Tuple[core.ClosedJaxpr,
Tuple[PyTreeDef,
FrozenSet[ErrorEffect]]]:
2021-12-08 15:07:43 +00:00
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
2021-12-08 15:07:43 +00:00
def checkify_fun_to_jaxpr(
f, error, enabled_errors,
in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]:
flat_error_vals, in_tree = tree_flatten(error)
f = checkify_subtrace(f)
f = checkify_traceable(f, enabled_errors)
f, error_effect = query_error_effects(f)
in_tree = jtu.tree_structure((error, *in_avals))
f, out_tree = flatten_fun_nokwargs(f, in_tree)
err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)),
flat_error_vals)
avals_in = [*err_vals, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect()))
2021-10-29 09:23:27 -07:00
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, a FailedCheckError error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "{x} needs to be positive!", x=x)
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
"""
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
if not is_scalar_pred(pred):
prim_name = 'debug_check' if debug else 'check'
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
_check_error(error, debug=debug)
def _check_error(error, *, debug=False):
error = tree_map(core.raise_as_much_as_possible, error)
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate when running under checkify, otherwise is a no-op.
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
otherwise the check will be dropped.
Args:
pred: if False, a FailedCheckError error is added.
msg: error message if error is added.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.debug_check(x!=0, "cannot be zero!")
... return x
>>> _ = f(0) # running without checkify means no debug_check is run.
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: cannot be zero!
"""
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
The semantics of this function are equivalent to:
2022-02-08 20:41:19 +00:00
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the :func:`~checkify` transformation.
This function is similar to :func:`~check` but with a different signature: whereas
:func:`~check` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both :func:`~check`
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
:func:`~jax.lax.scan`, etc. Both also can
be functionalized by using :func:`~checkify`.
But unlike :func:`~check`, this function is like a direct inverse of
:func:`~checkify`:
whereas :func:`~checkify` takes as input a function which
can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while :func:`~checkify` goes from
functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
2022-02-08 20:41:19 +00:00
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via
:func:`~checkify`) back into Python Exceptions.
Args:
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through :func:`~jax.jit`, then
re-inject your error value outside of the :func:`~jax.jit`:
>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "must be positive!")
... return x
>>> def with_inner_jit(x):
... checked_f = checkify.checkify(f)
... # a checkified function can be jitted
... error, out = jax.jit(checked_f)(x)
... checkify.check_error(error)
... return out
>>> _ = with_inner_jit(1) # no failed check
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
_check_error(error, debug=False)
## check primitive
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results
# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
pass
@check_p.def_impl
def check_impl(*args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
error = tree_unflatten(err_tree, args)
exc = error.get_exception()
if exc:
raise JaxRuntimeError(str(exc)) from exc
return []
@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree, debug):
del debug
return [], set(tree_unflatten(err_tree, args)._pred.keys())
2022-09-23 15:10:41 +01:00
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
'Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
def check_lowering_rule(ctx, *args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
2022-09-23 15:10:41 +01:00
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, _, keep_alive = mlir.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
token=None,
operands=args,
2022-09-23 15:10:41 +01:00
operand_avals=list(ctx.avals_in),
result_avals=list(ctx.avals_out),
has_side_effect=True)
ctx.module_context.add_keepalive(keep_alive)
return out_op
def check_lowering_rule_unsupported(*a, debug, **k):
if debug:
return []
2022-09-23 15:10:41 +01:00
raise functionalization_error
def python_err(err_tree, *args):
error = tree_unflatten(err_tree, args)
_check_error(error)
return []
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
2022-09-23 15:10:41 +01:00
platform='tpu')
mlir.register_lowering(check_p, check_lowering_rule,
2022-09-23 15:10:41 +01:00
platform='cpu')
mlir.register_lowering(check_p, check_lowering_rule,
2022-09-23 15:10:41 +01:00
platform='gpu')
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
batched_args = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = tree_unflatten(err_tree, batched_args)
_check_error(err, debug=debug)
return [], []
batching.primitive_batchers[check_p] = check_batching_rule
def check_jvp_rule(primals, _, *, err_tree, debug):
# Check primals, discard tangents.
check_p.bind(*primals, err_tree=err_tree, debug=debug)
return [], []
ad.primitive_jvps[check_p] = check_jvp_rule
2021-10-29 09:23:27 -07:00
## checkify rules
def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]:
# TODO(lenamartens): use c++ version from XLA?
tb = None
import inspect
for frame_info in inspect.stack():
frame = frame_info.frame
if skip_frames:
skip_frames -= 1
elif not traceback_util.include_frame(frame):
continue
else:
tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno)
return tb
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
2021-10-29 09:23:27 -07:00
out = prim.bind(*in_vals, **params)
err = check_nans(prim, error, enabled_errors, out)
return out, err
def check_nans(prim, error, enabled_errors, out):
if NaNError not in enabled_errors:
return error
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
return False
return jnp.any(jnp.isnan(x))
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
if prim.multiple_results else isnan(out))
return assert_func(error, any_nans, NaNError(summary(), prim.name))
2021-10-29 09:23:27 -07:00
# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
for prim in nan_primitives:
error_checks[prim] = functools.partial(nan_error_check, prim)
def gather_error_check(error, enabled_errors, operand, start_indices, *,
2021-10-29 09:23:27 -07:00
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
2021-10-29 09:23:27 -07:00
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if OOBError not in enabled_errors:
return out, error
2021-10-29 09:23:27 -07:00
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
num_batch_dims = len(start_indices.shape) - 1
2021-10-29 09:23:27 -07:00
upper_bound = operand_dims[np.array(dnums.start_index_map)]
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
2021-10-29 09:23:27 -07:00
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
error_checks[lax.gather_p] = gather_error_check
2021-10-29 09:23:27 -07:00
def div_error_check(error, enabled_errors, x, y):
2021-12-20 15:56:50 +00:00
"""Checks for division by zero and NaN."""
if DivisionByZeroError in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
2021-12-20 15:56:50 +00:00
error_checks[lax.div_p] = div_error_check
def oob_payload(oob_mask, indices, dims_map, operand_shape):
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand_shape)[oob_axis]
oob_index = jnp.ravel(indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
return payload
def scatter_oob(operand, indices, updates, dnums):
2021-12-23 15:23:58 +00:00
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
np.int64)
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_oob = jnp.less(indices, 0)
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
oob_mask = jnp.logical_or(lower_oob, upper_oob)
payload = oob_payload(oob_mask, indices,
dnums.scatter_dims_to_operand_dims, operand.shape)
return jnp.any(oob_mask), payload
2021-12-23 15:23:58 +00:00
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
2021-12-23 15:23:58 +00:00
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if OOBError not in enabled_errors:
return out, error
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
oob_error = OOBError(summary(), prim.name, operand.shape, payload)
error = assert_func(error, out_of_bounds, oob_error)
return out, check_nans(prim, error, enabled_errors, out)
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
lax.scatter_min_p)
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
lax.scatter_max_p)
2021-12-23 15:23:58 +00:00
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
_, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error,
enabled_errors)
for jxpr in branches)
_, effects = unzip2(out_trees_and_effects)
merged_error = error._add_placeholder_effects(set().union(*effects))
new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error,
enabled_errors)
for jxpr in branches)
out_trees, _ = unzip2(out_trees_and_effects)
flat_error, _ = tree_flatten(merged_error)
new_linear = (*[False] * len(flat_error), *linear)
err_and_outs = lax.cond_p.bind(
index, *flat_error, *ops,
2021-10-29 09:23:27 -07:00
branches=tuple(new_branches), linear=new_linear)
# we need to merge metadata across out_trees (a tuple)
# maybe there's a better way to do this, but we can use the outs
# to unflatten all trees.
err0, *out = tree_unflatten(out_trees[0], err_and_outs)
merged_metadata = err0._metadata
for tr in out_trees[1:]:
err, *_ = tree_unflatten(tr, err_and_outs)
merged_metadata = {**merged_metadata, **err._metadata}
return out, err0._replace(_metadata=merged_metadata)
error_checks[lax.cond_p] = cond_error_check
2021-10-29 09:23:27 -07:00
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
2021-12-06 19:36:35 +00:00
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
_, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors)
merged_error = error._add_placeholder_effects(effects)
checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors)
flat_error_vals, _ = tree_flatten(merged_error)
tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_linear = (*[False] * len(flat_error_vals), *linear)
new_in_flat = [*consts, *flat_error_vals, *carry, *xs]
err_and_out = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals),
2021-12-06 19:36:35 +00:00
linear=new_linear, unroll=unroll)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[lax.scan_p] = scan_error_check
2021-12-06 19:36:35 +00:00
2022-02-22 23:08:22 +00:00
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
2021-12-08 15:07:43 +00:00
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
out = body_f(*vals)
2022-02-22 23:08:22 +00:00
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
body_jaxpr.in_avals)
2021-12-08 15:07:43 +00:00
def ignore_error_output_jaxpr(jaxpr, num_error_vals):
"""Constructs a checked jaxpr which does not output its error value."""
2021-12-08 15:07:43 +00:00
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
2021-12-08 15:07:43 +00:00
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
2022-10-17 21:48:38 +01:00
if cond_jaxpr.out_avals[0].shape:
# TODO(lenamartens, sharadmv): support batched while.
raise ValueError('Checkify does not support batched while-loops '
'(checkify-of-vmap-of-while). \nHint: if possible, move '
'the vmap to the outer level to get '
'vmap-of-checkify-of-while.')
2022-09-16 17:46:25 +01:00
err_vals, _ = tree_flatten(error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr(
cond_jaxpr, error, enabled_errors)
outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry)
error, _ = tree_unflatten(cond_out_tree, outs)
checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr(
2022-10-17 21:48:38 +01:00
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
# merged error!
error = error._add_placeholder_effects(error_effects)
checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
err_vals = jtu.tree_leaves(error)
num_error_vals = len(err_vals)
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
all_out_vals = lax.while_p.bind(
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
# body_out_tree will have all the metadata of cond because it executes a cond!
# only need to merge metadata on the input error.
error, *out = tree_unflatten(body_out_tree, all_out_vals)
return out, error
error_checks[lax.while_p] = while_loop_error_check
2021-12-08 15:07:43 +00:00
2022-07-29 17:27:40 +01:00
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
enabled_errors)
out_error = error._add_placeholder_effects(effects)
flat_error_vals = jtu.tree_leaves(error)
num_error_vals = len(flat_error_vals)
new_vals_in = [*flat_error_vals, *vals_in]
sharding = OpShardingSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)),
*out_shardings)
if config.jax_array:
pos_sem = maps._PositionalSemantics.GLOBAL
else:
pos_sem = maps._positional_semantics.val
2022-07-29 17:27:40 +01:00
if not isinstance(in_positional_semantics, Iterable):
in_positional_semantics = (in_positional_semantics,)
if not isinstance(out_positional_semantics, Iterable):
out_positional_semantics = (out_positional_semantics,)
new_positional_sems_in = (*[pos_sem] * num_error_vals,
*in_positional_semantics)
new_positional_sems_out = (*[pos_sem] * num_error_vals,
*out_positional_semantics)
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
err_and_out = pjit.pjit_p.bind(
2022-07-29 17:27:40 +01:00
*new_vals_in,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out,
keep_unused=keep_unused,
inline=inline)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
2022-07-29 17:27:40 +01:00
error_checks[pjit.pjit_p] = pjit_error_check
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
del debug
new_error = tree_unflatten(err_tree, args)
# Split up new_error into error to be functionalized if it's included in
# enabled_errors (=discharged_error) and an error to be defunctionalized if
# it's not included (=recharged_error)
discharged_error = error
recharged_error = init_error
for error_effect in new_error._pred.keys():
pred = new_error._pred[error_effect]
code = new_error._code[error_effect]
payload = new_error._payload[error_effect]
if error_effect.error_type in enabled_errors:
discharged_error = update_error(discharged_error, pred, code, {}, payload,
error_effect)
else:
recharged_error = update_error(recharged_error, pred, code, {}, payload,
error_effect)
discharged_error = discharged_error._replace(
_metadata={**new_error._metadata, **discharged_error._metadata})
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
# TODO(lenamartens): we actually need to recharge, but this would be a
# breaking API change so leaving for a follow-up.
# check_error(recharged_error)
return [], discharged_error
error_checks[check_p] = check_discharge_rule
## checkify api
user_checks = frozenset({FailedCheckError})
nan_checks = frozenset({NaNError})
index_checks = frozenset({OOBError})
div_checks = frozenset({DivisionByZeroError})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out],
2022-02-08 20:41:19 +00:00
errors: FrozenSet[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added :func:`~check` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
argument.
2022-02-08 20:41:19 +00:00
The returned function will return an Error object `err` along with the output
of the original function. ``err.get()`` will either return ``None`` (if no
error occurred) or a string containing an error message. This error message
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
By default only user-added :func:`~check` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a :func:`~check` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ``div_checks``: a division by zero.
- ``index_checks``: an index was out-of-bounds.
Multiple categories can be enabled together by passing in an error `Set` (eg.
``errors=nan_checks``). Multiple sets can be re-combined (eg.
``errors=float_checks|user_checks``)
Args:
fun: Callable which can contain user checks (see :func:`~check`).
2022-02-08 20:41:19 +00:00
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
DIV errors by passing the ``float_checks`` set, or for
2022-02-08 20:41:19 +00:00
example combine multiple sets through set operations
(``float_checks | user_checks``)
Returns:
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing the first
failed :func:`~check`, and the second element is the original output of
``fun``.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>>
>>> @jax.jit
... def f(x):
... y = jnp.sin(x)
... return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
"""
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
error, out_flat = checkify_flat(f, errors, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return error, out
return checked_fun