2021-10-29 09:23:27 -07:00
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import partial
|
|
|
|
import itertools as it
|
|
|
|
from typing import Union, Optional, Callable, Dict
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
from jax import core
|
|
|
|
from jax import linear_util as lu
|
|
|
|
from jax.api_util import flatten_fun
|
|
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
|
|
|
from jax._src import source_info_util, traceback_util
|
2021-12-16 22:44:05 +00:00
|
|
|
from jax import lax
|
2021-12-06 19:36:35 +00:00
|
|
|
from jax._src.util import as_hashable_function, unzip2, split_list
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
## Utils
|
|
|
|
|
|
|
|
def popattr(obj, attrname):
|
|
|
|
val = getattr(obj, attrname)
|
|
|
|
delattr(obj, attrname)
|
|
|
|
return val
|
|
|
|
|
|
|
|
def setnewattr(obj, name, val):
|
|
|
|
sentinel = object()
|
|
|
|
assert getattr(obj, name, sentinel) is sentinel
|
|
|
|
setattr(obj, name, val)
|
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## Error value data type and functional assert.
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Error:
|
|
|
|
err: Union[bool, core.Tracer]
|
|
|
|
code: Union[int, core.Tracer]
|
|
|
|
msgs: Dict[int, str]
|
|
|
|
|
|
|
|
def get(self) -> Optional[str]:
|
|
|
|
assert np.shape(self.err) == np.shape(self.code)
|
|
|
|
if np.size(self.err) == 1:
|
|
|
|
if self.err:
|
|
|
|
return self.msgs[int(self.code)]
|
|
|
|
else:
|
|
|
|
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore
|
|
|
|
f'{self.msgs[int(self.code[idx])]}' # type: ignore
|
|
|
|
for idx, e in np.ndenumerate(self.err) if e) or None
|
|
|
|
return None
|
|
|
|
|
|
|
|
register_pytree_node(Error,
|
|
|
|
lambda e: ((e.err, e.code), tuple(sorted(e.msgs.items()))),
|
|
|
|
lambda msgs, data: Error(*data, dict(msgs))) # type: ignore
|
|
|
|
|
|
|
|
init_error = Error(False, 0, {})
|
|
|
|
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
|
|
|
|
|
|
|
|
|
|
|
|
Bool = Union[bool, core.Tracer]
|
|
|
|
Int = Union[int, core.Tracer]
|
|
|
|
|
|
|
|
def assert_func(error: Error, pred: Bool, msg: str) -> Error:
|
|
|
|
code = next_code()
|
|
|
|
out_err = error.err | jnp.logical_not(pred)
|
|
|
|
out_code = lax.select(error.err, error.code, code)
|
|
|
|
return Error(out_err, out_code, {code: msg, **error.msgs})
|
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## Checkify transformation for plumbing functional error values.
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
class ErrorTracer(core.Tracer):
|
|
|
|
def __init__(self, trace, val):
|
|
|
|
self._trace = trace
|
|
|
|
self.val = val
|
|
|
|
core.get_aval(val), val
|
|
|
|
aval = property(lambda self: core.get_aval(self.val))
|
|
|
|
full_lower = lambda self: self
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
class ErrorTrace(core.Trace):
|
2021-12-02 14:26:58 -08:00
|
|
|
pure = lift = lambda self, val: ErrorTracer(self, val)
|
|
|
|
|
|
|
|
def sublift(self, tracer):
|
|
|
|
return ErrorTracer(self, tracer.val)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
rule = error_checks.get(primitive)
|
|
|
|
if rule:
|
2021-12-01 14:34:17 -08:00
|
|
|
out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore
|
2021-10-29 09:23:27 -07:00
|
|
|
else:
|
|
|
|
out = primitive.bind(*in_vals, **params)
|
|
|
|
if primitive.multiple_results:
|
|
|
|
return [ErrorTracer(self, x) for x in out]
|
|
|
|
else:
|
|
|
|
return ErrorTracer(self, out)
|
|
|
|
|
|
|
|
def process_call(self, primitive, f, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
|
|
|
params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
|
|
|
|
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
|
|
|
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
|
|
|
return [ErrorTracer(self, x) for x in out_vals]
|
|
|
|
|
|
|
|
def process_map(self, primitive, f, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
|
|
|
|
|
|
|
@as_hashable_function(closure=params['out_axes_thunk'])
|
|
|
|
def new_out_axes_thunk():
|
|
|
|
return (0, 0, *params['out_axes_thunk']())
|
|
|
|
|
|
|
|
params_ = dict(params, in_axes=(None, None, *params['in_axes']),
|
|
|
|
out_axes_thunk=new_out_axes_thunk,
|
|
|
|
donated_invars=(False, False, *params['donated_invars']))
|
|
|
|
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
|
|
|
err, code = _reduce_any_error(errs, codes)
|
|
|
|
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
|
|
|
return [ErrorTracer(self, x) for x in outs]
|
|
|
|
|
|
|
|
def post_process_call(self, primitive, tracers, params):
|
|
|
|
vals = [t.val for t in tracers]
|
|
|
|
main = self.main
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
err, code, main.msgs = e.err, e.code, e.msgs
|
|
|
|
def todo(vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
err, code, *vals = vals
|
|
|
|
return [ErrorTracer(trace, x) for x in vals]
|
|
|
|
return (err, code, *vals), todo
|
|
|
|
|
|
|
|
def post_process_map(self, primitive, tracers, params):
|
|
|
|
vals = [t.val for t in tracers]
|
|
|
|
main = self.main
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
err, code, main.msgs = e.err, e.code, e.msgs
|
|
|
|
def todo(vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
err, code, *vals = vals
|
|
|
|
return [ErrorTracer(trace, x) for x in vals]
|
|
|
|
def out_axes_transform(out_axes):
|
|
|
|
return (0, 0, *out_axes)
|
|
|
|
return (err, code, *vals), (todo, out_axes_transform)
|
|
|
|
|
|
|
|
def _reduce_any_error(errs, codes):
|
|
|
|
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
|
|
|
|
return errs_[-1], codes_[-1]
|
|
|
|
|
|
|
|
ErrorCheckRule = Callable
|
|
|
|
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
|
|
|
|
|
|
|
def check_errors_flat(fun: lu.WrappedFun, *args):
|
|
|
|
fun, msgs = check_errors_subtrace(fun)
|
|
|
|
fun = check_errors_toplevel(fun)
|
|
|
|
err, code, *out_vals = fun.call_wrapped(*args)
|
|
|
|
return (err, code, out_vals), msgs()
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def check_errors_toplevel(*args):
|
|
|
|
error = init_error
|
|
|
|
with core.new_main(ErrorTrace) as main:
|
|
|
|
msgs = tuple(error.msgs.items())
|
|
|
|
outs = yield (main, msgs, error.err, error.code, *args), {}
|
|
|
|
del main
|
|
|
|
yield outs
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def check_errors_subtrace(main, msgs, err, code, *args):
|
|
|
|
setnewattr(main, 'error', Error(err, code, dict(msgs)))
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
in_tracers = [ErrorTracer(trace, x) for x in args]
|
|
|
|
out = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, out)
|
|
|
|
out_vals = [t.val for t in out_tracers]
|
|
|
|
err, code, msgs = main.error.err, main.error.code, main.error.msgs
|
|
|
|
del main.error
|
|
|
|
yield (err, code, *out_vals), msgs
|
|
|
|
|
2021-12-08 15:07:43 +00:00
|
|
|
def checkify_fun_to_jaxpr(f, error, in_avals):
|
2021-10-29 09:23:27 -07:00
|
|
|
f, msgs = check_errors_subtrace(f)
|
|
|
|
f = check_errors_traceable(f, tuple(error.msgs.items()))
|
|
|
|
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
|
|
|
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
2021-12-08 15:07:43 +00:00
|
|
|
avals_in = [err_aval, code_aval, *in_avals]
|
2021-10-29 09:23:27 -07:00
|
|
|
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
|
|
|
|
2021-12-08 15:07:43 +00:00
|
|
|
# TODO take (error_aval, code_aval) instead of error here?
|
|
|
|
def checkify_jaxpr(jaxpr, error):
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
# TODO dedup with check_errors_toplevel
|
|
|
|
@lu.transformation
|
|
|
|
def check_errors_traceable(msgs, err, code, *args):
|
|
|
|
with core.new_main(ErrorTrace) as main:
|
|
|
|
outs = yield (main, msgs, err, code, *args), {}
|
|
|
|
del main
|
|
|
|
yield outs
|
|
|
|
|
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
## assert primitive
|
|
|
|
|
|
|
|
def assert_(pred: Bool, msg: str) -> None:
|
2021-12-16 10:57:19 -08:00
|
|
|
if not is_scalar_pred(pred):
|
|
|
|
raise TypeError(f"assert_ takes a scalar pred as argument, got {pred}")
|
2021-12-02 14:26:58 -08:00
|
|
|
code = next_code()
|
|
|
|
return assert2_(pred, code, {code: msg})
|
|
|
|
|
2021-12-16 10:57:19 -08:00
|
|
|
def is_scalar_pred(pred) -> bool:
|
|
|
|
return (isinstance(pred, bool) or
|
|
|
|
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
|
|
|
pred.dtype == jnp.dtype('bool'))
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
def assert2_(pred: Bool, code: Int, msgs: Dict[int, str]) -> None:
|
|
|
|
return assert_p.bind(pred, code, msgs=msgs)
|
2021-12-02 11:33:56 -08:00
|
|
|
|
|
|
|
assert_p = core.Primitive('assert')
|
|
|
|
assert_p.multiple_results = True # zero results
|
|
|
|
|
|
|
|
@assert_p.def_impl
|
2021-12-02 14:26:58 -08:00
|
|
|
def assert_impl(pred, code, *, msgs):
|
|
|
|
assert pred, msgs[int(code)]
|
2021-12-02 11:33:56 -08:00
|
|
|
return []
|
|
|
|
|
|
|
|
@assert_p.def_abstract_eval
|
2021-12-02 14:26:58 -08:00
|
|
|
def assert_abstract_eval(pred, code, *, msgs):
|
2021-12-02 11:33:56 -08:00
|
|
|
raise Exception("can't be staged!")
|
|
|
|
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## checkify rules
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
def summary() -> str:
|
|
|
|
return str(source_info_util.summarize(source_info_util.current()))
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
def nan_error_check(prim, error, *in_vals, **params):
|
|
|
|
out = prim.bind(*in_vals, **params)
|
|
|
|
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
|
2021-12-02 14:26:58 -08:00
|
|
|
msg = f"nan generated by primitive {prim.name} at {summary()}"
|
2021-10-29 09:23:27 -07:00
|
|
|
return out, assert_func(error, no_nans, msg)
|
|
|
|
|
|
|
|
def gather_error_check(error, operand, start_indices, *,
|
|
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
|
|
indices_are_sorted, mode, fill_value):
|
2021-12-16 22:44:05 +00:00
|
|
|
out = lax.gather_p.bind(
|
2021-10-29 09:23:27 -07:00
|
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
|
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
|
|
|
|
# compare to OOB masking logic in lax._gather_translation_rule
|
|
|
|
dnums = dimension_numbers
|
|
|
|
operand_dims = np.array(operand.shape)
|
|
|
|
|
|
|
|
upper_bound = operand_dims[np.array(dnums.start_index_map)]
|
|
|
|
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
|
|
|
all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound))
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
msg = f"out-of-bounds indexing at {summary()}"
|
2021-10-29 09:23:27 -07:00
|
|
|
return out, assert_func(error, all_inbounds, msg)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.gather_p] = gather_error_check
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2021-12-20 15:56:50 +00:00
|
|
|
def div_error_check(error, x, y):
|
|
|
|
"""Checks for division by zero and NaN."""
|
|
|
|
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
|
|
|
|
msg = f'divided by zero at {summary()}'
|
|
|
|
div_by_zero_err = assert_func(error, all_nonzero, msg)
|
|
|
|
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
|
|
|
|
error_checks[lax.div_p] = div_error_check
|
|
|
|
|
2021-12-23 15:23:58 +00:00
|
|
|
def scatter_in_bounds(operand, indices, updates, dnums):
|
|
|
|
# Ref: see clamping code used in scatter_translation_rule
|
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
|
|
|
for i in range(len(operand.shape)):
|
|
|
|
if i in dnums.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
|
|
|
pos += 1
|
|
|
|
|
|
|
|
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
|
|
|
for i in dnums.scatter_dims_to_operand_dims],
|
|
|
|
np.int64)
|
|
|
|
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
|
|
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
|
|
|
(len(indices.shape) - 1,))
|
|
|
|
|
|
|
|
lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
|
|
|
|
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
|
|
|
|
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
|
|
|
|
|
|
|
|
def scatter_error_check(prim, error, operand, indices, updates, *,
|
|
|
|
update_jaxpr, update_consts,
|
|
|
|
dimension_numbers, indices_are_sorted,
|
|
|
|
unique_indices, mode):
|
|
|
|
"""Checks if indices are within bounds and update does not generate NaN."""
|
|
|
|
out = prim.bind(
|
|
|
|
operand, indices, updates, update_jaxpr=update_jaxpr,
|
|
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
|
|
|
mode=mode)
|
|
|
|
|
|
|
|
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
|
|
|
|
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
|
|
|
|
oob_error = assert_func(error, in_bounds, oob_msg)
|
|
|
|
|
|
|
|
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
|
|
|
|
nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
|
|
|
|
return out, assert_func(oob_error, no_nans, nan_msg)
|
|
|
|
error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p)
|
|
|
|
error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p)
|
|
|
|
error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p)
|
|
|
|
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
|
|
|
|
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
def cond_error_check(error, index, *ops, branches, linear):
|
|
|
|
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
|
|
|
|
new_linear = (False, False, *linear)
|
2021-12-16 22:44:05 +00:00
|
|
|
err, code, *outs = lax.cond_p.bind(
|
2021-10-29 09:23:27 -07:00
|
|
|
index, error.err, error.code, *ops,
|
|
|
|
branches=tuple(new_branches), linear=new_linear)
|
|
|
|
new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()}
|
|
|
|
return outs, Error(err, code, new_msgs)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.cond_p] = cond_error_check
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2021-12-06 19:36:35 +00:00
|
|
|
def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
|
|
|
|
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
|
|
|
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error)
|
|
|
|
new_linear = (False, False, *linear)
|
|
|
|
new_in_flat = [*consts, error.err, error.code, *carry, *xs]
|
2021-12-16 22:44:05 +00:00
|
|
|
err, code, *outs = lax.scan_p.bind(
|
2021-12-06 19:36:35 +00:00
|
|
|
*consts, *new_in_flat,
|
|
|
|
reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
|
|
|
num_consts=len(consts), num_carry=len(carry)+2,
|
|
|
|
linear=new_linear, unroll=unroll)
|
|
|
|
new_msgs = {**error.msgs, **msgs_}
|
|
|
|
return outs, Error(err, code, new_msgs)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.scan_p] = scan_error_check
|
2021-12-06 19:36:35 +00:00
|
|
|
|
2021-12-08 15:07:43 +00:00
|
|
|
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
|
|
|
|
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
|
|
|
body_f = core.jaxpr_as_fun(body_jaxpr)
|
|
|
|
def new_body_f(*vals):
|
2021-12-16 21:48:37 +00:00
|
|
|
out = body_f(*vals)
|
|
|
|
_ = cond_f(*out) # this checks if the next cond application will error
|
|
|
|
return out
|
2021-12-08 15:07:43 +00:00
|
|
|
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
|
|
|
|
|
|
|
|
def ignore_errors_jaxpr(jaxpr, error):
|
|
|
|
"""Constructs a jaxpr which takes two extra args but ignores them."""
|
|
|
|
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
|
|
|
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
|
|
|
consts = jaxpr.consts
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
new_vars = core.gensym([jaxpr])
|
|
|
|
new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
|
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars,
|
|
|
|
jaxpr.outvars, jaxpr.eqns)
|
|
|
|
return core.ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
|
|
|
|
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
|
2021-12-16 21:48:37 +00:00
|
|
|
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
|
|
|
|
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
|
|
|
|
# Check if the first cond application will error.
|
|
|
|
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
|
|
|
|
|
|
|
|
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
|
2021-12-08 15:07:43 +00:00
|
|
|
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
|
|
|
|
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
2021-12-16 21:48:37 +00:00
|
|
|
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
|
2021-12-16 22:44:05 +00:00
|
|
|
err, code, *out = lax.while_p.bind(
|
2021-12-08 15:07:43 +00:00
|
|
|
*new_in_flat,
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=compat_cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts,
|
|
|
|
body_jaxpr=checked_body_jaxpr)
|
2021-12-16 21:48:37 +00:00
|
|
|
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
|
2021-12-08 15:07:43 +00:00
|
|
|
return out, Error(err, code, new_msgs)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.while_p] = while_loop_error_check
|
2021-12-08 15:07:43 +00:00
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
|
2021-12-02 14:26:58 -08:00
|
|
|
# with the error-check-adding transformation (checkify), but they could be
|
|
|
|
# separated into two orthogonal transformations.
|
|
|
|
def assert_discharge_rule(error, pred, code, *, msgs):
|
|
|
|
out_err = error.err | jnp.logical_not(pred)
|
|
|
|
out_code = lax.select(error.err, error.code, code)
|
|
|
|
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
2021-12-02 11:33:56 -08:00
|
|
|
error_checks[assert_p] = assert_discharge_rule
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
## checkify api
|
|
|
|
|
|
|
|
def checkify(fun: Callable) -> Callable:
|
|
|
|
@traceback_util.api_boundary
|
|
|
|
def checked_fun(*args, **kwargs):
|
|
|
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
|
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
|
|
(err, code, out_flat), msgs = check_errors_flat(f, *args_flat)
|
|
|
|
out = tree_unflatten(out_tree(), out_flat)
|
|
|
|
return Error(err, code, msgs), out
|
|
|
|
return checked_fun
|
2021-12-16 22:44:05 +00:00
|
|
|
|
|
|
|
## NaN error rule table
|
|
|
|
|
|
|
|
def add_nan_check(prim):
|
|
|
|
error_checks[prim] = partial(nan_error_check, prim)
|
|
|
|
|
|
|
|
add_nan_check(lax.floor_p)
|
|
|
|
add_nan_check(lax.ceil_p)
|
|
|
|
add_nan_check(lax.round_p)
|
|
|
|
add_nan_check(lax.sign_p)
|
|
|
|
add_nan_check(lax.shift_left_p)
|
|
|
|
add_nan_check(lax.shift_right_arithmetic_p)
|
|
|
|
add_nan_check(lax.shift_right_logical_p)
|
|
|
|
add_nan_check(lax.bitcast_convert_type_p)
|
|
|
|
add_nan_check(lax.real_p)
|
|
|
|
add_nan_check(lax.complex_p)
|
|
|
|
add_nan_check(lax.conj_p)
|
|
|
|
add_nan_check(lax.imag_p)
|
|
|
|
add_nan_check(lax.add_p)
|
|
|
|
add_nan_check(lax.sub_p)
|
|
|
|
add_nan_check(lax.convert_element_type_p)
|
|
|
|
add_nan_check(lax.broadcast_in_dim_p)
|
|
|
|
add_nan_check(lax.concatenate_p)
|
|
|
|
add_nan_check(lax.pad_p)
|
|
|
|
add_nan_check(lax.reshape_p)
|
|
|
|
add_nan_check(lax.rev_p)
|
|
|
|
add_nan_check(lax.transpose_p)
|
|
|
|
add_nan_check(lax.slice_p)
|
|
|
|
add_nan_check(lax.reduce_sum_p)
|
|
|
|
add_nan_check(lax.reduce_window_sum_p)
|
|
|
|
add_nan_check(lax.fft_p)
|
|
|
|
add_nan_check(lax.cumsum_p)
|
|
|
|
add_nan_check(lax.cumprod_p)
|
|
|
|
add_nan_check(lax.cummax_p)
|
|
|
|
add_nan_check(lax.cummin_p)
|
|
|
|
add_nan_check(lax.erf_p)
|
|
|
|
add_nan_check(lax.expm1_p)
|
|
|
|
add_nan_check(lax.log1p_p)
|
|
|
|
add_nan_check(lax.sqrt_p)
|
|
|
|
add_nan_check(lax.rsqrt_p)
|
|
|
|
add_nan_check(lax.asinh_p)
|
|
|
|
add_nan_check(lax.acosh_p)
|
|
|
|
add_nan_check(lax.atanh_p)
|
|
|
|
add_nan_check(lax.erfc_p)
|
|
|
|
add_nan_check(lax.rem_p)
|
|
|
|
add_nan_check(lax.clamp_p)
|
|
|
|
add_nan_check(lax.erf_inv_p)
|
|
|
|
add_nan_check(lax.exp_p)
|
|
|
|
add_nan_check(lax.pow_p)
|
|
|
|
add_nan_check(lax.integer_pow_p)
|
|
|
|
add_nan_check(lax.tanh_p)
|
|
|
|
add_nan_check(lax.log_p)
|
|
|
|
add_nan_check(lax.atan2_p)
|
|
|
|
add_nan_check(lax.sin_p)
|
|
|
|
add_nan_check(lax.cos_p)
|
|
|
|
add_nan_check(lax.sinh_p)
|
|
|
|
add_nan_check(lax.cosh_p)
|
|
|
|
add_nan_check(lax.dot_general_p)
|
|
|
|
add_nan_check(lax.mul_p)
|
|
|
|
add_nan_check(lax.conv_general_dilated_p)
|
|
|
|
add_nan_check(lax.reduce_max_p)
|
|
|
|
add_nan_check(lax.reduce_min_p)
|
|
|
|
add_nan_check(lax.abs_p)
|
|
|
|
add_nan_check(lax.select_p)
|
|
|
|
add_nan_check(lax.max_p)
|
|
|
|
add_nan_check(lax.min_p)
|