mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

This change re-introduces symbolic zero support for `custom_vjp`. This time: * The forward rule API is slightly different, accepting two-field records at pytree leaves rather than pairs. * In the default setting where symbolic_zeros is not set, there are no new requirements from pytree node definitions that are involved in the primal arguments. This avoids any change in behavior on the default path. In particular, custom pytree node definitions that aren't completely polymorphic in unflattening can remain as is. * There is an additional test involving a custom pytree node.
1243 lines
49 KiB
Python
1243 lines
49 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
import itertools as it
|
|
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
|
|
FrozenSet, Type, Set, List, Sequence, Any)
|
|
|
|
import numpy as np
|
|
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
|
|
from jax._src import api
|
|
from jax._src import linear_util as lu
|
|
from jax._src import core
|
|
from jax._src import custom_derivatives
|
|
from jax._src import effects
|
|
from jax._src import pjit
|
|
from jax._src import prng
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import tree_util as jtu
|
|
from jax._src.api_util import flatten_fun
|
|
from jax._src.config import config
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.tree_util import tree_flatten
|
|
from jax._src.tree_util import tree_map
|
|
from jax._src.tree_util import tree_unflatten
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
|
unzip3, weakref_lru_cache)
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Bool = Union[bool, Array]
|
|
Int = Union[int, Array]
|
|
ErrorCategory = Type['JaxException']
|
|
Payload = List[Union[np.ndarray, Array]]
|
|
PyTreeDef = jtu.PyTreeDef
|
|
Out = TypeVar('Out')
|
|
|
|
## Utils
|
|
|
|
def popattr(obj, attrname):
|
|
val = getattr(obj, attrname)
|
|
delattr(obj, attrname)
|
|
return val
|
|
|
|
def setnewattr(obj, name, val):
|
|
sentinel = object()
|
|
assert getattr(obj, name, sentinel) is sentinel
|
|
setattr(obj, name, val)
|
|
|
|
# Concrete errors
|
|
|
|
class JaxException(Exception):
|
|
"""Python exception which can contain an error message with JAX run-time info."""
|
|
|
|
def __init__(self, traceback_info):
|
|
self.traceback_info = traceback_info
|
|
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
|
|
# self.with_traceback(self.traceback_info)
|
|
|
|
def __init_subclass__(cls):
|
|
jtu.register_pytree_node_class(cls)
|
|
|
|
def tree_flatten(self):
|
|
return ([], self.traceback_info)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
del payload
|
|
return cls(metadata)
|
|
|
|
def get_effect_type(self) -> ErrorEffect:
|
|
raise NotImplementedError
|
|
|
|
|
|
@functools.total_ordering
|
|
@dataclasses.dataclass(eq=True, frozen=True)
|
|
class ErrorEffect(effects.Effect):
|
|
error_type: Type[JaxException]
|
|
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
|
|
|
|
def __lt__(self, other: 'ErrorEffect'):
|
|
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
|
for sd in x.shape_dtypes)
|
|
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
|
|
return (unpack(self) < unpack(other))
|
|
|
|
effects.control_flow_allowed_effects.add_type(ErrorEffect)
|
|
effects.lowerable_effects.add_type(ErrorEffect)
|
|
|
|
class DivisionByZeroError(JaxException):
|
|
|
|
def __str__(self):
|
|
return f'division by zero at {self.traceback_info}'
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(DivisionByZeroError, ())
|
|
|
|
class NaNError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
|
|
def tree_flatten(self):
|
|
return ([], (self.traceback_info, self.prim))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, _):
|
|
return cls(*metadata)
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(NaNError, ())
|
|
|
|
def __str__(self):
|
|
return f'nan generated by primitive: {self.prim} at {self.traceback_info}'
|
|
|
|
class OOBError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
self.operand_shape = operand_shape
|
|
self._payload = payload
|
|
|
|
def tree_flatten(self):
|
|
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
return cls(*metadata, payload[0])
|
|
|
|
def __str__(self):
|
|
return (f'out-of-bounds indexing for array of '
|
|
f'shape {self.operand_shape}: '
|
|
f'index {self._payload[0]} is out of bounds for axis '
|
|
f'{self._payload[1]} with size {self._payload[2]}. '
|
|
f'Failed at {self.traceback_info}')
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))
|
|
|
|
class FailedCheckError(JaxException):
|
|
|
|
def __init__(self, traceback_info, fmt_string, *a, **k):
|
|
super().__init__(traceback_info)
|
|
self.fmt_string = fmt_string
|
|
self.args = a
|
|
self.kwargs = k
|
|
|
|
def tree_flatten(self):
|
|
return ((self.args, self.kwargs), # leaves
|
|
(self.traceback_info, self.fmt_string)) # treedef
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
args, kwargs = payload
|
|
return cls(*metadata, *args, **kwargs)
|
|
|
|
def __str__(self):
|
|
return (self.fmt_string.format(*self.args, **self.kwargs)
|
|
+ f' (check failed at {self.traceback_info})')
|
|
|
|
def get_effect_type(self):
|
|
vals = jtu.tree_leaves((self.args, self.kwargs))
|
|
return ErrorEffect(
|
|
FailedCheckError,
|
|
tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
|
|
|
|
@dataclasses.dataclass
|
|
class BatchedError(JaxException):
|
|
error_mapping: Dict[Tuple[int, ...], JaxException]
|
|
|
|
def __post_init__(self):
|
|
traceback_info = list(self.error_mapping.values())[0].traceback_info
|
|
super().__init__(traceback_info)
|
|
|
|
|
|
def __str__(self):
|
|
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
|
|
for idx, e in self.error_mapping.items())
|
|
|
|
|
|
# Error Value
|
|
|
|
@jtu.register_pytree_node_class
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Error:
|
|
_pred: Dict[ErrorEffect, Bool]
|
|
_code: Dict[ErrorEffect, Int]
|
|
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
|
|
_payload: Dict[ErrorEffect, Payload]
|
|
|
|
def get(self) -> Optional[str]:
|
|
"""Returns error message if error happened, None if no error happened."""
|
|
exp = self.get_exception()
|
|
if exp is not None:
|
|
return str(exp)
|
|
return None
|
|
|
|
def get_exception(self) -> Optional[JaxException]:
|
|
"""Returns Python exception if error happened, None if no error happened."""
|
|
if any(map(np.shape, self._pred.values())):
|
|
return self._get_batched_exception()
|
|
else:
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect]:
|
|
if min_code is None or code < min_code:
|
|
min_code = code
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
|
|
self._payload[cur_effect])
|
|
return None
|
|
|
|
def throw(self):
|
|
_check_error(self)
|
|
|
|
def __str__(self):
|
|
return f'Error({self.get()})'
|
|
|
|
# Internal helpers
|
|
|
|
def _get_batched_exception(self) -> Optional[BatchedError]:
|
|
shape = np.shape(list(self._pred.values())[0])
|
|
error_mapping = {}
|
|
for idx in np.ndindex(*shape):
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect][idx]: # type: ignore
|
|
if min_code is None or code[idx] < min_code:
|
|
min_code = code[idx] # type: ignore
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
|
|
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
|
|
error_mapping[idx] = jax_error
|
|
if error_mapping:
|
|
return BatchedError(error_mapping)
|
|
else:
|
|
return None
|
|
|
|
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
|
|
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
|
|
new_codes = {**self._code, **{effect_type: code}} # type: ignore
|
|
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
|
|
new_metadata = {**self._metadata, **metadata}
|
|
return Error(new_errs, new_codes, new_metadata, new_payload)
|
|
|
|
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
|
|
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
|
|
new_err = self._pred.copy()
|
|
new_code = self._code.copy()
|
|
new_payload = self._payload.copy()
|
|
for effect in effects:
|
|
if effect not in self._pred.keys():
|
|
new_err[effect] = False
|
|
new_payload[effect] = list(
|
|
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
|
|
# The error value associated with this effect will never become True, so
|
|
# we don't need to set a meaningful code.
|
|
new_code[effect] = -1
|
|
return Error(new_err, new_code, self._metadata, new_payload)
|
|
|
|
def _replace(self, *args, **kwargs):
|
|
return dataclasses.replace(self, *args, **kwargs)
|
|
|
|
# PyTree methods
|
|
|
|
def tree_flatten(self):
|
|
return ((self._pred, self._code, self._payload), (self._metadata))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, data):
|
|
pred, code, payload = data
|
|
return cls(pred, code, metadata, payload)
|
|
|
|
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
|
|
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
|
|
|
|
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
|
|
code = next_code()
|
|
effect_type = new_error.get_effect_type()
|
|
new_payload, new_metadata = tree_flatten(new_error)
|
|
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
|
|
|
|
def update_error(error, pred, code, metadata, payload, effect_type):
|
|
err_of_type = error._pred.get(effect_type, False)
|
|
out_err = err_of_type | pred
|
|
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
|
|
cur_payload = error._payload.get(effect_type, None)
|
|
if cur_payload is not None:
|
|
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
|
|
else:
|
|
out_payload = payload
|
|
return error._update(effect_type, out_err, out_code, metadata, out_payload)
|
|
|
|
|
|
## Checkify transformation for plumbing functional error values.
|
|
|
|
@lu.transformation_with_aux
|
|
def _flatten_and_get_error_metadata_thunk(*invals):
|
|
error, out = yield invals, {}
|
|
out_vals, out_tree = jtu.tree_flatten((error, out))
|
|
yield out_vals, (out_tree, set(error._pred.keys()))
|
|
|
|
def default_checkify_rule(primitive: core.Primitive, error: Error,
|
|
enabled_errors, *invals: core.Value,
|
|
**params: Any) -> Tuple[Error, Sequence[core.Value]]:
|
|
"""Default rule for primitives in `checkify` interpreter."""
|
|
if 'call_jaxpr' not in params:
|
|
# Default non-HOP case: just call primitive and don't update error.
|
|
return error, primitive.bind(*invals, **params)
|
|
|
|
# Code below handles call- and map-primitives, by recursively calling
|
|
# checkify_jaxpr.
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
num_error_vals = len(err_vals)
|
|
if 'donated_invars' in params:
|
|
params = dict(params, donated_invars=(*[False]*num_error_vals,
|
|
*params['donated_invars']))
|
|
|
|
# call_jaxpr handling
|
|
call_jaxpr = params.pop('call_jaxpr')
|
|
partial_checkify = lu.hashable_partial(lu.wrap_init(
|
|
checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree)
|
|
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
|
|
partial_checkify)
|
|
|
|
# map-specific params handling.
|
|
if isinstance(primitive, core.MapPrimitive):
|
|
# Update `in_axes` and `out_axes_thunk` params for map primitive.
|
|
out_val_axes = params.pop('out_axes')
|
|
|
|
@as_hashable_function(closure=out_val_axes)
|
|
def out_axes_thunk():
|
|
out_err_num = metadata()[0].num_leaves - len(out_val_axes)
|
|
return (*(0,)*out_err_num, *out_val_axes)
|
|
|
|
params = dict(params,
|
|
in_axes=(*(None,)*num_error_vals, *params['in_axes']),
|
|
out_axes_thunk=out_axes_thunk)
|
|
|
|
all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params)
|
|
|
|
out_tree, _ = metadata()
|
|
error, out_vals = tree_unflatten(out_tree, all_vals)
|
|
if isinstance(primitive, core.MapPrimitive):
|
|
error = _reduce_any_error(error)
|
|
return error, out_vals
|
|
|
|
def get_shaped_aval(val):
|
|
return core.raise_to_shaped(core.get_aval(val))
|
|
|
|
def initial_style_jaxpr(
|
|
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
|
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
|
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
|
|
|
@weakref_lru_cache
|
|
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
|
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
|
|
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
debug = pe.debug_info(fun, in_tree, out_tree, False, 'checkify')
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
|
return jaxpr, consts, out_tree()
|
|
|
|
|
|
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
|
error: Error, *args) -> Tuple[Error, List[core.Value]]:
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
|
|
enabled_errors, err_tree, *err_vals, *args)
|
|
|
|
def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
|
enabled_errors, err_tree: PyTreeDef,
|
|
*args: core.Value) -> Tuple[Error, List[Any]]:
|
|
env: Dict[core.Var, Any] = {}
|
|
err_vals, in_args = split_list(args, [err_tree.num_leaves])
|
|
error = jtu.tree_unflatten(err_tree, err_vals)
|
|
|
|
def read_env(var: core.Atom):
|
|
if isinstance(var, core.Literal):
|
|
return var.val
|
|
return env[var]
|
|
|
|
def write_env(var: core.Var, val: Any):
|
|
env[var] = val
|
|
|
|
map(write_env, jaxpr.constvars, consts)
|
|
map(write_env, jaxpr.invars, in_args)
|
|
|
|
# interpreter loop
|
|
for eqn in jaxpr.eqns:
|
|
invals = map(read_env, eqn.invars)
|
|
checkify_rule = error_checks.get(
|
|
eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive))
|
|
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
|
with source_info_util.user_context(eqn.source_info.traceback,
|
|
name_stack=name_stack):
|
|
error, outvals = checkify_rule(error, enabled_errors,
|
|
*invals, **eqn.params)
|
|
if eqn.primitive.multiple_results:
|
|
map(write_env, eqn.outvars, outvals)
|
|
else:
|
|
write_env(eqn.outvars[0], outvals)
|
|
|
|
return error, map(read_env, jaxpr.outvars)
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_output(*args):
|
|
ans = yield args, {}
|
|
yield tree_flatten(ans)
|
|
|
|
|
|
def _reduce_any_error(error: Error):
|
|
out_error = init_error
|
|
for error_effect in error._pred.keys():
|
|
errs, codes, payloads = (error._pred[error_effect],
|
|
error._code[error_effect],
|
|
error._payload[error_effect])
|
|
reduced_idx = jnp.argsort(errs)[-1]
|
|
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
|
|
(errs, codes, payloads))
|
|
out_error = out_error._update(error_effect, pred, code, {}, payload)
|
|
|
|
out_error = out_error._replace(_metadata=error._metadata)
|
|
return out_error
|
|
|
|
## check_p primitive
|
|
|
|
check_p = core.Primitive('check')
|
|
check_p.multiple_results = True # zero results
|
|
|
|
# TODO(lenamartens): inherit from Exception instead of ValueError.
|
|
class JaxRuntimeError(ValueError):
|
|
pass
|
|
|
|
@check_p.def_impl
|
|
def check_impl(*args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
error = tree_unflatten(err_tree, args)
|
|
exc = error.get_exception()
|
|
if exc:
|
|
raise JaxRuntimeError(str(exc)) from exc
|
|
return []
|
|
|
|
@check_p.def_effectful_abstract_eval
|
|
def check_abstract_eval(*args, err_tree, debug):
|
|
del debug
|
|
return [], set(tree_unflatten(err_tree, args)._pred.keys())
|
|
|
|
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
|
functionalization_error = ValueError(
|
|
'Cannot abstractly evaluate a checkify.check which was not'
|
|
' functionalized. This probably means you tried to stage'
|
|
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
|
' through `checkify.checkify`.'
|
|
)
|
|
|
|
def check_lowering_rule(ctx, *args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
|
raise functionalization_error
|
|
|
|
out_op, _, keep_alive = mlir.emit_python_callback(
|
|
ctx, callback=functools.partial(python_err, err_tree),
|
|
token=None,
|
|
operands=args,
|
|
operand_avals=list(ctx.avals_in),
|
|
result_avals=list(ctx.avals_out),
|
|
has_side_effect=True)
|
|
ctx.module_context.add_keepalive(keep_alive)
|
|
return out_op
|
|
|
|
def check_lowering_rule_unsupported(*a, debug, **k):
|
|
if debug:
|
|
return []
|
|
raise functionalization_error
|
|
|
|
def python_err(err_tree, *args):
|
|
error = tree_unflatten(err_tree, args)
|
|
_check_error(error)
|
|
return []
|
|
|
|
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
|
|
platform='tpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='cpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='gpu')
|
|
|
|
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
|
|
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
|
|
if dim is not batching.not_mapped)
|
|
batched_args = (batching.bdim_at_front(a, d, size)
|
|
for a, d in zip(batched_args, batch_dims))
|
|
err = tree_unflatten(err_tree, batched_args)
|
|
_check_error(err, debug=debug)
|
|
return [], []
|
|
batching.primitive_batchers[check_p] = check_batching_rule
|
|
|
|
def check_jvp_rule(primals, _, *, err_tree, debug):
|
|
# Check primals, discard tangents.
|
|
check_p.bind(*primals, err_tree=err_tree, debug=debug)
|
|
return [], []
|
|
ad.primitive_jvps[check_p] = check_jvp_rule
|
|
|
|
## checkify rules
|
|
|
|
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
|
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
|
|
|
|
|
def summary() -> str:
|
|
return str(source_info_util.summarize(source_info_util.current()))
|
|
|
|
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
|
out = prim.bind(*in_vals, **params)
|
|
err = check_nans(prim, error, enabled_errors, out)
|
|
return err, out
|
|
|
|
def check_nans(prim, error, enabled_errors, out):
|
|
if NaNError not in enabled_errors:
|
|
return error
|
|
|
|
def isnan(x):
|
|
if isinstance(x, prng.PRNGKeyArray):
|
|
return False
|
|
return jnp.any(jnp.isnan(x))
|
|
|
|
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
|
|
if prim.multiple_results else isnan(out))
|
|
return assert_func(error, any_nans, NaNError(summary(), prim.name))
|
|
|
|
|
|
# All primitives which can generate a NaN.
|
|
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
|
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
|
|
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
|
|
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
|
|
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
|
|
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
|
|
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
|
|
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
|
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
|
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
|
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
|
lax.reduce_sum_p, lax.reduce_window_p,
|
|
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
|
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
|
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
|
|
|
|
for _prim in nan_primitives:
|
|
error_checks[_prim] = functools.partial(nan_error_check, _prim)
|
|
|
|
|
|
def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
|
|
out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
operand_dims = np.array(operand.shape)
|
|
slice_sizes = np.array(slice_sizes)
|
|
start_indices = jnp.array(start_indices)
|
|
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)
|
|
|
|
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
|
|
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_slice", operand.shape, payload))
|
|
return error, out
|
|
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check
|
|
|
|
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
indices_are_sorted, mode, fill_value):
|
|
out = lax.gather_p.bind(
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
# compare to OOB masking logic in lax._gather_translation_rule
|
|
dnums = dimension_numbers
|
|
operand_dims = np.array(operand.shape)
|
|
num_batch_dims = len(start_indices.shape) - 1
|
|
|
|
upper_bound = operand_dims[np.array(dnums.start_index_map)]
|
|
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
|
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
|
|
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
|
|
|
|
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
|
|
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
|
|
return error, out
|
|
error_checks[lax.gather_p] = gather_error_check
|
|
|
|
def div_error_check(error, enabled_errors, x, y):
|
|
"""Checks for division by zero and NaN."""
|
|
if DivisionByZeroError in enabled_errors:
|
|
any_zero = jnp.any(jnp.equal(y, 0))
|
|
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
|
|
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
|
error_checks[lax.div_p] = div_error_check
|
|
|
|
def oob_payload(oob_mask, indices, dims_map, operand_shape):
|
|
# Get first OOB index, axis and axis size so it can be added to the error msg.
|
|
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
|
|
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
|
|
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
|
|
oob_axis_size = jnp.array(operand_shape)[oob_axis]
|
|
oob_index = jnp.ravel(indices)[flat_idx]
|
|
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
|
|
return payload
|
|
|
|
def scatter_oob(operand, indices, updates, dnums):
|
|
# Ref: see clamping code used in scatter_translation_rule
|
|
slice_sizes = []
|
|
pos = 0
|
|
for i in range(len(operand.shape)):
|
|
if i in dnums.inserted_window_dims:
|
|
slice_sizes.append(1)
|
|
else:
|
|
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
|
pos += 1
|
|
|
|
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
|
for i in dnums.scatter_dims_to_operand_dims],
|
|
np.int64)
|
|
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
|
(len(indices.shape) - 1,))
|
|
|
|
lower_oob = jnp.less(indices, 0)
|
|
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
|
|
oob_mask = jnp.logical_or(lower_oob, upper_oob)
|
|
payload = oob_payload(oob_mask, indices,
|
|
dnums.scatter_dims_to_operand_dims, operand.shape)
|
|
return jnp.any(oob_mask), payload
|
|
|
|
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
|
|
*, update_jaxpr, update_consts, dimension_numbers,
|
|
indices_are_sorted, unique_indices, mode):
|
|
"""Checks if indices are within bounds and update does not generate NaN."""
|
|
out = prim.bind(
|
|
operand, indices, updates, update_jaxpr=update_jaxpr,
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
|
mode=mode)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
|
|
oob_error = OOBError(summary(), prim.name, operand.shape, payload)
|
|
error = assert_func(error, out_of_bounds, oob_error)
|
|
error = check_nans(prim, error, enabled_errors, out)
|
|
return error, out
|
|
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
|
|
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_add_p)
|
|
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_mul_p)
|
|
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_min_p)
|
|
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_max_p)
|
|
|
|
# HOP error check rules
|
|
|
|
@weakref_lru_cache
|
|
def jaxpr_to_checkify_jaxpr(
|
|
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
|
|
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
|
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
|
|
jaxpr.consts, enabled_errors,
|
|
err_tree)
|
|
fun = lu.wrap_init(checkify_jaxpr_partial)
|
|
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
|
|
|
|
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
|
|
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
|
out_tree, error_effects = metadata()
|
|
return checked_jaxpr, out_tree, error_effects
|
|
|
|
def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear):
|
|
# Get the error-effects out of all branches so the cond can be called with
|
|
# a merged error with all these effects.
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
|
def get_error_effects_from_jaxpr(jxpr):
|
|
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
|
|
*in_avals)
|
|
return effects
|
|
effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches]
|
|
merged_error = error._add_placeholder_effects(set().union(*effects))
|
|
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
|
new_linear = (*[False] * len(err_vals), *linear)
|
|
|
|
# Update branch jaxprs to be checkified jaxprs.
|
|
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
|
new_branches, out_trees, _ = unzip3(
|
|
jaxpr_to_checkify_jaxpr(
|
|
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
|
|
|
|
err_and_outs = lax.cond_p.bind(
|
|
index, *err_vals, *ops,
|
|
branches=tuple(new_branches), linear=new_linear)
|
|
|
|
# we need to merge metadata across out_trees (a tuple)
|
|
err0, out = tree_unflatten(out_trees[0], err_and_outs)
|
|
merged_metadata = err0._metadata
|
|
for tr in out_trees[1:]:
|
|
err, _ = tree_unflatten(tr, err_and_outs)
|
|
merged_metadata = {**merged_metadata, **err._metadata}
|
|
return err0._replace(_metadata=merged_metadata), out
|
|
error_checks[lax.cond_p] = cond_error_check
|
|
|
|
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
|
num_consts, num_carry, linear, unroll):
|
|
|
|
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
|
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
|
|
# Query body effects to create a merged error containing all effects (such
|
|
# that in and out carried error are of the same type).
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
|
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *new_in_aval)
|
|
|
|
merged_error = error._add_placeholder_effects(effects)
|
|
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
|
|
|
# Create checked-jaxpr, with the needed pre-processing on the inputs.
|
|
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
|
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *new_in_aval)
|
|
|
|
new_in_flat = [*consts, *err_vals, *carry, *xs]
|
|
new_linear = (*[False] * len(err_vals), *linear)
|
|
tomove = ([False] * len(err_vals) + [True] * len(consts)
|
|
+ [False] * (len(carry) + len(xs)))
|
|
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
|
|
new_in_flat = [*consts, *err_vals, *carry, *xs]
|
|
err_and_out = lax.scan_p.bind(
|
|
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
|
num_consts=len(consts), num_carry=len(carry)+len(err_vals),
|
|
linear=new_linear, unroll=unroll)
|
|
err, out = tree_unflatten(out_tree, err_and_out)
|
|
return err, out
|
|
|
|
error_checks[lax.scan_p] = scan_error_check
|
|
|
|
def checkify_while_body_jaxpr(
|
|
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
|
|
enabled_errors, error: Error,
|
|
c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
|
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
|
body_f = core.jaxpr_as_fun(body_jaxpr)
|
|
def new_body_f(*vals):
|
|
out = body_f(*vals)
|
|
# This checks if the next cond application will error
|
|
_ = cond_f(*c_consts, *out)
|
|
return out
|
|
new_body_f_ = lu.wrap_init(new_body_f)
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals)
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
err_vals = map(get_shaped_aval, err_vals)
|
|
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
|
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
|
|
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
|
return jaxpr, out_tree, error_effects
|
|
|
|
def ignore_error_output_jaxpr(jaxpr, num_error_vals):
|
|
"""Constructs a checked jaxpr which does not output its error value."""
|
|
consts = jaxpr.consts
|
|
jaxpr = jaxpr.jaxpr
|
|
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
|
|
return core.ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
|
cond_jaxpr, body_nconsts, body_jaxpr):
|
|
if cond_jaxpr.out_avals[0].shape:
|
|
# TODO(lenamartens, sharadmv): support batched while.
|
|
raise ValueError('Checkify does not support batched while-loops '
|
|
'(checkify-of-vmap-of-while). \nHint: if possible, move '
|
|
'the vmap to the outer level to get '
|
|
'vmap-of-checkify-of-while.')
|
|
|
|
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
|
# Check if the first cond application will error.
|
|
error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)
|
|
|
|
_, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
|
|
enabled_errors, error, c_consts)
|
|
# merged error!
|
|
error = error._add_placeholder_effects(error_effects)
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
|
|
cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts)
|
|
num_error_vals = len(err_vals)
|
|
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
|
|
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
|
|
|
cond_in_flat = [*err_vals, *c_consts, *carry]
|
|
cond_in_flat = map(get_shaped_aval, cond_in_flat)
|
|
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
|
|
err_tree, *cond_in_flat)
|
|
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
|
|
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
|
|
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
|
|
|
|
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
|
|
all_out_vals = lax.while_p.bind(
|
|
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
|
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
|
|
# body_out_tree will have all the metadata of cond because it executes a cond!
|
|
error, out = tree_unflatten(body_out_tree, all_out_vals)
|
|
return error, out
|
|
error_checks[lax.while_p] = while_loop_error_check
|
|
|
|
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
|
in_shardings, out_shardings, resource_env,
|
|
donated_invars, name,
|
|
inline, keep_unused):
|
|
# jaxpr to checked_jaxpr
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
new_vals_in = [*err_vals, *vals_in]
|
|
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
|
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *in_avals)
|
|
|
|
# Update pjit params to account for extra error values.
|
|
num_error_vals = len(err_vals)
|
|
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
|
sharding = pjit._UNSPECIFIED
|
|
|
|
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
|
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
|
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
|
|
|
|
err_and_out = pjit.pjit_p.bind(
|
|
*new_vals_in,
|
|
jaxpr=checked_jaxpr,
|
|
in_shardings=new_in_shardings,
|
|
out_shardings=new_out_shardings,
|
|
resource_env=resource_env,
|
|
donated_invars=new_donated_invars,
|
|
name=name,
|
|
inline=inline,
|
|
keep_unused=keep_unused,
|
|
)
|
|
return tree_unflatten(out_tree, err_and_out)
|
|
error_checks[pjit.pjit_p] = pjit_error_check
|
|
|
|
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
|
jvp_jaxpr_thunk, call_jaxpr, **params):
|
|
# The types to have in mind are:
|
|
# jvp : (a -> b) -> (a, T a) -> (b, T b)
|
|
# checkify : (a -> b) -> a -> Err b
|
|
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
|
|
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
|
|
# where the other Err' components are trivial (of float0 dtype).
|
|
# Semantically, we don't add checks to the JVP rule. To check the result of a
|
|
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
|
|
# just forwards the input error and code (and trivial tangents) to the output.
|
|
err_vals, err_tree = jtu.tree_flatten(in_err)
|
|
partial_checkify = lu.wrap_init(
|
|
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
|
|
call_jaxpr.consts, enabled_errors, err_tree))
|
|
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
|
|
partial_checkify)
|
|
|
|
# Construct the default jvp function, without checkify-ing.
|
|
@lu.wrap_init
|
|
def jvp(*xs):
|
|
# TODO(lenamartens, sharadmv): why not checkify here?
|
|
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()
|
|
n, ragged = divmod(len(xs), 2)
|
|
assert not ragged
|
|
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
|
|
return core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *tangents)
|
|
|
|
jvp, jvp_out_tree = flatten_fun_output(jvp)
|
|
all_outs = custom_derivatives.custom_jvp_call_p.bind(
|
|
partial_checkify, jvp, *err_vals, *in_vals, **params)
|
|
fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree)
|
|
if fst:
|
|
err_and_out_tree, _ = out_metadata
|
|
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
|
|
else:
|
|
err_vals, out_vals = split_list(all_outs, [len(err_vals)])
|
|
# forward input error to output
|
|
out_err = jtu.tree_unflatten(err_tree, err_vals)
|
|
return out_err, out_vals
|
|
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
|
|
|
|
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
|
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
|
|
symbolic_zeros):
|
|
err_vals, err_tree = jtu.tree_flatten(in_err)
|
|
fun = lu.wrap_init(
|
|
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
|
|
fun_jaxpr.consts, enabled_errors, err_tree))
|
|
fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun)
|
|
|
|
@lu.wrap_init
|
|
def fwd(*args):
|
|
# TODO(lenamartens, sharadmv): why not checkify here?
|
|
xs, zeros = args[::2], args[1::2]
|
|
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
|
|
xs_without_consts = xs[num_consts:]
|
|
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
|
|
|
|
fwd, fwd_out_tree = flatten_fun_output(fwd)
|
|
all_outs = custom_derivatives.custom_vjp_call_p.bind(
|
|
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees,
|
|
symbolic_zeros=symbolic_zeros)
|
|
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
|
|
if fst:
|
|
err_and_out_tree, _ = out_metadata
|
|
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
|
|
else:
|
|
err_vals, out_vals = split_list(all_outs, [len(err_vals)])
|
|
# forward input error to output
|
|
out_err = jtu.tree_unflatten(err_tree, err_vals)
|
|
return out_err, out_vals
|
|
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule
|
|
|
|
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
|
|
del debug
|
|
new_error = tree_unflatten(err_tree, args)
|
|
# Split up new_error into error to be functionalized if it's included in
|
|
# enabled_errors (=discharged_error) and an error to be defunctionalized if
|
|
# it's not included (=recharged_error)
|
|
discharged_error = error
|
|
recharged_error = init_error
|
|
for error_effect in new_error._pred.keys():
|
|
pred = new_error._pred[error_effect]
|
|
code = new_error._code[error_effect]
|
|
payload = new_error._payload[error_effect]
|
|
if error_effect.error_type in enabled_errors:
|
|
discharged_error = update_error(discharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
else:
|
|
recharged_error = update_error(recharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
|
|
discharged_error = discharged_error._replace(
|
|
_metadata={**new_error._metadata, **discharged_error._metadata})
|
|
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
|
|
# TODO(lenamartens): we actually need to recharge, but this would be a
|
|
# breaking API change so leaving for a follow-up.
|
|
# check_error(recharged_error)
|
|
return discharged_error, []
|
|
error_checks[check_p] = check_discharge_rule
|
|
|
|
|
|
## checkify public api
|
|
|
|
user_checks = frozenset({FailedCheckError})
|
|
nan_checks = frozenset({NaNError})
|
|
index_checks = frozenset({OOBError})
|
|
div_checks = frozenset({DivisionByZeroError})
|
|
float_checks = nan_checks | div_checks
|
|
automatic_checks = float_checks | index_checks
|
|
all_checks = automatic_checks | user_checks
|
|
|
|
|
|
def checkify(f: Callable[..., Out],
|
|
errors: FrozenSet[ErrorCategory] = user_checks
|
|
) -> Callable[..., Tuple[Error, Out]]:
|
|
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
|
|
|
Run-time errors are either user-added :func:`~check` assertions, or
|
|
automatically added checks like NaN checks, depending on the ``errors``
|
|
argument.
|
|
|
|
The returned function will return an Error object `err` along with the output
|
|
of the original function. ``err.get()`` will either return ``None`` (if no
|
|
error occurred) or a string containing an error message. This error message
|
|
will correspond to the first error which occurred. ``err.throw()`` will raise
|
|
a ValueError with the error message if an error occurred.
|
|
|
|
By default only user-added :func:`~check` assertions are enabled. You can
|
|
enable automatic checks through the ``errors`` argument.
|
|
|
|
The automatic check sets which can be enabled, and when an error is generated:
|
|
- ``user_checks``: a :func:`~check` evaluated to False.
|
|
- ``nan_checks``: a floating-point operation generated a NaN value
|
|
as output.
|
|
- ``div_checks``: a division by zero.
|
|
- ``index_checks``: an index was out-of-bounds.
|
|
|
|
Multiple categories can be enabled together by passing in an error `Set` (eg.
|
|
``errors=nan_checks``). Multiple sets can be re-combined (eg.
|
|
``errors=float_checks|user_checks``)
|
|
|
|
Args:
|
|
fun: Callable which can contain user checks (see :func:`~check`).
|
|
errors: A set of ErrorCategory values which defines the set of enabled
|
|
checks. By default only explicit ``checks`` are enabled
|
|
(``user_checks``). You can also for example enable NAN and
|
|
DIV errors by passing the ``float_checks`` set, or for
|
|
example combine multiple sets through set operations
|
|
(``float_checks | user_checks``)
|
|
Returns:
|
|
A function which accepts the same arguments as ``fun`` and returns as output
|
|
a pair where the first element is an ``Error`` value, representing the first
|
|
failed :func:`~check`, and the second element is the original output of
|
|
``fun``.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>>
|
|
>>> @jax.jit
|
|
... def f(x):
|
|
... y = jnp.sin(x)
|
|
... return x+y
|
|
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
|
|
"""
|
|
@traceback_util.api_boundary
|
|
def checked_fun(*args, **kwargs):
|
|
# stage:
|
|
flat_args, in_tree = tree_flatten((args, kwargs))
|
|
in_avals = map(get_shaped_aval, flat_args)
|
|
jaxpr_, consts, out_tree = initial_style_jaxpr(f, in_tree, in_avals)
|
|
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
|
# checkify:
|
|
flat_args = jtu.tree_leaves((args, kwargs))
|
|
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error,
|
|
*consts, *flat_args)
|
|
return error, jtu.tree_unflatten(out_tree, out_flat)
|
|
return checked_fun
|
|
|
|
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
|
"""Check a predicate, add an error with msg if predicate is False.
|
|
|
|
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
|
Before staging a function with checks, :func:`~checkify` it!
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added. Can be a format string.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "{x} needs to be positive!", x=x)
|
|
... return 1/x
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(-3.)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
|
|
|
|
"""
|
|
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
|
|
|
|
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
|
|
if not is_scalar_pred(pred):
|
|
prim_name = 'debug_check' if debug else 'check'
|
|
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
|
|
for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
|
|
if not isinstance(arg, (Array, np.ndarray)):
|
|
raise TypeError('Formatting arguments to checkify.check need to be '
|
|
'PyTrees of arrays, but got '
|
|
f'{repr(arg)} of type {type(arg)}.')
|
|
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
|
|
error = assert_func(init_error, jnp.logical_not(pred), new_error)
|
|
_check_error(error, debug=debug)
|
|
|
|
def _check_error(error, *, debug=False):
|
|
if any(map(np.shape, error._pred.values())):
|
|
error = _reduce_any_error(error)
|
|
err_args, tree_def = tree_flatten(error)
|
|
|
|
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
|
|
|
|
|
|
def is_scalar_pred(pred) -> bool:
|
|
return (isinstance(pred, bool) or
|
|
isinstance(pred, Array) and pred.shape == () and
|
|
pred.dtype == jnp.dtype('bool'))
|
|
|
|
|
|
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
|
"""Check a predicate when running under checkify, otherwise is a no-op.
|
|
|
|
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
|
|
otherwise the check will be dropped.
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.debug_check(x!=0, "cannot be zero!")
|
|
... return x
|
|
>>> _ = f(0) # running without checkify means no debug_check is run.
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: cannot be zero!
|
|
|
|
"""
|
|
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
|
|
|
|
|
|
def check_error(error: Error) -> None:
|
|
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
|
|
|
The semantics of this function are equivalent to:
|
|
|
|
>>> def check_error(err: Error) -> None:
|
|
... err.throw() # can raise ValueError
|
|
|
|
But unlike that implementation, ``check_error`` can be functionalized using
|
|
the :func:`~checkify` transformation.
|
|
|
|
This function is similar to :func:`~check` but with a different signature: whereas
|
|
:func:`~check` takes as arguments a boolean predicate and a new error message
|
|
string, this function takes an ``Error`` value as argument. Both :func:`~check`
|
|
and this function raise a Python Exception on failure (a side-effect), and
|
|
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
|
|
:func:`~jax.lax.scan`, etc. Both also can
|
|
be functionalized by using :func:`~checkify`.
|
|
|
|
But unlike :func:`~check`, this function is like a direct inverse of
|
|
:func:`~checkify`:
|
|
whereas :func:`~checkify` takes as input a function which
|
|
can raise a Python
|
|
Exception and produces a new function without that effect but which produces
|
|
an ``Error`` value as output, this ``check_error`` function can accept an
|
|
``Error`` value as input and can produce the side-effect of raising an
|
|
Exception. That is, while :func:`~checkify` goes from
|
|
functionalizable Exception
|
|
effect to error value, this ``check_error`` goes from error value to
|
|
functionalizable Exception effect.
|
|
|
|
``check_error`` is useful when you want to turn checks represented by an
|
|
``Error`` value (produced by functionalizing ``checks`` via
|
|
:func:`~checkify`) back into Python Exceptions.
|
|
|
|
Args:
|
|
error: Error to check.
|
|
|
|
For example, you might want to functionalize part of your program through
|
|
checkify, stage out your functionalized code through :func:`~jax.jit`, then
|
|
re-inject your error value outside of the :func:`~jax.jit`:
|
|
|
|
>>> import jax
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "must be positive!")
|
|
... return x
|
|
>>> def with_inner_jit(x):
|
|
... checked_f = checkify.checkify(f)
|
|
... # a checkified function can be jitted
|
|
... error, out = jax.jit(checked_f)(x)
|
|
... checkify.check_error(error)
|
|
... return out
|
|
>>> _ = with_inner_jit(1) # no failed check
|
|
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.JaxRuntimeError: must be positive!
|
|
>>> # can re-checkify
|
|
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
|
"""
|
|
if not isinstance(error, Error):
|
|
raise ValueError('check_error takes an Error as argument, '
|
|
f'got type {type(error)} instead.')
|
|
_check_error(error, debug=False)
|