rocm_jax/jax/_src/checkify.py

928 lines
37 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from dataclasses import dataclass
from functools import partial
import itertools as it
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable
import numpy as np
import jax.numpy as jnp
from jax import core
from jax import linear_util as lu
from jax.api_util import flatten_fun
from jax.experimental import pjit
from jax.experimental import maps
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.experimental.sharding import OpShardingSharding
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src import source_info_util, traceback_util
from jax._src.lax import control_flow as cf
from jax._src.config import config
from jax import lax
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
## Utils
def popattr(obj, attrname):
val = getattr(obj, attrname)
delattr(obj, attrname)
return val
def setnewattr(obj, name, val):
sentinel = object()
assert getattr(obj, name, sentinel) is sentinel
setattr(obj, name, val)
## Error value data type and functional assert.
Bool = Union[bool, core.Tracer]
Int = Union[int, core.Tracer]
Payload = Union[np.ndarray, jnp.ndarray, core.Tracer]
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
# OOB message.
# TODO(lenamartens): Relax this fixed-size constraint.
init_payload = lambda: np.ones((3,), np.int32)
def _format_msg(msg, payloads):
payload_mapping = {}
for i, pl in enumerate(payloads):
payload_mapping[f'payload{i}'] = pl
return msg.format(**payload_mapping)
@dataclass(frozen=True)
class Error:
err: Bool
code: Int
msgs: Dict[int, str]
# There might be many msgs with a {payload}, but only one msg will
# ever be active for an Error instance, so only one Payload is tracked.
payload: Payload
def __init__(self, err: Bool, code: Int, msgs: Dict[int, str], payload: Optional[Payload] = None):
# We can't directly assign to members of a frozen dataclass, even in __init__.
object.__setattr__(self, "err", err)
object.__setattr__(self, "code", code)
object.__setattr__(self, "msgs", msgs)
object.__setattr__(self, "payload",
init_payload() if payload is None else payload)
def get(self) -> Optional[str]:
"""Returns error message is error happened, None if no error happened."""
assert np.shape(self.err) == np.shape(self.code)
if np.size(self.err) == 1:
if self.err:
return _format_msg(self.msgs[int(self.code)], self.payload)
else:
return '\n'.join(
f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore
f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore
for idx, e in np.ndenumerate(self.err) if e) or None
return None
def throw(self):
"""Throw ValueError with error message if error happened."""
err = self.get()
if err:
raise ValueError(err)
register_pytree_node(Error,
lambda e: ((e.err, e.code, e.payload),
tuple(sorted(e.msgs.items()))),
lambda msgs, data: Error(data[0], data[1], # type: ignore
dict(msgs), data[2])) # type: ignore
init_error = Error(False, 0, {})
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
def assert_func(error: Error, err: Bool, msg: str,
payload: Optional[Payload]) -> Error:
code = next_code()
payload = init_payload() if payload is None else payload
out_err = error.err | err
out_code = lax.select(error.err, error.code, code)
out_payload = lax.select(error.err, error.payload, payload)
return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload)
## Checkify transformation for plumbing functional error values.
class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
core.get_aval(val), val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors
def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
*in_vals, **params)
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
return [CheckifyTracer(self, x) for x in out]
else:
return CheckifyTracer(self, out)
def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
if 'donated_invars' in params:
params = dict(params, donated_invars=(False, False, False,
*params['donated_invars']))
err, code, payload, *out_vals = primitive.bind(f, e.err, e.code, e.payload,
*in_vals, **params)
setnewattr(self.main, 'error', Error(err, code, msgs(), payload))
return [CheckifyTracer(self, x) for x in out_vals]
def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
@as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk():
return (0, 0, 0, *params['out_axes_thunk']())
params_ = dict(params, in_axes=(None, None, None, *params['in_axes']),
out_axes_thunk=new_out_axes_thunk,
donated_invars=(False, False, False, *params['donated_invars']))
errs, codes, payloads, *outs = primitive.bind(f, e.err, e.code, e.payload,
*in_vals, **params_)
err, code, payload = _reduce_any_error(errs, codes, payloads)
setnewattr(self.main, 'error', Error(err, code, msgs(), payload))
return [CheckifyTracer(self, x) for x in outs]
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
def todo(vals):
err, code, payload, *vals = vals
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (err, code, payload, *vals), todo
def post_process_map(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
def todo(vals):
errs, codes, payloads, *vals = vals
err, code, payload = _reduce_any_error(errs, codes, payloads)
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes):
return (0, 0, 0, *out_axes)
return (err, code, payload, *vals), (todo, out_axes_transform)
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
msgs = tuple(e.msgs.items())
fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
jvp, msgs2 = checkify_custom_jvp_subtrace(jvp, self.main, msgs)
err, code, payload, *out_vals = prim.bind(fun, jvp, e.err, e.code,
e.payload, *in_vals)
fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
setattr(self.main, 'error', Error(err, code, out_msgs, payload))
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
if jvp_was_run:
msg = ("support for custom_jvp rules which close over checkify values is "
"not implemented. If you see this, open an issue at "
"https://github.com/google/jax/issues!")
raise NotImplementedError(msg)
vals = [t.val for t in out_tracers]
main = self.main
e = popattr(main, 'error')
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
def todo(vals):
err, code, payload, *vals = vals
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (err, code, payload, *vals), todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
msgs = tuple(e.msgs.items())
fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
fwd, msgs2 = checkify_custom_vjp_subtrace(fwd, self.main, msgs)
out = prim.bind(fun, fwd, bwd, e.err, e.code, e.payload,
*in_vals, out_trees=out_trees)
fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
if fst:
err, code, payload, *out = out
else:
err, code, payload = e.err, e.code, e.payload # forward input error values to output
setattr(self.main, 'error', Error(err, code, out_msgs, payload))
return [CheckifyTracer(self, x) for x in out]
def _reduce_any_error(errs, codes, payloads):
reduced_idx = jnp.argsort(errs)[-1]
return errs[reduced_idx], codes[reduced_idx], payloads[reduced_idx]
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
*args):
fun, msgs = checkify_subtrace(fun)
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
err, code, payload, *outvals = fun.call_wrapped(init_error.err,
init_error.code,
init_error.payload, *args)
return (err, code, payload, outvals), msgs()
@lu.transformation
def checkify_traceable(msgs, enabled_errors, err, code, payload, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, msgs, err, code, payload, *args), {}
del main
yield outs
@lu.transformation_with_aux
def checkify_subtrace(main, msgs, err, code, payload, *args):
setnewattr(main, 'error', Error(err, code, dict(msgs), payload))
trace = main.with_cur_sublevel()
in_tracers = [CheckifyTracer(trace, x) for x in args]
out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers]
err, code, payload, msgs = main.error.err, main.error.code, main.error.payload, main.error.msgs
del main.error
yield (err, code, payload, *out_vals), msgs
@lu.transformation_with_aux
def checkify_custom_jvp_subtrace(main, msgs, *args):
# Like checkify_subtrace, but used specifically on the custom JVP rules
# associated with a custom_jvp. This code is called in the context of a
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
# flattened into a single args tuple, and similarly must produce flattened
# primal and tangent outputs. Both primals and tangents include error values,
# but the tangent error values are trivially zero.
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
# where the other Err' components are trivial (of float0 dtype).
# Semantically, we don't add checks to the JVP rule. To check the result of a
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
# just forwards the input error and code (and trivial tangents) to the output.
del main
n, ragged = divmod(len(args), 2)
assert not ragged
(err,), (code,), (payload,), primals = split_list(args[:n], [1, 1, 1])
(err_dot,), (code_dot,), (pl_dot,), tangents = split_list(args[n:], [1, 1, 1])
outs = yield (*primals, *tangents), {}
m, ragged = divmod(len(outs), 2)
assert not ragged
out_primals, out_tangents = outs[:m], outs[m:]
yield (err, code, payload, *out_primals,
err_dot, code_dot, pl_dot, *out_tangents), dict(msgs)
@lu.transformation_with_aux
def checkify_custom_vjp_subtrace(main, msgs, err, code, payload, *args):
# We don't add any checks; just drop input error values.
del main, err, code, payload
outs = yield args, {}
yield outs, dict(msgs)
# TODO take (error_aval, code_aval) instead of error here?
def checkify_jaxpr(jaxpr, error, enabled_errors):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
f, msgs = checkify_subtrace(f)
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
payload_aval = core.raise_to_shaped(core.get_aval(error.payload))
avals_in = [err_aval, code_aval, payload_aval, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
## assert primitive
def check(pred: Bool, msg: str) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, ``checkify`` it!
Args:
pred: if False, an error is added.
msg: error message if error is added.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x!=0, "cannot be zero!")
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: cannot be zero! (check failed at ...)
"""
if not is_scalar_pred(pred):
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
code = next_code()
msg += f' (check failed at {summary()})'
return check_error(Error(jnp.logical_not(pred), code, {code: msg}))
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``.
The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the ``checkify`` transformation.
This function is similar to ``check`` but with a different signature: whereas
``check`` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both ``check``
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can
be functionalized by using ``checkify``.
But unlike ``check``, this function is like a direct inverse of ``checkify``:
whereas ``checkify`` takes as input a function which can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while ``checkify`` goes from functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via ``checkify``)
back into Python Exceptions.
Args:
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through ``jit``, then re-inject
your error value outside of the ``jit``:
>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "must be positive!")
... return x
>>> def with_inner_jit(x):
... checked_f = checkify.checkify(f)
... # a checkified function can be jitted
... error, out = jax.jit(checked_f)(x)
... checkify.check_error(error)
... return out
>>> _ = with_inner_jit(1) # no failed check
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if np.shape(error.err):
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
else:
err, code, payload = error.err, error.code, error.payload
err = core.raise_as_much_as_possible(err)
return assert_p.bind(err, code, payload, msgs=error.msgs)
assert_p = core.Primitive('assert') # TODO: rename to check?
assert_p.multiple_results = True # zero results
@assert_p.def_impl
def assert_impl(err, code, payload, *, msgs):
Error(err, code, msgs, payload).throw()
return []
CheckEffect = object()
@assert_p.def_effectful_abstract_eval
def assert_abstract_eval(err, code, payload, *, msgs):
return [], {CheckEffect}
def assert_lowering_rule(*a, **k):
# TODO(lenamartens): actually throw an error through emit_python_callable
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
raise ValueError('Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
mlir.register_lowering(assert_p, assert_lowering_rule)
mlir.lowerable_effects.add(CheckEffect)
cf.allowed_effects.add(CheckEffect)
def assert_batching_rule(batched_args, batch_dims, *, msgs):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
err, code, payload = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = Error(err, code, msgs, payload)
check_error(err)
return [], []
batching.primitive_batchers[assert_p] = assert_batching_rule
def assert_jvp_rule(primals, _, *, msgs):
# Check primals, discard tangents.
assert_p.bind(*primals, msgs=msgs)
return [], []
ad.primitive_jvps[assert_p] = assert_jvp_rule
## checkify rules
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
if ErrorCategory.NAN not in enabled_errors:
return out, error
any_nans = jnp.any(jnp.isnan(out))
msg = f'nan generated by primitive {prim.name} at {summary()}'
return out, assert_func(error, any_nans, msg, None)
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if ErrorCategory.OOB not in enabled_errors:
return out, error
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
num_batch_dims = len(start_indices.shape) - 1
upper_bound = operand_dims[np.array(dnums.start_index_map)]
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
out_of_bounds = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(jnp.logical_not(out_of_bounds))
multi_idx = jnp.unravel_index(flat_idx, start_indices.shape)
oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand.shape)[oob_axis]
oob_index = jnp.ravel(start_indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
msg = (f'out-of-bounds indexing at {summary()} for array of '
f'shape {operand.shape}: '
'index {payload0} is out of bounds for axis {payload1} '
'with size {payload2}.')
return out, assert_func(error, jnp.any(out_of_bounds), msg, payload)
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
if ErrorCategory.DIV in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
msg = f'divided by zero at {summary()}'
error = assert_func(error, any_zero, msg, None)
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check
def scatter_oob(operand, indices, updates, dnums):
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
np.int64)
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_oob = jnp.any(jnp.less(indices, 0))
upper_oob = jnp.any(jnp.greater(indices, upper_bound.astype(indices.dtype)))
return jnp.logical_or(lower_oob, upper_oob)
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if ErrorCategory.OOB not in enabled_errors:
return out, error
out_of_bounds = scatter_oob(operand, indices, updates, dimension_numbers)
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
oob_error = assert_func(error, out_of_bounds, oob_msg, None)
any_nans = jnp.any(jnp.isnan(out))
nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
return out, assert_func(oob_error, any_nans, nan_msg, None)
error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors)
for jxpr in branches)
new_linear = (False, False, False, *linear)
err, code, payload, *outs = lax.cond_p.bind(
index, error.err, error.code, error.payload, *ops,
branches=tuple(new_branches), linear=new_linear)
new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()}
return outs, Error(err, code, new_msgs, payload)
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
tomove = [False] * 3 + [True] * len(consts) + [False] * (len(carry) + len(xs))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_linear = (False, False, False, *linear)
new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs]
err, code, payload, *outs = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+3,
linear=new_linear, unroll=unroll)
new_msgs = {**error.msgs, **msgs_}
return outs, Error(err, code, new_msgs, payload)
error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
out = body_f(*vals)
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
body_jaxpr.in_avals)
def ignore_error_output_jaxpr(jaxpr):
"""Constructs a checked jaxpr which does not output its error value."""
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:])
return core.ClosedJaxpr(new_jaxpr, consts)
def batch_error(err, code, payload, batch_shape):
err = jnp.broadcast_to(err, batch_shape)
code = jnp.broadcast_to(code, batch_shape)
payload = jnp.broadcast_to(payload, batch_shape+(3,))
return err, code, payload
def unbatch_error(err, code, payload):
err = err.ravel()[0]
code = code.ravel()[0]
payload = payload.reshape(-1, 3)[0]
return err, code, payload
def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err):
fun = core.jaxpr_as_fun(jaxpr)
def g(err, code, payload, *a):
err_args = unbatch_error(err, code, payload)
err, code, payload, *out = fun(*err_args, *a)
err, code, payload = batch_error(err, code, payload, batch_shape)
return (err, code, payload, *out)
error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err)
new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]])
return core.ClosedJaxpr(new_jaxpr, literals_out)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
batch_shape = cond_jaxpr.out_avals[0].shape
if batch_shape:
err_args = batch_error(error.err, error.code, error.payload, batch_shape)
else:
err_args = [error.err, error.code, error.payload]
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error,
enabled_errors)
if batch_shape:
checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args)
cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)(
*err_args, *c_consts, *carry)
checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
if batch_shape:
checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args)
to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr)
to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry)
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry]
err, code, payload, *out = lax.while_p.bind(
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
if batch_shape:
err, code, payload = unbatch_error(err, code, payload)
return out, Error(err, code, new_msgs, payload)
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):
checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors)
new_vals_in = [error.err, error.code, error.payload, *vals_in]
sharding = OpShardingSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * 3, *in_shardings)
new_out_shardings = (*[sharding] * 3, *out_shardings)
if config.jax_array:
pos_sem = maps._PositionalSemantics.GLOBAL
else:
pos_sem = maps._positional_semantics.val
if not isinstance(in_positional_semantics, Iterable):
in_positional_semantics = (in_positional_semantics,)
if not isinstance(out_positional_semantics, Iterable):
out_positional_semantics = (out_positional_semantics,)
new_positional_sems_in = (*[pos_sem] * 3, *in_positional_semantics)
new_positional_sems_out = (*[pos_sem] * 3, *out_positional_semantics)
new_donated_invars = (*[False] * 3, *donated_invars)
err, code, payload, *vals_out = pjit.pjit_p.bind(
*new_vals_in,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out)
return vals_out, Error(err, code, msgs, payload)
error_checks[pjit.pjit_p] = pjit_error_check
def add_nan_check(prim):
error_checks[prim] = partial(nan_error_check, prim)
add_nan_check(lax.floor_p)
add_nan_check(lax.ceil_p)
add_nan_check(lax.round_p)
add_nan_check(lax.sign_p)
add_nan_check(lax.shift_left_p)
add_nan_check(lax.shift_right_arithmetic_p)
add_nan_check(lax.shift_right_logical_p)
add_nan_check(lax.bitcast_convert_type_p)
add_nan_check(lax.real_p)
add_nan_check(lax.complex_p)
add_nan_check(lax.conj_p)
add_nan_check(lax.imag_p)
add_nan_check(lax.add_p)
add_nan_check(lax.sub_p)
add_nan_check(lax.convert_element_type_p)
add_nan_check(lax.broadcast_in_dim_p)
add_nan_check(lax.concatenate_p)
add_nan_check(lax.pad_p)
add_nan_check(lax.reshape_p)
add_nan_check(lax.rev_p)
add_nan_check(lax.transpose_p)
add_nan_check(lax.slice_p)
add_nan_check(lax.reduce_sum_p)
add_nan_check(lax.reduce_window_sum_p)
add_nan_check(lax.fft_p)
add_nan_check(lax.cumsum_p)
add_nan_check(lax.cumprod_p)
add_nan_check(lax.cummax_p)
add_nan_check(lax.cummin_p)
add_nan_check(lax.erf_p)
add_nan_check(lax.expm1_p)
add_nan_check(lax.log1p_p)
add_nan_check(lax.sqrt_p)
add_nan_check(lax.rsqrt_p)
add_nan_check(lax.asinh_p)
add_nan_check(lax.acosh_p)
add_nan_check(lax.atanh_p)
add_nan_check(lax.erfc_p)
add_nan_check(lax.rem_p)
add_nan_check(lax.clamp_p)
add_nan_check(lax.erf_inv_p)
add_nan_check(lax.exp_p)
add_nan_check(lax.pow_p)
add_nan_check(lax.integer_pow_p)
add_nan_check(lax.tanh_p)
add_nan_check(lax.log_p)
add_nan_check(lax.atan2_p)
add_nan_check(lax.sin_p)
add_nan_check(lax.cos_p)
add_nan_check(lax.sinh_p)
add_nan_check(lax.cosh_p)
add_nan_check(lax.dot_general_p)
add_nan_check(lax.mul_p)
add_nan_check(lax.conv_general_dilated_p)
add_nan_check(lax.reduce_max_p)
add_nan_check(lax.reduce_min_p)
add_nan_check(lax.abs_p)
add_nan_check(lax.select_n_p)
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)
def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs):
if ErrorCategory.USER_CHECK not in enabled_errors:
return [], error
out_err = error.err | err
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs}, payload)
error_checks[assert_p] = assert_discharge_rule
## checkify api
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK'])
user_checks = frozenset({ErrorCategory.USER_CHECK})
nan_checks = frozenset({ErrorCategory.NAN})
index_checks = frozenset({ErrorCategory.OOB})
div_checks = frozenset({ErrorCategory.DIV})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out],
errors: FrozenSet[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added ``checkify.check`` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
argument.
The returned function will return an Error object `err` along with the output
of the original function. ``err.get()`` will either return ``None`` (if no
error occurred) or a string containing an error message. This error message
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
By default only user-added ``checkify.check`` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a ``checkify.check`` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ``div_checks``: a division by zero.
- ``index_checks``: an index was out-of-bounds.
Multiple categories can be enabled together by creating a `Set` (eg.
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be
re-combined (eg. ``errors=float_checks|user_checks``)
Args:
fun: Callable which can contain user checks (see ``check``).
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
DIV errors by passing the ``float_checks`` set, or for
example combine multiple sets through set operations
(``float_checks | user_checks``)
Returns:
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing the first
failed ``check``, and the second element is the original output of ``fun``.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>>
>>> @jax.jit
... def f(x):
... y = jnp.sin(x)
... return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: nan generated by primitive sin
"""
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, payload, out_flat), msgs = checkify_flat(f, errors, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs, payload), out
return checked_fun