mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
1271 lines
49 KiB
Python
1271 lines
49 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import dataclasses
|
|
import functools
|
|
import itertools as it
|
|
import types
|
|
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
|
|
|
|
import jax
|
|
from jax import core
|
|
from jax import lax
|
|
from jax import linear_util as lu
|
|
from jax._src import prng
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src.config import config
|
|
from jax._src.lax import control_flow as cf
|
|
from jax._src.sharding import OpShardingSharding
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
|
|
safe_zip)
|
|
from jax.api_util import flatten_fun
|
|
from jax.api_util import flatten_fun_nokwargs
|
|
from jax.experimental import maps
|
|
from jax.experimental import pjit
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.tree_util import tree_flatten
|
|
from jax.tree_util import tree_map
|
|
from jax.tree_util import tree_unflatten
|
|
import jax.numpy as jnp
|
|
import jax.tree_util as jtu
|
|
import numpy as np
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Bool = Union[bool, Array]
|
|
Int = Union[int, Array]
|
|
ErrorCategory = Type['JaxException']
|
|
Payload = List[Union[np.ndarray, Array]]
|
|
PyTreeDef = jtu.PyTreeDef
|
|
|
|
## Utils
|
|
|
|
def popattr(obj, attrname):
|
|
val = getattr(obj, attrname)
|
|
delattr(obj, attrname)
|
|
return val
|
|
|
|
def setnewattr(obj, name, val):
|
|
sentinel = object()
|
|
assert getattr(obj, name, sentinel) is sentinel
|
|
setattr(obj, name, val)
|
|
|
|
# Concrete errors
|
|
|
|
class JaxException(Exception):
|
|
"""Python exception which can contain an error message with JAX run-time info."""
|
|
|
|
def __init__(self, traceback_info):
|
|
self.traceback_info = traceback_info
|
|
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
|
|
# self.with_traceback(self.traceback_info)
|
|
|
|
def __init_subclass__(cls):
|
|
jtu.register_pytree_node_class(cls)
|
|
|
|
def tree_flatten(self):
|
|
return ([], self.traceback_info)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
del payload
|
|
return cls(metadata)
|
|
|
|
def get_effect_type(self) -> core.Effect:
|
|
pass
|
|
|
|
|
|
@functools.total_ordering
|
|
@dataclasses.dataclass(eq=True, frozen=True)
|
|
class ErrorEffect:
|
|
error_type: Type[JaxException]
|
|
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
|
|
|
|
def __post_init__(self):
|
|
cf.allowed_effects.add(self)
|
|
mlir.lowerable_effects.add(self)
|
|
|
|
def __lt__(self, other: 'ErrorEffect'):
|
|
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
|
for sd in x.shape_dtypes)
|
|
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
|
|
return (unpack(self) < unpack(other))
|
|
|
|
|
|
class DivisionByZeroError(JaxException):
|
|
|
|
def __str__(self):
|
|
return f'division by zero at {self.traceback_info}'
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(DivisionByZeroError, ())
|
|
|
|
class NaNError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
|
|
def tree_flatten(self):
|
|
return ([], (self.traceback_info, self.prim))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, _):
|
|
return cls(*metadata)
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(NaNError, ())
|
|
|
|
def __str__(self):
|
|
return f'nan generated by primitive: {self.prim} at {self.traceback_info}'
|
|
|
|
class OOBError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
self.operand_shape = operand_shape
|
|
self._payload = payload
|
|
|
|
def tree_flatten(self):
|
|
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
return cls(*metadata, payload[0])
|
|
|
|
def __str__(self):
|
|
return (f'out-of-bounds indexing for array of '
|
|
f'shape {self.operand_shape}: '
|
|
f'index {self._payload[0]} is out of bounds for axis '
|
|
f'{self._payload[1]} with size {self._payload[2]}. '
|
|
f'Failed at {self.traceback_info}')
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
|
|
|
|
class FailedCheckError(JaxException):
|
|
|
|
def __init__(self, traceback_info, fmt_string, *a, **k):
|
|
super().__init__(traceback_info)
|
|
self.fmt_string = fmt_string
|
|
self.args = a
|
|
self.kwargs = k
|
|
|
|
def tree_flatten(self):
|
|
return ((jnp.array([], jnp.int32), self.args, self.kwargs),
|
|
(self.traceback_info, self.fmt_string))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
_, args, kwargs = payload
|
|
return cls(*metadata, *args, **kwargs)
|
|
|
|
def __str__(self):
|
|
return (self.fmt_string.format(*self.args, **self.kwargs)
|
|
+ f' (check failed at {self.traceback_info})')
|
|
|
|
def get_effect_type(self):
|
|
vals = jtu.tree_leaves((self.args, self.kwargs))
|
|
return ErrorEffect(
|
|
FailedCheckError,
|
|
# Need a 0-size array here for data-dependence.
|
|
(jax.ShapeDtypeStruct((0,), jnp.int32),
|
|
*tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)))
|
|
|
|
@dataclasses.dataclass
|
|
class BatchedError(JaxException):
|
|
error_mapping: Dict[Tuple[int, ...], JaxException]
|
|
|
|
def __post_init__(self):
|
|
traceback_info = list(self.error_mapping.values())[0].traceback_info
|
|
super().__init__(traceback_info)
|
|
|
|
|
|
def __str__(self):
|
|
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
|
|
for idx, e in self.error_mapping.items())
|
|
|
|
|
|
# Error Value
|
|
|
|
@jtu.register_pytree_node_class
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Error:
|
|
_pred: Dict[ErrorEffect, Bool]
|
|
_code: Dict[ErrorEffect, Int]
|
|
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
|
|
_payload: Dict[ErrorEffect, Payload]
|
|
|
|
def get(self) -> Optional[str]:
|
|
"""Returns error message if error happened, None if no error happened."""
|
|
exp = self.get_exception()
|
|
if exp is not None:
|
|
return str(exp)
|
|
return None
|
|
|
|
def get_exception(self) -> Optional[JaxException]:
|
|
"""Returns Python exception if error happened, None if no error happened."""
|
|
if any(map(np.shape, self._pred.values())):
|
|
return self._get_batched_exception()
|
|
else:
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect]:
|
|
if min_code is None or code < min_code:
|
|
min_code = code
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
|
|
self._payload[cur_effect])
|
|
return None
|
|
|
|
def throw(self):
|
|
_check_error(self)
|
|
|
|
def __str__(self):
|
|
return f'Error({self.get()})'
|
|
|
|
# Internal helpers
|
|
|
|
def _get_batched_exception(self):
|
|
shape = np.shape(list(self._pred.values())[0])
|
|
error_mapping = {}
|
|
for idx in np.ndindex(*shape):
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect][idx]: # type: ignore
|
|
if min_code is None or code[idx] < min_code:
|
|
min_code = code[idx] # type: ignore
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
|
|
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
|
|
error_mapping[idx] = jax_error
|
|
return BatchedError(error_mapping)
|
|
|
|
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
|
|
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
|
|
new_codes = {**self._code, **{effect_type: code}} # type: ignore
|
|
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
|
|
new_metadata = {**self._metadata, **metadata}
|
|
return Error(new_errs, new_codes, new_metadata, new_payload)
|
|
|
|
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
|
|
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
|
|
new_err = self._pred.copy()
|
|
new_code = self._code.copy()
|
|
new_payload = self._payload.copy()
|
|
for effect in effects:
|
|
if effect not in self._pred.keys():
|
|
new_err[effect] = False
|
|
new_payload[effect] = list(
|
|
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
|
|
# The error value associated with this effect will never become True, so
|
|
# we don't need to set a meaningful code.
|
|
new_code[effect] = -1
|
|
return Error(new_err, new_code, self._metadata, new_payload)
|
|
|
|
def _replace(self, *args, **kwargs):
|
|
return dataclasses.replace(self, *args, **kwargs)
|
|
|
|
# PyTree methods
|
|
|
|
def tree_flatten(self):
|
|
return ((self._pred, self._code, self._payload), (self._metadata))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, data):
|
|
pred, code, payload = data
|
|
return cls(pred, code, metadata, payload)
|
|
|
|
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
|
|
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
|
|
|
|
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
|
|
code = next_code()
|
|
effect_type = new_error.get_effect_type()
|
|
new_payload, new_metadata = tree_flatten(new_error)
|
|
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
|
|
|
|
def update_error(error, pred, code, metadata, payload, effect_type):
|
|
err_of_type = error._pred.get(effect_type, False)
|
|
out_err = err_of_type | pred
|
|
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
|
|
cur_payload = error._payload.get(effect_type, None)
|
|
if cur_payload is not None:
|
|
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
|
|
else:
|
|
out_payload = payload
|
|
return error._update(effect_type, out_err, out_code, metadata, out_payload)
|
|
|
|
|
|
## Checkify transformation for plumbing functional error values.
|
|
|
|
class CheckifyTracer(core.Tracer):
|
|
def __init__(self, trace, val):
|
|
self._trace = trace
|
|
self.val = val
|
|
aval = property(lambda self: core.get_aval(self.val))
|
|
full_lower = lambda self: self
|
|
|
|
class CheckifyTrace(core.Trace):
|
|
pure = lift = lambda self, val: CheckifyTracer(self, val)
|
|
|
|
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
|
|
enabled_errors: FrozenSet['ErrorCategory']) -> None:
|
|
self.main = main
|
|
self.level = main.level
|
|
self.sublevel = sublevel
|
|
self.main.enabled_errors = enabled_errors
|
|
|
|
def sublift(self, tracer):
|
|
return CheckifyTracer(self, tracer.val)
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
in_vals = [t.val for t in tracers]
|
|
rule = error_checks.get(primitive)
|
|
if rule:
|
|
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
|
|
*in_vals, **params)
|
|
else:
|
|
out = primitive.bind(*in_vals, **params)
|
|
if primitive.multiple_results:
|
|
return [CheckifyTracer(self, x) for x in out]
|
|
else:
|
|
return CheckifyTracer(self, out)
|
|
|
|
def process_call(self, primitive, f, tracers, params):
|
|
in_vals = [t.val for t in tracers]
|
|
e = popattr(self.main, 'error')
|
|
flat_vals, in_tree = tree_flatten((e, *in_vals))
|
|
f = checkify_subtrace(f, self.main)
|
|
f, out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
if 'donated_invars' in params:
|
|
params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)),
|
|
*params['donated_invars']))
|
|
all_vals = primitive.bind(f, *flat_vals, **params)
|
|
error, *out_vals = tree_unflatten(out_tree(), all_vals)
|
|
setnewattr(self.main, 'error', error)
|
|
return [CheckifyTracer(self, x) for x in out_vals]
|
|
|
|
def process_map(self, primitive, f, tracers, params):
|
|
in_vals = [t.val for t in tracers]
|
|
e = popattr(self.main, 'error')
|
|
flat_vals, in_tree = tree_flatten((e, *in_vals))
|
|
num_error_vals = len(jtu.tree_leaves(e))
|
|
f = checkify_subtrace(f, self.main)
|
|
f, out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
|
|
@as_hashable_function(closure=params['out_axes_thunk'])
|
|
def new_out_axes_thunk():
|
|
out_val_axes = params['out_axes_thunk']()
|
|
out_err_num = out_tree().num_leaves - len(out_val_axes)
|
|
return (*(0,)*out_err_num, *out_val_axes)
|
|
|
|
params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']),
|
|
out_axes_thunk=new_out_axes_thunk,
|
|
donated_invars=(*(False,)*num_error_vals, *params['donated_invars']))
|
|
all_vals = primitive.bind(f, *flat_vals, **params_)
|
|
error, *out_vals = tree_unflatten(out_tree(), all_vals)
|
|
error = _reduce_any_error(error)
|
|
setnewattr(self.main, 'error', error)
|
|
return [CheckifyTracer(self, x) for x in out_vals]
|
|
|
|
def post_process_call(self, primitive, tracers, params):
|
|
vals = [t.val for t in tracers]
|
|
main = self.main
|
|
e = popattr(main, 'error')
|
|
err_leaves, err_tree = tree_flatten(e)
|
|
setnewattr(main, 'err_tree', err_tree)
|
|
def todo(vals):
|
|
err_tree = popattr(main, 'err_tree')
|
|
err_vals, vals = split_list(vals, [err_tree.num_leaves])
|
|
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
|
|
trace = main.with_cur_sublevel()
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
|
return (*err_leaves, *vals), todo
|
|
|
|
def post_process_map(self, primitive, tracers, params):
|
|
vals = [t.val for t in tracers]
|
|
main = self.main
|
|
e = popattr(main, 'error')
|
|
err_leaves, err_tree = tree_flatten(e)
|
|
num_err_leaves = len(err_leaves)
|
|
setnewattr(main, 'err_tree', err_tree)
|
|
def todo(vals):
|
|
err_tree = popattr(main, 'err_tree')
|
|
err_vals, vals = split_list(vals, [err_tree.num_leaves])
|
|
error = tree_unflatten(err_tree, err_vals)
|
|
error = _reduce_any_error(error)
|
|
setnewattr(main, 'error', error)
|
|
trace = main.with_cur_sublevel()
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
|
def out_axes_transform(out_axes):
|
|
return (*(0,)*num_err_leaves, *out_axes)
|
|
return (*err_leaves, *vals), (todo, out_axes_transform)
|
|
|
|
def process_custom_jvp_call(self, prim, f, jvp, tracers):
|
|
in_vals = [t.val for t in tracers]
|
|
e = popattr(self.main, 'error')
|
|
err_vals, err_tree = tree_flatten(e)
|
|
flat_vals, in_tree = tree_flatten((e, *in_vals))
|
|
num_error_vals = len(err_vals)
|
|
f = checkify_subtrace(f, self.main)
|
|
f, f_out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main,
|
|
num_error_vals, err_tree)
|
|
all_outs = prim.bind(f, jvp, *flat_vals)
|
|
fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree)
|
|
if fst:
|
|
out_err, *out_vals = tree_unflatten(out_tree, all_outs)
|
|
else:
|
|
err_vals, out_vals = split_list(all_outs, [num_error_vals])
|
|
# forward input error values to output
|
|
out_err = tree_unflatten(out_tree, err_vals)
|
|
setattr(self.main, 'error', out_err)
|
|
return [CheckifyTracer(self, x) for x in out_vals]
|
|
|
|
def post_process_custom_jvp_call(self, tracers, jvp_was_run):
|
|
if jvp_was_run:
|
|
msg = ('support for custom_jvp rules which close over checkify values is '
|
|
'not implemented. If you see this, open an issue at '
|
|
'https://github.com/google/jax/issues!')
|
|
raise NotImplementedError(msg)
|
|
vals = [t.val for t in tracers]
|
|
main = self.main
|
|
e = popattr(main, 'error')
|
|
err_leaves, err_tree = tree_flatten(e)
|
|
def todo(vals):
|
|
err_vals, vals = split_list(vals, [len(err_leaves)])
|
|
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
|
|
trace = main.with_cur_sublevel()
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
|
return (*err_leaves, *vals), todo
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
in_vals = [t.val for t in tracers]
|
|
e = popattr(self.main, 'error')
|
|
err_vals, err_tree = tree_flatten(e)
|
|
flat_vals, in_tree = tree_flatten((e, *in_vals))
|
|
num_error_vals = len(err_vals)
|
|
|
|
fun = checkify_subtrace(fun, self.main)
|
|
fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree)
|
|
fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main,
|
|
err_tree, num_error_vals)
|
|
|
|
all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees)
|
|
fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree)
|
|
if fst:
|
|
error, *out = tree_unflatten(out_tree, all_out_vals)
|
|
else:
|
|
_, out = split_list(all_out_vals, [num_error_vals])
|
|
# forward input error values to output
|
|
error = tree_unflatten(err_tree, err_vals)
|
|
setattr(self.main, 'error', error)
|
|
return [CheckifyTracer(self, x) for x in out]
|
|
|
|
def _reduce_any_error(error: Error):
|
|
out_error = init_error
|
|
for error_effect in error._pred.keys():
|
|
errs, codes, payloads = (error._pred[error_effect],
|
|
error._code[error_effect],
|
|
error._payload[error_effect])
|
|
reduced_idx = jnp.argsort(errs)[-1]
|
|
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
|
|
(errs, codes, payloads))
|
|
out_error = out_error._update(error_effect, pred, code, {}, payload)
|
|
|
|
out_error = out_error._replace(_metadata=error._metadata)
|
|
return out_error
|
|
|
|
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
|
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
|
|
|
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
|
|
*args):
|
|
fun = checkify_subtrace(fun)
|
|
fun = checkify_traceable(fun, enabled_errors)
|
|
error, *outvals = fun.call_wrapped(init_error, *args)
|
|
return error, outvals
|
|
|
|
@lu.transformation
|
|
def checkify_traceable(enabled_errors, error, *args):
|
|
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
|
|
outs = yield (main, error, *args), {}
|
|
del main
|
|
yield outs
|
|
|
|
@lu.transformation
|
|
def checkify_subtrace(main, error, *args):
|
|
setnewattr(main, 'error', error)
|
|
trace = main.with_cur_sublevel()
|
|
in_tracers = [CheckifyTracer(trace, x) for x in args]
|
|
out = yield in_tracers, {}
|
|
out_tracers = map(trace.full_raise, out)
|
|
out_vals = [t.val for t in out_tracers]
|
|
error = main.error
|
|
del main.error
|
|
yield (error, *out_vals)
|
|
|
|
@lu.transformation_with_aux
|
|
def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args):
|
|
# Like checkify_subtrace, but used specifically on the custom JVP rules
|
|
# associated with a custom_jvp. This code is called in the context of a
|
|
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
|
|
# flattened into a single args tuple, and similarly must produce flattened
|
|
# primal and tangent outputs. Both primals and tangents include error values,
|
|
# but the tangent error values are trivially zero.
|
|
# The types to have in mind are:
|
|
# jvp : (a -> b) -> (a, T a) -> (b, T b)
|
|
# checkify : (a -> b) -> a -> Err b
|
|
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
|
|
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
|
|
# where the other Err' components are trivial (of float0 dtype).
|
|
# Semantically, we don't add checks to the JVP rule. To check the result of a
|
|
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
|
|
# just forwards the input error and code (and trivial tangents) to the output.
|
|
del main
|
|
n, ragged = divmod(len(args), 2)
|
|
assert not ragged
|
|
err_primals, primals = split_list(args[:n], [num_error_vals])
|
|
err_tangents, tangents = split_list(args[n:], [num_error_vals])
|
|
outs = yield (*primals, *tangents), {}
|
|
m, ragged = divmod(len(outs), 2)
|
|
assert not ragged
|
|
out_primals, out_tangents = outs[:m], outs[m:]
|
|
yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree
|
|
|
|
@lu.transformation_with_aux
|
|
def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args):
|
|
del main
|
|
# We don't add any checks; just drop input error values.
|
|
_, args = split_list(args, [num_error_vals])
|
|
outs = yield args, {}
|
|
yield outs, err_tree
|
|
|
|
@lu.transformation_with_aux
|
|
def query_error_effects(*args):
|
|
(error, *outs) = yield args, {}
|
|
yield (error, *outs), set(error._pred.keys())
|
|
|
|
def checkify_jaxpr(jaxpr, error,
|
|
enabled_errors) -> Tuple[core.ClosedJaxpr,
|
|
Tuple[PyTreeDef,
|
|
FrozenSet[ErrorEffect]]]:
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
|
|
|
|
def checkify_fun_to_jaxpr(
|
|
f, error, enabled_errors,
|
|
in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]:
|
|
flat_error_vals, in_tree = tree_flatten(error)
|
|
f = checkify_subtrace(f)
|
|
f = checkify_traceable(f, enabled_errors)
|
|
f, error_effect = query_error_effects(f)
|
|
in_tree = jtu.tree_structure((error, *in_avals))
|
|
f, out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)),
|
|
flat_error_vals)
|
|
avals_in = [*err_vals, *in_avals]
|
|
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect()))
|
|
|
|
|
|
|
|
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
|
"""Check a predicate, add an error with msg if predicate is False.
|
|
|
|
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
|
Before staging a function with checks, :func:`~checkify` it!
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added. Can be a format string.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "{x} needs to be positive!", x=x)
|
|
... return 1/x
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(-3.)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
|
|
|
|
"""
|
|
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
|
|
|
|
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
|
|
if not is_scalar_pred(pred):
|
|
prim_name = 'debug_check' if debug else 'check'
|
|
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
|
|
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
|
|
error = assert_func(init_error, jnp.logical_not(pred), new_error)
|
|
_check_error(error, debug=debug)
|
|
|
|
def _check_error(error, *, debug=False):
|
|
error = tree_map(core.raise_as_much_as_possible, error)
|
|
if any(map(np.shape, error._pred.values())):
|
|
error = _reduce_any_error(error)
|
|
err_args, tree_def = tree_flatten(error)
|
|
|
|
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
|
|
|
|
|
|
def is_scalar_pred(pred) -> bool:
|
|
return (isinstance(pred, bool) or
|
|
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
|
pred.dtype == jnp.dtype('bool'))
|
|
|
|
|
|
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
|
"""Check a predicate when running under checkify, otherwise is a no-op.
|
|
|
|
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
|
|
otherwise the check will be dropped.
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.debug_check(x!=0, "cannot be zero!")
|
|
... return x
|
|
>>> _ = f(0) # running without checkify means no debug_check is run.
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: cannot be zero!
|
|
|
|
"""
|
|
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
|
|
|
|
|
|
def check_error(error: Error) -> None:
|
|
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
|
|
|
The semantics of this function are equivalent to:
|
|
|
|
>>> def check_error(err: Error) -> None:
|
|
... err.throw() # can raise ValueError
|
|
|
|
But unlike that implementation, ``check_error`` can be functionalized using
|
|
the :func:`~checkify` transformation.
|
|
|
|
This function is similar to :func:`~check` but with a different signature: whereas
|
|
:func:`~check` takes as arguments a boolean predicate and a new error message
|
|
string, this function takes an ``Error`` value as argument. Both :func:`~check`
|
|
and this function raise a Python Exception on failure (a side-effect), and
|
|
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
|
|
:func:`~jax.lax.scan`, etc. Both also can
|
|
be functionalized by using :func:`~checkify`.
|
|
|
|
But unlike :func:`~check`, this function is like a direct inverse of
|
|
:func:`~checkify`:
|
|
whereas :func:`~checkify` takes as input a function which
|
|
can raise a Python
|
|
Exception and produces a new function without that effect but which produces
|
|
an ``Error`` value as output, this ``check_error`` function can accept an
|
|
``Error`` value as input and can produce the side-effect of raising an
|
|
Exception. That is, while :func:`~checkify` goes from
|
|
functionalizable Exception
|
|
effect to error value, this ``check_error`` goes from error value to
|
|
functionalizable Exception effect.
|
|
|
|
``check_error`` is useful when you want to turn checks represented by an
|
|
``Error`` value (produced by functionalizing ``checks`` via
|
|
:func:`~checkify`) back into Python Exceptions.
|
|
|
|
Args:
|
|
error: Error to check.
|
|
|
|
For example, you might want to functionalize part of your program through
|
|
checkify, stage out your functionalized code through :func:`~jax.jit`, then
|
|
re-inject your error value outside of the :func:`~jax.jit`:
|
|
|
|
>>> import jax
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "must be positive!")
|
|
... return x
|
|
>>> def with_inner_jit(x):
|
|
... checked_f = checkify.checkify(f)
|
|
... # a checkified function can be jitted
|
|
... error, out = jax.jit(checked_f)(x)
|
|
... checkify.check_error(error)
|
|
... return out
|
|
>>> _ = with_inner_jit(1) # no failed check
|
|
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.JaxRuntimeError: must be positive!
|
|
>>> # can re-checkify
|
|
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
|
"""
|
|
if not isinstance(error, Error):
|
|
raise ValueError('check_error takes an Error as argument, '
|
|
f'got type {type(error)} instead.')
|
|
|
|
_check_error(error, debug=False)
|
|
|
|
|
|
## check primitive
|
|
|
|
check_p = core.Primitive('check')
|
|
check_p.multiple_results = True # zero results
|
|
|
|
# TODO(lenamartens): inherit from Exception instead of ValueError.
|
|
class JaxRuntimeError(ValueError):
|
|
pass
|
|
|
|
@check_p.def_impl
|
|
def check_impl(*args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
error = tree_unflatten(err_tree, args)
|
|
exc = error.get_exception()
|
|
if exc:
|
|
raise JaxRuntimeError(str(exc)) from exc
|
|
return []
|
|
|
|
@check_p.def_effectful_abstract_eval
|
|
def check_abstract_eval(*args, err_tree, debug):
|
|
del debug
|
|
return [], set(tree_unflatten(err_tree, args)._pred.keys())
|
|
|
|
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
|
functionalization_error = ValueError(
|
|
'Cannot abstractly evaluate a checkify.check which was not'
|
|
' functionalized. This probably means you tried to stage'
|
|
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
|
' through `checkify.checkify`.'
|
|
)
|
|
|
|
def check_lowering_rule(ctx, *args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
|
raise functionalization_error
|
|
|
|
out_op, _, keep_alive = mlir.emit_python_callback(
|
|
ctx, callback=functools.partial(python_err, err_tree),
|
|
token=None,
|
|
operands=args,
|
|
operand_avals=list(ctx.avals_in),
|
|
result_avals=list(ctx.avals_out),
|
|
has_side_effect=True)
|
|
ctx.module_context.add_keepalive(keep_alive)
|
|
return out_op
|
|
|
|
def check_lowering_rule_unsupported(*a, debug, **k):
|
|
if debug:
|
|
return []
|
|
raise functionalization_error
|
|
|
|
def python_err(err_tree, *args):
|
|
error = tree_unflatten(err_tree, args)
|
|
_check_error(error)
|
|
return []
|
|
|
|
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
|
|
platform='tpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='cpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='gpu')
|
|
|
|
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
|
|
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
|
|
if dim is not batching.not_mapped)
|
|
batched_args = (batching.bdim_at_front(a, d, size)
|
|
for a, d in zip(batched_args, batch_dims))
|
|
err = tree_unflatten(err_tree, batched_args)
|
|
_check_error(err, debug=debug)
|
|
return [], []
|
|
batching.primitive_batchers[check_p] = check_batching_rule
|
|
|
|
def check_jvp_rule(primals, _, *, err_tree, debug):
|
|
# Check primals, discard tangents.
|
|
check_p.bind(*primals, err_tree=err_tree, debug=debug)
|
|
return [], []
|
|
ad.primitive_jvps[check_p] = check_jvp_rule
|
|
|
|
## checkify rules
|
|
|
|
def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]:
|
|
# TODO(lenamartens): use c++ version from XLA?
|
|
tb = None
|
|
import inspect
|
|
for frame_info in inspect.stack():
|
|
frame = frame_info.frame
|
|
if skip_frames:
|
|
skip_frames -= 1
|
|
elif not traceback_util.include_frame(frame):
|
|
continue
|
|
else:
|
|
tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno)
|
|
return tb
|
|
|
|
def summary() -> str:
|
|
return str(source_info_util.summarize(source_info_util.current()))
|
|
|
|
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
|
out = prim.bind(*in_vals, **params)
|
|
err = check_nans(prim, error, enabled_errors, out)
|
|
return out, err
|
|
|
|
def check_nans(prim, error, enabled_errors, out):
|
|
if NaNError not in enabled_errors:
|
|
return error
|
|
|
|
def isnan(x):
|
|
if isinstance(x, prng.PRNGKeyArray):
|
|
return False
|
|
return jnp.any(jnp.isnan(x))
|
|
|
|
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
|
|
if prim.multiple_results else isnan(out))
|
|
return assert_func(error, any_nans, NaNError(summary(), prim.name))
|
|
|
|
|
|
# All primitives which can generate a NaN.
|
|
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
|
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
|
|
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
|
|
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
|
|
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
|
|
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
|
|
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
|
|
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
|
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
|
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
|
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
|
lax.reduce_sum_p, lax.reduce_window_p,
|
|
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
|
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
|
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
|
|
|
|
for prim in nan_primitives:
|
|
error_checks[prim] = functools.partial(nan_error_check, prim)
|
|
|
|
|
|
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
indices_are_sorted, mode, fill_value):
|
|
out = lax.gather_p.bind(
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return out, error
|
|
|
|
# compare to OOB masking logic in lax._gather_translation_rule
|
|
dnums = dimension_numbers
|
|
operand_dims = np.array(operand.shape)
|
|
num_batch_dims = len(start_indices.shape) - 1
|
|
|
|
upper_bound = operand_dims[np.array(dnums.start_index_map)]
|
|
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
|
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
|
|
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
|
|
|
|
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
|
|
return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
|
|
error_checks[lax.gather_p] = gather_error_check
|
|
|
|
def div_error_check(error, enabled_errors, x, y):
|
|
"""Checks for division by zero and NaN."""
|
|
if DivisionByZeroError in enabled_errors:
|
|
any_zero = jnp.any(jnp.equal(y, 0))
|
|
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
|
|
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
|
error_checks[lax.div_p] = div_error_check
|
|
|
|
def oob_payload(oob_mask, indices, dims_map, operand_shape):
|
|
# Get first OOB index, axis and axis size so it can be added to the error msg.
|
|
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
|
|
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
|
|
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
|
|
oob_axis_size = jnp.array(operand_shape)[oob_axis]
|
|
oob_index = jnp.ravel(indices)[flat_idx]
|
|
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
|
|
return payload
|
|
|
|
def scatter_oob(operand, indices, updates, dnums):
|
|
# Ref: see clamping code used in scatter_translation_rule
|
|
slice_sizes = []
|
|
pos = 0
|
|
for i in range(len(operand.shape)):
|
|
if i in dnums.inserted_window_dims:
|
|
slice_sizes.append(1)
|
|
else:
|
|
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
|
pos += 1
|
|
|
|
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
|
for i in dnums.scatter_dims_to_operand_dims],
|
|
np.int64)
|
|
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
|
(len(indices.shape) - 1,))
|
|
|
|
lower_oob = jnp.less(indices, 0)
|
|
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
|
|
oob_mask = jnp.logical_or(lower_oob, upper_oob)
|
|
payload = oob_payload(oob_mask, indices,
|
|
dnums.scatter_dims_to_operand_dims, operand.shape)
|
|
return jnp.any(oob_mask), payload
|
|
|
|
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
|
|
*, update_jaxpr, update_consts, dimension_numbers,
|
|
indices_are_sorted, unique_indices, mode):
|
|
"""Checks if indices are within bounds and update does not generate NaN."""
|
|
out = prim.bind(
|
|
operand, indices, updates, update_jaxpr=update_jaxpr,
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
|
mode=mode)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return out, error
|
|
|
|
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
|
|
oob_error = OOBError(summary(), prim.name, operand.shape, payload)
|
|
error = assert_func(error, out_of_bounds, oob_error)
|
|
return out, check_nans(prim, error, enabled_errors, out)
|
|
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
|
|
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_add_p)
|
|
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_mul_p)
|
|
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_min_p)
|
|
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_max_p)
|
|
|
|
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
|
|
_, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error,
|
|
enabled_errors)
|
|
for jxpr in branches)
|
|
_, effects = unzip2(out_trees_and_effects)
|
|
|
|
merged_error = error._add_placeholder_effects(set().union(*effects))
|
|
new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error,
|
|
enabled_errors)
|
|
for jxpr in branches)
|
|
out_trees, _ = unzip2(out_trees_and_effects)
|
|
|
|
flat_error, _ = tree_flatten(merged_error)
|
|
new_linear = (*[False] * len(flat_error), *linear)
|
|
err_and_outs = lax.cond_p.bind(
|
|
index, *flat_error, *ops,
|
|
branches=tuple(new_branches), linear=new_linear)
|
|
|
|
# we need to merge metadata across out_trees (a tuple)
|
|
# maybe there's a better way to do this, but we can use the outs
|
|
# to unflatten all trees.
|
|
err0, *out = tree_unflatten(out_trees[0], err_and_outs)
|
|
merged_metadata = err0._metadata
|
|
for tr in out_trees[1:]:
|
|
err, *_ = tree_unflatten(tr, err_and_outs)
|
|
merged_metadata = {**merged_metadata, **err._metadata}
|
|
return out, err0._replace(_metadata=merged_metadata)
|
|
error_checks[lax.cond_p] = cond_error_check
|
|
|
|
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
|
num_consts, num_carry, linear, unroll):
|
|
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
|
_, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors)
|
|
merged_error = error._add_placeholder_effects(effects)
|
|
checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors)
|
|
|
|
flat_error_vals, _ = tree_flatten(merged_error)
|
|
tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))
|
|
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
|
|
new_linear = (*[False] * len(flat_error_vals), *linear)
|
|
new_in_flat = [*consts, *flat_error_vals, *carry, *xs]
|
|
err_and_out = lax.scan_p.bind(
|
|
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
|
num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals),
|
|
linear=new_linear, unroll=unroll)
|
|
err, *out = tree_unflatten(out_tree, err_and_out)
|
|
return out, err
|
|
|
|
error_checks[lax.scan_p] = scan_error_check
|
|
|
|
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
|
|
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
|
body_f = core.jaxpr_as_fun(body_jaxpr)
|
|
def new_body_f(*vals):
|
|
out = body_f(*vals)
|
|
# This checks if the next cond application will error
|
|
_ = cond_f(*c_consts, *out)
|
|
return out
|
|
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
|
|
body_jaxpr.in_avals)
|
|
|
|
def ignore_error_output_jaxpr(jaxpr, num_error_vals):
|
|
"""Constructs a checked jaxpr which does not output its error value."""
|
|
consts = jaxpr.consts
|
|
jaxpr = jaxpr.jaxpr
|
|
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
|
|
return core.ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
|
cond_jaxpr, body_nconsts, body_jaxpr):
|
|
if cond_jaxpr.out_avals[0].shape:
|
|
# TODO(lenamartens, sharadmv): support batched while.
|
|
raise ValueError('Checkify does not support batched while-loops '
|
|
'(checkify-of-vmap-of-while). \nHint: if possible, move '
|
|
'the vmap to the outer level to get '
|
|
'vmap-of-checkify-of-while.')
|
|
|
|
err_vals, _ = tree_flatten(error)
|
|
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
|
|
|
# Check if the first cond application will error.
|
|
checked_cond_jaxpr, (cond_out_tree, _) = checkify_jaxpr(
|
|
cond_jaxpr, error, enabled_errors)
|
|
outs = core.jaxpr_as_fun(checked_cond_jaxpr)(*err_vals, *c_consts, *carry)
|
|
error, _ = tree_unflatten(cond_out_tree, outs)
|
|
|
|
checked_body_jaxpr_, (_, error_effects) = checkify_while_body_jaxpr(
|
|
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
|
|
# merged error!
|
|
error = error._add_placeholder_effects(error_effects)
|
|
checked_body_jaxpr_, (body_out_tree, _) = checkify_while_body_jaxpr(
|
|
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
|
|
err_vals = jtu.tree_leaves(error)
|
|
num_error_vals = len(err_vals)
|
|
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
|
|
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
|
|
|
checked_cond_jaxpr, _ = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
|
|
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
|
|
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
|
|
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
|
|
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
|
|
|
|
all_out_vals = lax.while_p.bind(
|
|
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
|
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
|
|
# body_out_tree will have all the metadata of cond because it executes a cond!
|
|
# only need to merge metadata on the input error.
|
|
error, *out = tree_unflatten(body_out_tree, all_out_vals)
|
|
return out, error
|
|
error_checks[lax.while_p] = while_loop_error_check
|
|
|
|
|
|
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
|
in_shardings, out_shardings, resource_env,
|
|
donated_invars, name,
|
|
in_positional_semantics, out_positional_semantics):
|
|
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
|
|
enabled_errors)
|
|
out_error = error._add_placeholder_effects(effects)
|
|
|
|
flat_error_vals = jtu.tree_leaves(error)
|
|
num_error_vals = len(flat_error_vals)
|
|
new_vals_in = [*flat_error_vals, *vals_in]
|
|
|
|
sharding = OpShardingSharding.get_replicated(
|
|
list(resource_env.physical_mesh.devices.flat))
|
|
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
|
new_out_shardings = (*[sharding] * len(jtu.tree_leaves(out_error)),
|
|
*out_shardings)
|
|
|
|
if config.jax_array:
|
|
pos_sem = maps._PositionalSemantics.GLOBAL
|
|
else:
|
|
pos_sem = maps._positional_semantics.val
|
|
|
|
if not isinstance(in_positional_semantics, Iterable):
|
|
in_positional_semantics = (in_positional_semantics,)
|
|
if not isinstance(out_positional_semantics, Iterable):
|
|
out_positional_semantics = (out_positional_semantics,)
|
|
new_positional_sems_in = (*[pos_sem] * num_error_vals,
|
|
*in_positional_semantics)
|
|
new_positional_sems_out = (*[pos_sem] * num_error_vals,
|
|
*out_positional_semantics)
|
|
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
|
|
|
|
err_and_out = pjit.pjit_p.bind(
|
|
*new_vals_in,
|
|
jaxpr=checked_jaxpr,
|
|
in_shardings=new_in_shardings,
|
|
out_shardings=new_out_shardings,
|
|
resource_env=resource_env,
|
|
donated_invars=new_donated_invars,
|
|
name=name,
|
|
in_positional_semantics=new_positional_sems_in,
|
|
out_positional_semantics=new_positional_sems_out)
|
|
err, *out = tree_unflatten(out_tree, err_and_out)
|
|
return out, err
|
|
error_checks[pjit.pjit_p] = pjit_error_check
|
|
|
|
|
|
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
|
|
del debug
|
|
new_error = tree_unflatten(err_tree, args)
|
|
# Split up new_error into error to be functionalized if it's included in
|
|
# enabled_errors (=discharged_error) and an error to be defunctionalized if
|
|
# it's not included (=recharged_error)
|
|
discharged_error = error
|
|
recharged_error = init_error
|
|
for error_effect in new_error._pred.keys():
|
|
pred = new_error._pred[error_effect]
|
|
code = new_error._code[error_effect]
|
|
payload = new_error._payload[error_effect]
|
|
if error_effect.error_type in enabled_errors:
|
|
discharged_error = update_error(discharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
else:
|
|
recharged_error = update_error(recharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
|
|
discharged_error = discharged_error._replace(
|
|
_metadata={**new_error._metadata, **discharged_error._metadata})
|
|
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
|
|
# TODO(lenamartens): we actually need to recharge, but this would be a
|
|
# breaking API change so leaving for a follow-up.
|
|
# check_error(recharged_error)
|
|
return [], discharged_error
|
|
error_checks[check_p] = check_discharge_rule
|
|
|
|
|
|
## checkify api
|
|
|
|
user_checks = frozenset({FailedCheckError})
|
|
nan_checks = frozenset({NaNError})
|
|
index_checks = frozenset({OOBError})
|
|
div_checks = frozenset({DivisionByZeroError})
|
|
float_checks = nan_checks | div_checks
|
|
automatic_checks = float_checks | index_checks
|
|
all_checks = automatic_checks | user_checks
|
|
|
|
Out = TypeVar('Out')
|
|
|
|
|
|
def checkify(fun: Callable[..., Out],
|
|
errors: FrozenSet[ErrorCategory] = user_checks
|
|
) -> Callable[..., Tuple[Error, Out]]:
|
|
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
|
|
|
Run-time errors are either user-added :func:`~check` assertions, or
|
|
automatically added checks like NaN checks, depending on the ``errors``
|
|
argument.
|
|
|
|
The returned function will return an Error object `err` along with the output
|
|
of the original function. ``err.get()`` will either return ``None`` (if no
|
|
error occurred) or a string containing an error message. This error message
|
|
will correspond to the first error which occurred. ``err.throw()`` will raise
|
|
a ValueError with the error message if an error occurred.
|
|
|
|
By default only user-added :func:`~check` assertions are enabled. You can
|
|
enable automatic checks through the ``errors`` argument.
|
|
|
|
The automatic check sets which can be enabled, and when an error is generated:
|
|
- ``user_checks``: a :func:`~check` evaluated to False.
|
|
- ``nan_checks``: a floating-point operation generated a NaN value
|
|
as output.
|
|
- ``div_checks``: a division by zero.
|
|
- ``index_checks``: an index was out-of-bounds.
|
|
|
|
Multiple categories can be enabled together by passing in an error `Set` (eg.
|
|
``errors=nan_checks``). Multiple sets can be re-combined (eg.
|
|
``errors=float_checks|user_checks``)
|
|
|
|
Args:
|
|
fun: Callable which can contain user checks (see :func:`~check`).
|
|
errors: A set of ErrorCategory values which defines the set of enabled
|
|
checks. By default only explicit ``checks`` are enabled
|
|
(``user_checks``). You can also for example enable NAN and
|
|
DIV errors by passing the ``float_checks`` set, or for
|
|
example combine multiple sets through set operations
|
|
(``float_checks | user_checks``)
|
|
Returns:
|
|
A function which accepts the same arguments as ``fun`` and returns as output
|
|
a pair where the first element is an ``Error`` value, representing the first
|
|
failed :func:`~check`, and the second element is the original output of
|
|
``fun``.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>>
|
|
>>> @jax.jit
|
|
... def f(x):
|
|
... y = jnp.sin(x)
|
|
... return x+y
|
|
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
|
|
"""
|
|
@traceback_util.api_boundary
|
|
def checked_fun(*args, **kwargs):
|
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
error, out_flat = checkify_flat(f, errors, *args_flat)
|
|
out = tree_unflatten(out_tree(), out_flat)
|
|
return error, out
|
|
return checked_fun
|