2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-10-29 09:23:27 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
import enum
|
2021-10-29 09:23:27 -07:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import partial
|
|
|
|
import itertools as it
|
2022-07-29 17:27:40 +01:00
|
|
|
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
from jax import core
|
|
|
|
from jax import linear_util as lu
|
|
|
|
from jax.api_util import flatten_fun
|
2022-07-29 17:27:40 +01:00
|
|
|
from jax.experimental import pjit
|
|
|
|
from jax.experimental import maps
|
2022-09-08 05:25:19 -07:00
|
|
|
from jax.interpreters import ad
|
2022-06-23 17:23:43 +01:00
|
|
|
from jax.interpreters import batching
|
2022-04-19 16:08:09 +01:00
|
|
|
from jax.interpreters import mlir
|
2021-10-29 09:23:27 -07:00
|
|
|
from jax.interpreters import partial_eval as pe
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src.sharding import OpShardingSharding
|
2021-10-29 09:23:27 -07:00
|
|
|
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
|
|
|
from jax._src import source_info_util, traceback_util
|
2022-04-19 16:08:09 +01:00
|
|
|
from jax._src.lax import control_flow as cf
|
2022-08-16 16:51:26 -07:00
|
|
|
from jax._src.config import config
|
2022-11-09 12:08:57 +00:00
|
|
|
from jax._src import prng
|
2021-12-16 22:44:05 +00:00
|
|
|
from jax import lax
|
2022-09-23 09:59:46 -07:00
|
|
|
from jax._src.typing import Array
|
2022-01-18 22:22:57 -08:00
|
|
|
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
|
|
|
|
safe_zip)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
2022-01-18 22:22:57 -08:00
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
## Utils
|
|
|
|
|
|
|
|
def popattr(obj, attrname):
|
|
|
|
val = getattr(obj, attrname)
|
|
|
|
delattr(obj, attrname)
|
|
|
|
return val
|
|
|
|
|
|
|
|
def setnewattr(obj, name, val):
|
|
|
|
sentinel = object()
|
|
|
|
assert getattr(obj, name, sentinel) is sentinel
|
|
|
|
setattr(obj, name, val)
|
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## Error value data type and functional assert.
|
|
|
|
|
2022-09-23 09:59:46 -07:00
|
|
|
Bool = Union[bool, Array]
|
|
|
|
Int = Union[int, Array]
|
|
|
|
Payload = Union[np.ndarray, Array]
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-03-30 17:19:15 +01:00
|
|
|
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
|
|
|
|
# OOB message.
|
2022-03-08 14:41:18 +00:00
|
|
|
# TODO(lenamartens): Relax this fixed-size constraint.
|
2022-10-14 17:01:12 +01:00
|
|
|
init_payload = np.ones((3,), np.int32)
|
2022-03-30 17:19:15 +01:00
|
|
|
|
|
|
|
|
|
|
|
def _format_msg(msg, payloads):
|
|
|
|
payload_mapping = {}
|
|
|
|
for i, pl in enumerate(payloads):
|
|
|
|
payload_mapping[f'payload{i}'] = pl
|
|
|
|
return msg.format(**payload_mapping)
|
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Error:
|
2022-01-10 12:25:47 +00:00
|
|
|
err: Bool
|
|
|
|
code: Int
|
2021-10-29 09:23:27 -07:00
|
|
|
msgs: Dict[int, str]
|
2022-03-08 14:41:18 +00:00
|
|
|
# There might be many msgs with a {payload}, but only one msg will
|
|
|
|
# ever be active for an Error instance, so only one Payload is tracked.
|
2022-08-10 18:24:05 +00:00
|
|
|
payload: Payload
|
|
|
|
|
|
|
|
def __init__(self, err: Bool, code: Int, msgs: Dict[int, str], payload: Optional[Payload] = None):
|
|
|
|
# We can't directly assign to members of a frozen dataclass, even in __init__.
|
|
|
|
object.__setattr__(self, "err", err)
|
|
|
|
object.__setattr__(self, "code", code)
|
|
|
|
object.__setattr__(self, "msgs", msgs)
|
|
|
|
object.__setattr__(self, "payload",
|
2022-10-14 17:01:12 +01:00
|
|
|
init_payload if payload is None else payload)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def get(self) -> Optional[str]:
|
2022-01-10 12:25:47 +00:00
|
|
|
"""Returns error message is error happened, None if no error happened."""
|
2021-10-29 09:23:27 -07:00
|
|
|
assert np.shape(self.err) == np.shape(self.code)
|
|
|
|
if np.size(self.err) == 1:
|
|
|
|
if self.err:
|
2022-03-30 17:19:15 +01:00
|
|
|
return _format_msg(self.msgs[int(self.code)], self.payload)
|
2021-10-29 09:23:27 -07:00
|
|
|
else:
|
2022-03-08 14:41:18 +00:00
|
|
|
return '\n'.join(
|
|
|
|
f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore
|
2022-03-30 17:19:15 +01:00
|
|
|
f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore
|
2022-03-08 14:41:18 +00:00
|
|
|
for idx, e in np.ndenumerate(self.err) if e) or None
|
2021-10-29 09:23:27 -07:00
|
|
|
return None
|
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
def throw(self):
|
2022-09-22 15:23:54 +01:00
|
|
|
check_error(self)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return f'Error({self.get()})'
|
|
|
|
|
|
|
|
|
|
|
|
def raise_error(error):
|
|
|
|
err = error.get()
|
|
|
|
if err:
|
|
|
|
raise ValueError(err)
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-03-08 14:41:18 +00:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
register_pytree_node(Error,
|
2022-03-08 14:41:18 +00:00
|
|
|
lambda e: ((e.err, e.code, e.payload),
|
|
|
|
tuple(sorted(e.msgs.items()))),
|
|
|
|
lambda msgs, data: Error(data[0], data[1], # type: ignore
|
|
|
|
dict(msgs), data[2])) # type: ignore
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
init_error = Error(False, 0, {})
|
|
|
|
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
|
|
|
|
|
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
def assert_func(error: Error, err: Bool, msg: str,
|
2022-03-08 14:41:18 +00:00
|
|
|
payload: Optional[Payload]) -> Error:
|
2021-10-29 09:23:27 -07:00
|
|
|
code = next_code()
|
2022-10-14 17:01:12 +01:00
|
|
|
payload = init_payload if payload is None else payload
|
2022-09-08 08:58:44 -07:00
|
|
|
out_err = error.err | err
|
2021-10-29 09:23:27 -07:00
|
|
|
out_code = lax.select(error.err, error.code, code)
|
2022-03-08 14:41:18 +00:00
|
|
|
out_payload = lax.select(error.err, error.payload, payload)
|
|
|
|
return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## Checkify transformation for plumbing functional error values.
|
|
|
|
|
2022-01-10 21:29:12 -08:00
|
|
|
class CheckifyTracer(core.Tracer):
|
2021-12-02 14:26:58 -08:00
|
|
|
def __init__(self, trace, val):
|
|
|
|
self._trace = trace
|
|
|
|
self.val = val
|
|
|
|
aval = property(lambda self: core.get_aval(self.val))
|
|
|
|
full_lower = lambda self: self
|
|
|
|
|
2022-01-10 21:29:12 -08:00
|
|
|
class CheckifyTrace(core.Trace):
|
|
|
|
pure = lift = lambda self, val: CheckifyTracer(self, val)
|
2021-12-02 14:26:58 -08:00
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
|
|
|
|
enabled_errors: FrozenSet['ErrorCategory']) -> None:
|
|
|
|
self.main = main
|
|
|
|
self.level = main.level
|
|
|
|
self.sublevel = sublevel
|
|
|
|
self.main.enabled_errors = enabled_errors
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
def sublift(self, tracer):
|
2022-01-10 21:29:12 -08:00
|
|
|
return CheckifyTracer(self, tracer.val)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
rule = error_checks.get(primitive)
|
|
|
|
if rule:
|
2022-01-18 22:22:57 -08:00
|
|
|
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
|
|
|
|
*in_vals, **params)
|
2021-10-29 09:23:27 -07:00
|
|
|
else:
|
|
|
|
out = primitive.bind(*in_vals, **params)
|
|
|
|
if primitive.multiple_results:
|
2022-01-10 21:29:12 -08:00
|
|
|
return [CheckifyTracer(self, x) for x in out]
|
2021-10-29 09:23:27 -07:00
|
|
|
else:
|
2022-01-10 21:29:12 -08:00
|
|
|
return CheckifyTracer(self, out)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def process_call(self, primitive, f, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
2022-01-10 21:29:12 -08:00
|
|
|
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
2022-02-14 15:02:05 +00:00
|
|
|
if 'donated_invars' in params:
|
2022-03-08 14:41:18 +00:00
|
|
|
params = dict(params, donated_invars=(False, False, False,
|
2022-02-14 15:02:05 +00:00
|
|
|
*params['donated_invars']))
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *out_vals = primitive.bind(f, e.err, e.code, e.payload,
|
|
|
|
*in_vals, **params)
|
|
|
|
setnewattr(self.main, 'error', Error(err, code, msgs(), payload))
|
2022-01-10 21:29:12 -08:00
|
|
|
return [CheckifyTracer(self, x) for x in out_vals]
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def process_map(self, primitive, f, tracers, params):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
2022-01-10 21:29:12 -08:00
|
|
|
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
@as_hashable_function(closure=params['out_axes_thunk'])
|
|
|
|
def new_out_axes_thunk():
|
2022-03-08 14:41:18 +00:00
|
|
|
return (0, 0, 0, *params['out_axes_thunk']())
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-03-08 14:41:18 +00:00
|
|
|
params_ = dict(params, in_axes=(None, None, None, *params['in_axes']),
|
2021-10-29 09:23:27 -07:00
|
|
|
out_axes_thunk=new_out_axes_thunk,
|
2022-03-08 14:41:18 +00:00
|
|
|
donated_invars=(False, False, False, *params['donated_invars']))
|
|
|
|
errs, codes, payloads, *outs = primitive.bind(f, e.err, e.code, e.payload,
|
|
|
|
*in_vals, **params_)
|
|
|
|
err, code, payload = _reduce_any_error(errs, codes, payloads)
|
|
|
|
setnewattr(self.main, 'error', Error(err, code, msgs(), payload))
|
2022-01-10 21:29:12 -08:00
|
|
|
return [CheckifyTracer(self, x) for x in outs]
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def post_process_call(self, primitive, tracers, params):
|
|
|
|
vals = [t.val for t in tracers]
|
|
|
|
main = self.main
|
2022-01-18 22:22:57 -08:00
|
|
|
e = popattr(main, 'error')
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
|
2021-10-29 09:23:27 -07:00
|
|
|
def todo(vals):
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *vals = vals
|
|
|
|
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
|
2022-01-18 22:22:57 -08:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-10 21:29:12 -08:00
|
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
2022-03-08 14:41:18 +00:00
|
|
|
return (err, code, payload, *vals), todo
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
def post_process_map(self, primitive, tracers, params):
|
|
|
|
vals = [t.val for t in tracers]
|
|
|
|
main = self.main
|
2022-01-18 22:22:57 -08:00
|
|
|
e = popattr(main, 'error')
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
|
2021-10-29 09:23:27 -07:00
|
|
|
def todo(vals):
|
2022-03-08 14:41:18 +00:00
|
|
|
errs, codes, payloads, *vals = vals
|
|
|
|
err, code, payload = _reduce_any_error(errs, codes, payloads)
|
|
|
|
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
|
2021-10-29 09:23:27 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-10 21:29:12 -08:00
|
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
2021-10-29 09:23:27 -07:00
|
|
|
def out_axes_transform(out_axes):
|
2022-03-08 14:41:18 +00:00
|
|
|
return (0, 0, 0, *out_axes)
|
|
|
|
return (err, code, payload, *vals), (todo, out_axes_transform)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-02-09 11:18:40 -08:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
msgs = tuple(e.msgs.items())
|
|
|
|
fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
|
|
|
|
jvp, msgs2 = checkify_custom_jvp_subtrace(jvp, self.main, msgs)
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *out_vals = prim.bind(fun, jvp, e.err, e.code,
|
|
|
|
e.payload, *in_vals)
|
2022-02-09 11:18:40 -08:00
|
|
|
fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
|
2022-03-08 14:41:18 +00:00
|
|
|
setattr(self.main, 'error', Error(err, code, out_msgs, payload))
|
2022-02-09 11:18:40 -08:00
|
|
|
return [CheckifyTracer(self, x) for x in out_vals]
|
|
|
|
|
|
|
|
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
|
|
|
|
if jvp_was_run:
|
|
|
|
msg = ("support for custom_jvp rules which close over checkify values is "
|
|
|
|
"not implemented. If you see this, open an issue at "
|
|
|
|
"https://github.com/google/jax/issues!")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
vals = [t.val for t in out_tracers]
|
|
|
|
main = self.main
|
|
|
|
e = popattr(main, 'error')
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, main.msgs = e.err, e.code, e.payload, e.msgs
|
2022-02-09 11:18:40 -08:00
|
|
|
def todo(vals):
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *vals = vals
|
|
|
|
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'), payload))
|
2022-02-09 11:18:40 -08:00
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
return [CheckifyTracer(trace, x) for x in vals]
|
2022-03-08 14:41:18 +00:00
|
|
|
return (err, code, payload, *vals), todo
|
2022-02-09 11:18:40 -08:00
|
|
|
|
2022-02-09 12:04:34 -08:00
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
in_vals = [t.val for t in tracers]
|
|
|
|
e = popattr(self.main, 'error')
|
|
|
|
msgs = tuple(e.msgs.items())
|
|
|
|
fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
|
|
|
|
fwd, msgs2 = checkify_custom_vjp_subtrace(fwd, self.main, msgs)
|
2022-03-08 14:41:18 +00:00
|
|
|
out = prim.bind(fun, fwd, bwd, e.err, e.code, e.payload,
|
|
|
|
*in_vals, out_trees=out_trees)
|
2022-02-09 12:04:34 -08:00
|
|
|
fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
|
|
|
|
if fst:
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *out = out
|
2022-02-09 12:04:34 -08:00
|
|
|
else:
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload = e.err, e.code, e.payload # forward input error values to output
|
|
|
|
setattr(self.main, 'error', Error(err, code, out_msgs, payload))
|
2022-02-09 12:04:34 -08:00
|
|
|
return [CheckifyTracer(self, x) for x in out]
|
|
|
|
|
2022-03-08 14:41:18 +00:00
|
|
|
def _reduce_any_error(errs, codes, payloads):
|
2022-03-30 17:19:15 +01:00
|
|
|
reduced_idx = jnp.argsort(errs)[-1]
|
|
|
|
return errs[reduced_idx], codes[reduced_idx], payloads[reduced_idx]
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-01-18 22:22:57 -08:00
|
|
|
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
2021-10-29 09:23:27 -07:00
|
|
|
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
|
|
|
|
2022-01-18 22:22:57 -08:00
|
|
|
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
|
|
|
|
*args):
|
2022-01-10 21:29:12 -08:00
|
|
|
fun, msgs = checkify_subtrace(fun)
|
2022-01-10 18:21:41 +00:00
|
|
|
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, *outvals = fun.call_wrapped(init_error.err,
|
|
|
|
init_error.code,
|
|
|
|
init_error.payload, *args)
|
|
|
|
return (err, code, payload, outvals), msgs()
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
@lu.transformation
|
2022-03-08 14:41:18 +00:00
|
|
|
def checkify_traceable(msgs, enabled_errors, err, code, payload, *args):
|
2022-01-10 18:21:41 +00:00
|
|
|
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
|
2022-03-08 14:41:18 +00:00
|
|
|
outs = yield (main, msgs, err, code, payload, *args), {}
|
2021-10-29 09:23:27 -07:00
|
|
|
del main
|
|
|
|
yield outs
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
2022-03-08 14:41:18 +00:00
|
|
|
def checkify_subtrace(main, msgs, err, code, payload, *args):
|
|
|
|
setnewattr(main, 'error', Error(err, code, dict(msgs), payload))
|
2021-10-29 09:23:27 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-10 21:29:12 -08:00
|
|
|
in_tracers = [CheckifyTracer(trace, x) for x in args]
|
2021-10-29 09:23:27 -07:00
|
|
|
out = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, out)
|
|
|
|
out_vals = [t.val for t in out_tracers]
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload, msgs = main.error.err, main.error.code, main.error.payload, main.error.msgs
|
2021-10-29 09:23:27 -07:00
|
|
|
del main.error
|
2022-03-08 14:41:18 +00:00
|
|
|
yield (err, code, payload, *out_vals), msgs
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-02-09 11:18:40 -08:00
|
|
|
@lu.transformation_with_aux
|
|
|
|
def checkify_custom_jvp_subtrace(main, msgs, *args):
|
|
|
|
# Like checkify_subtrace, but used specifically on the custom JVP rules
|
|
|
|
# associated with a custom_jvp. This code is called in the context of a
|
|
|
|
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
|
|
|
|
# flattened into a single args tuple, and similarly must produce flattened
|
|
|
|
# primal and tangent outputs. Both primals and tangents include error values,
|
|
|
|
# but the tangent error values are trivially zero.
|
|
|
|
# The types to have in mind are:
|
|
|
|
# jvp : (a -> b) -> (a, T a) -> (b, T b)
|
|
|
|
# checkify : (a -> b) -> a -> Err b
|
|
|
|
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
|
|
|
|
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
|
|
|
|
# where the other Err' components are trivial (of float0 dtype).
|
|
|
|
# Semantically, we don't add checks to the JVP rule. To check the result of a
|
|
|
|
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
|
|
|
|
# just forwards the input error and code (and trivial tangents) to the output.
|
2022-02-09 12:04:34 -08:00
|
|
|
del main
|
2022-02-09 11:18:40 -08:00
|
|
|
n, ragged = divmod(len(args), 2)
|
|
|
|
assert not ragged
|
2022-03-08 14:41:18 +00:00
|
|
|
(err,), (code,), (payload,), primals = split_list(args[:n], [1, 1, 1])
|
|
|
|
(err_dot,), (code_dot,), (pl_dot,), tangents = split_list(args[n:], [1, 1, 1])
|
2022-02-09 11:18:40 -08:00
|
|
|
outs = yield (*primals, *tangents), {}
|
|
|
|
m, ragged = divmod(len(outs), 2)
|
|
|
|
assert not ragged
|
|
|
|
out_primals, out_tangents = outs[:m], outs[m:]
|
2022-03-08 14:41:18 +00:00
|
|
|
yield (err, code, payload, *out_primals,
|
|
|
|
err_dot, code_dot, pl_dot, *out_tangents), dict(msgs)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-02-09 12:04:34 -08:00
|
|
|
@lu.transformation_with_aux
|
2022-03-08 14:41:18 +00:00
|
|
|
def checkify_custom_vjp_subtrace(main, msgs, err, code, payload, *args):
|
2022-02-09 12:04:34 -08:00
|
|
|
# We don't add any checks; just drop input error values.
|
2022-03-08 14:41:18 +00:00
|
|
|
del main, err, code, payload
|
2022-02-09 12:04:34 -08:00
|
|
|
outs = yield args, {}
|
|
|
|
yield outs, dict(msgs)
|
|
|
|
|
2021-12-08 15:07:43 +00:00
|
|
|
# TODO take (error_aval, code_aval) instead of error here?
|
2022-01-10 18:21:41 +00:00
|
|
|
def checkify_jaxpr(jaxpr, error, enabled_errors):
|
2021-12-08 15:07:43 +00:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2022-01-10 18:21:41 +00:00
|
|
|
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
|
2021-12-08 15:07:43 +00:00
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
|
2022-01-10 21:29:12 -08:00
|
|
|
f, msgs = checkify_subtrace(f)
|
2022-01-10 18:21:41 +00:00
|
|
|
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
|
2022-01-10 21:29:12 -08:00
|
|
|
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
|
|
|
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
2022-03-08 14:41:18 +00:00
|
|
|
payload_aval = core.raise_to_shaped(core.get_aval(error.payload))
|
|
|
|
avals_in = [err_aval, code_aval, payload_aval, *in_avals]
|
2022-01-10 21:29:12 -08:00
|
|
|
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
|
2021-12-02 11:33:56 -08:00
|
|
|
## assert primitive
|
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
def check(pred: Bool, msg: str) -> None:
|
2022-02-10 14:28:46 +00:00
|
|
|
"""Check a predicate, add an error with msg if predicate is False.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
2022-08-24 09:49:51 -04:00
|
|
|
Before staging a function with checks, :func:`~checkify` it!
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
pred: if False, an error is added.
|
|
|
|
msg: error message if error is added.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> from jax.experimental import checkify
|
|
|
|
>>> def f(x):
|
|
|
|
... checkify.check(x!=0, "cannot be zero!")
|
|
|
|
... return 1/x
|
|
|
|
>>> checked_f = checkify.checkify(f)
|
|
|
|
>>> err, out = jax.jit(checked_f)(0)
|
|
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
ValueError: cannot be zero! (check failed at ...)
|
|
|
|
|
|
|
|
"""
|
2021-12-16 10:57:19 -08:00
|
|
|
if not is_scalar_pred(pred):
|
2022-01-10 12:25:47 +00:00
|
|
|
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
|
2021-12-02 14:26:58 -08:00
|
|
|
code = next_code()
|
2022-01-10 12:25:47 +00:00
|
|
|
msg += f' (check failed at {summary()})'
|
|
|
|
return check_error(Error(jnp.logical_not(pred), code, {code: msg}))
|
2021-12-02 14:26:58 -08:00
|
|
|
|
2021-12-16 10:57:19 -08:00
|
|
|
def is_scalar_pred(pred) -> bool:
|
|
|
|
return (isinstance(pred, bool) or
|
|
|
|
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
|
|
|
pred.dtype == jnp.dtype('bool'))
|
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
def check_error(error: Error) -> None:
|
2022-08-24 09:49:51 -04:00
|
|
|
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
2022-02-08 20:23:40 +00:00
|
|
|
|
|
|
|
The semantics of this function are equivalent to:
|
2022-02-08 20:41:19 +00:00
|
|
|
|
2022-02-10 14:28:46 +00:00
|
|
|
>>> def check_error(err: Error) -> None:
|
|
|
|
... err.throw() # can raise ValueError
|
|
|
|
|
|
|
|
But unlike that implementation, ``check_error`` can be functionalized using
|
2022-08-24 09:49:51 -04:00
|
|
|
the :func:`~checkify` transformation.
|
2022-02-10 14:28:46 +00:00
|
|
|
|
2022-08-24 09:49:51 -04:00
|
|
|
This function is similar to :func:`~check` but with a different signature: whereas
|
|
|
|
:func:`~check` takes as arguments a boolean predicate and a new error message
|
|
|
|
string, this function takes an ``Error`` value as argument. Both :func:`~check`
|
2022-02-10 14:28:46 +00:00
|
|
|
and this function raise a Python Exception on failure (a side-effect), and
|
2022-08-24 09:49:51 -04:00
|
|
|
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
|
|
|
|
:func:`~jax.lax.scan`, etc. Both also can
|
|
|
|
be functionalized by using :func:`~checkify`.
|
|
|
|
|
|
|
|
But unlike :func:`~check`, this function is like a direct inverse of
|
|
|
|
:func:`~checkify`:
|
|
|
|
whereas :func:`~checkify` takes as input a function which
|
|
|
|
can raise a Python
|
2022-02-10 14:28:46 +00:00
|
|
|
Exception and produces a new function without that effect but which produces
|
|
|
|
an ``Error`` value as output, this ``check_error`` function can accept an
|
|
|
|
``Error`` value as input and can produce the side-effect of raising an
|
2022-08-24 09:49:51 -04:00
|
|
|
Exception. That is, while :func:`~checkify` goes from
|
|
|
|
functionalizable Exception
|
2022-02-10 14:28:46 +00:00
|
|
|
effect to error value, this ``check_error`` goes from error value to
|
2022-02-08 20:41:19 +00:00
|
|
|
functionalizable Exception effect.
|
|
|
|
|
2022-02-10 14:28:46 +00:00
|
|
|
``check_error`` is useful when you want to turn checks represented by an
|
2022-08-24 09:49:51 -04:00
|
|
|
``Error`` value (produced by functionalizing ``checks`` via
|
|
|
|
:func:`~checkify`) back into Python Exceptions.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
Args:
|
2022-02-10 14:28:46 +00:00
|
|
|
error: Error to check.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-02-10 14:28:46 +00:00
|
|
|
For example, you might want to functionalize part of your program through
|
2022-08-24 09:49:51 -04:00
|
|
|
checkify, stage out your functionalized code through :func:`~jax.jit`, then
|
|
|
|
re-inject your error value outside of the :func:`~jax.jit`:
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
>>> from jax.experimental import checkify
|
|
|
|
>>> def f(x):
|
|
|
|
... checkify.check(x>0, "must be positive!")
|
|
|
|
... return x
|
|
|
|
>>> def with_inner_jit(x):
|
|
|
|
... checked_f = checkify.checkify(f)
|
|
|
|
... # a checkified function can be jitted
|
|
|
|
... error, out = jax.jit(checked_f)(x)
|
|
|
|
... checkify.check_error(error)
|
|
|
|
... return out
|
|
|
|
>>> _ = with_inner_jit(1) # no failed check
|
|
|
|
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
ValueError: must be positive!
|
|
|
|
>>> # can re-checkify
|
|
|
|
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
|
|
|
"""
|
2022-09-22 15:23:54 +01:00
|
|
|
if not isinstance(error, Error):
|
|
|
|
raise ValueError('check_error takes an Error as argument, '
|
|
|
|
f'got type {type(error)} instead.')
|
|
|
|
|
2022-02-16 19:12:54 +00:00
|
|
|
if np.shape(error.err):
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
|
2022-02-15 13:12:19 -08:00
|
|
|
else:
|
2022-03-08 14:41:18 +00:00
|
|
|
err, code, payload = error.err, error.code, error.payload
|
2022-04-19 13:48:20 +01:00
|
|
|
|
|
|
|
err = core.raise_as_much_as_possible(err)
|
2022-09-08 08:58:44 -07:00
|
|
|
return assert_p.bind(err, code, payload, msgs=error.msgs)
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
assert_p = core.Primitive('assert') # TODO: rename to check?
|
2021-12-02 11:33:56 -08:00
|
|
|
assert_p.multiple_results = True # zero results
|
|
|
|
|
|
|
|
@assert_p.def_impl
|
2022-09-08 08:58:44 -07:00
|
|
|
def assert_impl(err, code, payload, *, msgs):
|
2022-09-22 15:23:54 +01:00
|
|
|
raise_error(Error(err, code, msgs, payload))
|
2021-12-02 11:33:56 -08:00
|
|
|
return []
|
|
|
|
|
2022-04-19 16:08:09 +01:00
|
|
|
CheckEffect = object()
|
|
|
|
|
|
|
|
@assert_p.def_effectful_abstract_eval
|
2022-09-08 08:58:44 -07:00
|
|
|
def assert_abstract_eval(err, code, payload, *, msgs):
|
2022-04-19 16:08:09 +01:00
|
|
|
return [], {CheckEffect}
|
|
|
|
|
2022-09-23 15:10:41 +01:00
|
|
|
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
|
|
|
functionalization_error = ValueError(
|
|
|
|
'Cannot abstractly evaluate a checkify.check which was not'
|
|
|
|
' functionalized. This probably means you tried to stage'
|
|
|
|
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
|
|
|
' through `checkify.checkify`.'
|
|
|
|
)
|
|
|
|
|
|
|
|
def python_err(msgs, err, code, payload):
|
|
|
|
error = Error(err, code, msgs, payload)
|
|
|
|
check_error(error)
|
|
|
|
return []
|
|
|
|
|
|
|
|
def assert_lowering_rule(ctx, err, code, payload, *, msgs):
|
|
|
|
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
|
|
|
raise functionalization_error
|
|
|
|
|
|
|
|
out_op, token_out, keep_alive = mlir.emit_python_callback(
|
|
|
|
ctx, callback=lambda *a: python_err(msgs, *a),
|
|
|
|
token=ctx.tokens_in.get(CheckEffect)[0],
|
|
|
|
operands=[err, code, payload],
|
|
|
|
operand_avals=list(ctx.avals_in),
|
|
|
|
result_avals=list(ctx.avals_out),
|
|
|
|
has_side_effect=True)
|
|
|
|
ctx.set_tokens_out(ctx.tokens_in.update_tokens(
|
|
|
|
mlir.TokenSet({CheckEffect: token_out})))
|
|
|
|
ctx.module_context.add_keepalive(keep_alive)
|
|
|
|
return out_op
|
|
|
|
|
|
|
|
def assert_lowering_rule_unsupported(*a, **k):
|
|
|
|
raise functionalization_error
|
|
|
|
|
|
|
|
mlir.register_lowering(assert_p, assert_lowering_rule_unsupported,
|
|
|
|
platform='tpu')
|
|
|
|
mlir.register_lowering(assert_p, assert_lowering_rule,
|
|
|
|
platform='cpu')
|
|
|
|
mlir.register_lowering(assert_p, assert_lowering_rule,
|
|
|
|
platform='gpu')
|
2022-04-19 16:08:09 +01:00
|
|
|
mlir.lowerable_effects.add(CheckEffect)
|
|
|
|
cf.allowed_effects.add(CheckEffect)
|
2022-09-23 15:10:41 +01:00
|
|
|
core.ordered_effects.add(CheckEffect)
|
2021-12-02 11:33:56 -08:00
|
|
|
|
2022-06-23 17:23:43 +01:00
|
|
|
|
|
|
|
def assert_batching_rule(batched_args, batch_dims, *, msgs):
|
|
|
|
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
|
|
|
|
if dim is not batching.not_mapped)
|
2022-09-08 08:58:44 -07:00
|
|
|
err, code, payload = (batching.bdim_at_front(a, d, size)
|
|
|
|
for a, d in zip(batched_args, batch_dims))
|
|
|
|
err = Error(err, code, msgs, payload)
|
2022-06-23 17:23:43 +01:00
|
|
|
check_error(err)
|
|
|
|
return [], []
|
|
|
|
|
|
|
|
batching.primitive_batchers[assert_p] = assert_batching_rule
|
|
|
|
|
2022-09-08 05:25:19 -07:00
|
|
|
def assert_jvp_rule(primals, _, *, msgs):
|
|
|
|
# Check primals, discard tangents.
|
|
|
|
assert_p.bind(*primals, msgs=msgs)
|
|
|
|
return [], []
|
|
|
|
|
|
|
|
ad.primitive_jvps[assert_p] = assert_jvp_rule
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
## checkify rules
|
|
|
|
|
2021-12-02 14:26:58 -08:00
|
|
|
def summary() -> str:
|
|
|
|
return str(source_info_util.summarize(source_info_util.current()))
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
2021-10-29 09:23:27 -07:00
|
|
|
out = prim.bind(*in_vals, **params)
|
2022-01-10 18:21:41 +00:00
|
|
|
if ErrorCategory.NAN not in enabled_errors:
|
|
|
|
return out, error
|
2022-11-09 12:08:57 +00:00
|
|
|
|
|
|
|
def isnan(x):
|
|
|
|
if isinstance(x, prng.PRNGKeyArray):
|
|
|
|
return False
|
|
|
|
return jnp.isnan(x)
|
|
|
|
|
|
|
|
any_nans = (jnp.any(isnan(x) for x in out)
|
|
|
|
if prim.multiple_results else jnp.any(isnan(out)))
|
2022-09-08 08:58:44 -07:00
|
|
|
msg = f'nan generated by primitive {prim.name} at {summary()}'
|
|
|
|
return out, assert_func(error, any_nans, msg, None)
|
2021-10-29 09:23:27 -07:00
|
|
|
|
Checkify: add and remove primitives which are checked for NaN outputs.
Started from all primitives exported from `jax.lax` and removed
a primitive when:
- its output is int/bool (but what if the output is complex?)
- it does not generate NaNs, ie. if the input does not contain a NaN
value, the output will not contain a NaN value (eg.
reshape/concatenate/..., max/..)
- it's already handled by other rules (eg. div, gather/scatter and
scan/cond/while)
Compared to the previous set:
added: {logistic, custom_linear_solve, igammac, igamma_grad_a, psum,
igamma, reduce, tan, rng_uniform, lgamma, digamma,
regularized_incomplete_beta, reduce_prod, reduce_window, cbrt,
bessel_i0e, random_gamma_grad, bessel_i1e}
removed: {shift_left, concatenate, complex, shift_right_arithmetic,
convert_element_type, conj, sign, round, shift_right_logical,
reduce_max, bitcast_convert_type, real, max, reduce_min, rev, slice,
min, imag, clamp, floor, select_n}
2022-11-14 17:50:46 +00:00
|
|
|
|
|
|
|
# All primitives which can generate a NaN.
|
|
|
|
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
|
|
|
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
|
|
|
|
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
|
|
|
|
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
|
|
|
|
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
|
|
|
|
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
|
|
|
|
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
|
|
|
|
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
|
|
|
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
|
|
|
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
|
|
|
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
|
|
|
lax.reduce_sum_p, lax.reduce_window_p,
|
|
|
|
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
|
|
|
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
|
|
|
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
|
|
|
|
|
|
|
|
for prim in nan_primitives:
|
|
|
|
error_checks[prim] = partial(nan_error_check, prim)
|
|
|
|
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
2021-10-29 09:23:27 -07:00
|
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
|
|
indices_are_sorted, mode, fill_value):
|
2021-12-16 22:44:05 +00:00
|
|
|
out = lax.gather_p.bind(
|
2021-10-29 09:23:27 -07:00
|
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
|
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
if ErrorCategory.OOB not in enabled_errors:
|
|
|
|
return out, error
|
|
|
|
|
2021-10-29 09:23:27 -07:00
|
|
|
# compare to OOB masking logic in lax._gather_translation_rule
|
|
|
|
dnums = dimension_numbers
|
|
|
|
operand_dims = np.array(operand.shape)
|
2022-03-08 14:41:18 +00:00
|
|
|
num_batch_dims = len(start_indices.shape) - 1
|
2021-10-29 09:23:27 -07:00
|
|
|
|
|
|
|
upper_bound = operand_dims[np.array(dnums.start_index_map)]
|
|
|
|
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
2022-03-08 14:41:18 +00:00
|
|
|
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
|
2022-09-08 08:58:44 -07:00
|
|
|
out_of_bounds = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
|
2022-03-08 14:41:18 +00:00
|
|
|
|
2022-03-30 17:19:15 +01:00
|
|
|
# Get first OOB index, axis and axis size so it can be added to the error msg.
|
2022-09-08 08:58:44 -07:00
|
|
|
flat_idx = jnp.argmin(jnp.logical_not(out_of_bounds))
|
2022-03-30 17:19:15 +01:00
|
|
|
multi_idx = jnp.unravel_index(flat_idx, start_indices.shape)
|
|
|
|
oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]]
|
|
|
|
oob_axis_size = jnp.array(operand.shape)[oob_axis]
|
|
|
|
oob_index = jnp.ravel(start_indices)[flat_idx]
|
|
|
|
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
|
|
|
|
|
|
|
|
msg = (f'out-of-bounds indexing at {summary()} for array of '
|
|
|
|
f'shape {operand.shape}: '
|
|
|
|
'index {payload0} is out of bounds for axis {payload1} '
|
|
|
|
'with size {payload2}.')
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
return out, assert_func(error, jnp.any(out_of_bounds), msg, payload)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.gather_p] = gather_error_check
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def div_error_check(error, enabled_errors, x, y):
|
2021-12-20 15:56:50 +00:00
|
|
|
"""Checks for division by zero and NaN."""
|
2022-01-10 18:21:41 +00:00
|
|
|
if ErrorCategory.DIV in enabled_errors:
|
2022-09-08 08:58:44 -07:00
|
|
|
any_zero = jnp.any(jnp.equal(y, 0))
|
2022-09-22 15:23:54 +01:00
|
|
|
msg = f'division by zero at {summary()}'
|
2022-09-08 08:58:44 -07:00
|
|
|
error = assert_func(error, any_zero, msg, None)
|
2022-01-10 18:21:41 +00:00
|
|
|
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
2021-12-20 15:56:50 +00:00
|
|
|
error_checks[lax.div_p] = div_error_check
|
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
def scatter_oob(operand, indices, updates, dnums):
|
2021-12-23 15:23:58 +00:00
|
|
|
# Ref: see clamping code used in scatter_translation_rule
|
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
|
|
|
for i in range(len(operand.shape)):
|
|
|
|
if i in dnums.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
|
|
|
pos += 1
|
|
|
|
|
|
|
|
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
|
|
|
for i in dnums.scatter_dims_to_operand_dims],
|
|
|
|
np.int64)
|
|
|
|
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
|
|
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
|
|
|
(len(indices.shape) - 1,))
|
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
lower_oob = jnp.any(jnp.less(indices, 0))
|
|
|
|
upper_oob = jnp.any(jnp.greater(indices, upper_bound.astype(indices.dtype)))
|
|
|
|
return jnp.logical_or(lower_oob, upper_oob)
|
2021-12-23 15:23:58 +00:00
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
|
|
|
|
*, update_jaxpr, update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices, mode):
|
2021-12-23 15:23:58 +00:00
|
|
|
"""Checks if indices are within bounds and update does not generate NaN."""
|
|
|
|
out = prim.bind(
|
|
|
|
operand, indices, updates, update_jaxpr=update_jaxpr,
|
|
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
|
|
|
mode=mode)
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
if ErrorCategory.OOB not in enabled_errors:
|
|
|
|
return out, error
|
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
out_of_bounds = scatter_oob(operand, indices, updates, dimension_numbers)
|
2021-12-23 15:23:58 +00:00
|
|
|
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
|
2022-09-08 08:58:44 -07:00
|
|
|
oob_error = assert_func(error, out_of_bounds, oob_msg, None)
|
2021-12-23 15:23:58 +00:00
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
any_nans = jnp.any(jnp.isnan(out))
|
2021-12-23 15:23:58 +00:00
|
|
|
nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
|
2022-09-08 08:58:44 -07:00
|
|
|
return out, assert_func(oob_error, any_nans, nan_msg, None)
|
2021-12-23 15:23:58 +00:00
|
|
|
error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p)
|
|
|
|
error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p)
|
|
|
|
error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p)
|
|
|
|
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
|
|
|
|
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
|
|
|
|
|
2022-01-10 18:21:41 +00:00
|
|
|
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
|
2022-01-18 22:22:57 -08:00
|
|
|
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors)
|
|
|
|
for jxpr in branches)
|
2022-03-08 14:41:18 +00:00
|
|
|
new_linear = (False, False, False, *linear)
|
|
|
|
err, code, payload, *outs = lax.cond_p.bind(
|
|
|
|
index, error.err, error.code, error.payload, *ops,
|
2021-10-29 09:23:27 -07:00
|
|
|
branches=tuple(new_branches), linear=new_linear)
|
|
|
|
new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()}
|
2022-03-08 14:41:18 +00:00
|
|
|
return outs, Error(err, code, new_msgs, payload)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.cond_p] = cond_error_check
|
2021-10-29 09:23:27 -07:00
|
|
|
|
2022-01-18 22:22:57 -08:00
|
|
|
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
|
|
|
num_consts, num_carry, linear, unroll):
|
2021-12-06 19:36:35 +00:00
|
|
|
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
2022-02-22 10:31:05 -08:00
|
|
|
checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
|
2022-03-08 14:41:18 +00:00
|
|
|
tomove = [False] * 3 + [True] * len(consts) + [False] * (len(carry) + len(xs))
|
2022-02-22 10:31:05 -08:00
|
|
|
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
|
2022-03-08 14:41:18 +00:00
|
|
|
new_linear = (False, False, False, *linear)
|
|
|
|
new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs]
|
|
|
|
err, code, payload, *outs = lax.scan_p.bind(
|
2022-02-22 10:31:05 -08:00
|
|
|
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
2022-03-08 14:41:18 +00:00
|
|
|
num_consts=len(consts), num_carry=len(carry)+3,
|
2021-12-06 19:36:35 +00:00
|
|
|
linear=new_linear, unroll=unroll)
|
|
|
|
new_msgs = {**error.msgs, **msgs_}
|
2022-03-08 14:41:18 +00:00
|
|
|
return outs, Error(err, code, new_msgs, payload)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.scan_p] = scan_error_check
|
2021-12-06 19:36:35 +00:00
|
|
|
|
2022-02-22 23:08:22 +00:00
|
|
|
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
|
2021-12-08 15:07:43 +00:00
|
|
|
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
|
|
|
body_f = core.jaxpr_as_fun(body_jaxpr)
|
|
|
|
def new_body_f(*vals):
|
2021-12-16 21:48:37 +00:00
|
|
|
out = body_f(*vals)
|
2022-02-22 23:08:22 +00:00
|
|
|
# This checks if the next cond application will error
|
|
|
|
_ = cond_f(*c_consts, *out)
|
2021-12-16 21:48:37 +00:00
|
|
|
return out
|
2022-01-18 22:22:57 -08:00
|
|
|
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
|
|
|
|
body_jaxpr.in_avals)
|
2021-12-08 15:07:43 +00:00
|
|
|
|
2022-04-19 16:08:09 +01:00
|
|
|
def ignore_error_output_jaxpr(jaxpr):
|
|
|
|
"""Constructs a checked jaxpr which does not output its error value."""
|
2021-12-08 15:07:43 +00:00
|
|
|
consts = jaxpr.consts
|
|
|
|
jaxpr = jaxpr.jaxpr
|
2022-04-19 16:08:09 +01:00
|
|
|
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:])
|
2021-12-08 15:07:43 +00:00
|
|
|
return core.ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
|
2022-01-18 22:22:57 -08:00
|
|
|
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
|
|
|
cond_jaxpr, body_nconsts, body_jaxpr):
|
2022-10-17 21:48:38 +01:00
|
|
|
if cond_jaxpr.out_avals[0].shape:
|
|
|
|
# TODO(lenamartens, sharadmv): support batched while.
|
|
|
|
raise ValueError('Checkify does not support batched while-loops '
|
|
|
|
'(checkify-of-vmap-of-while). \nHint: if possible, move '
|
|
|
|
'the vmap to the outer level to get '
|
|
|
|
'vmap-of-checkify-of-while.')
|
|
|
|
err_args = [error.err, error.code, error.payload]
|
2022-09-16 17:46:25 +01:00
|
|
|
|
2022-02-22 10:31:05 -08:00
|
|
|
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
|
|
|
|
2021-12-16 21:48:37 +00:00
|
|
|
# Check if the first cond application will error.
|
2022-04-19 16:08:09 +01:00
|
|
|
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error,
|
|
|
|
enabled_errors)
|
|
|
|
cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)(
|
2022-09-16 17:46:25 +01:00
|
|
|
*err_args, *c_consts, *carry)
|
2021-12-16 21:48:37 +00:00
|
|
|
|
2022-02-22 10:31:05 -08:00
|
|
|
checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
|
2022-10-17 21:48:38 +01:00
|
|
|
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
|
2022-03-08 14:41:18 +00:00
|
|
|
to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
|
2022-02-22 10:31:05 -08:00
|
|
|
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
2022-03-08 14:41:18 +00:00
|
|
|
|
2022-04-19 16:08:09 +01:00
|
|
|
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr)
|
2022-03-08 14:41:18 +00:00
|
|
|
to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry)
|
2022-02-22 10:31:05 -08:00
|
|
|
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
|
2022-03-08 14:41:18 +00:00
|
|
|
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry]
|
|
|
|
|
|
|
|
err, code, payload, *out = lax.while_p.bind(
|
2022-02-22 10:31:05 -08:00
|
|
|
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
|
2021-12-16 21:48:37 +00:00
|
|
|
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
|
2022-09-16 17:46:25 +01:00
|
|
|
|
2022-03-08 14:41:18 +00:00
|
|
|
return out, Error(err, code, new_msgs, payload)
|
2021-12-16 22:44:05 +00:00
|
|
|
error_checks[lax.while_p] = while_loop_error_check
|
2021-12-08 15:07:43 +00:00
|
|
|
|
2022-07-29 17:27:40 +01:00
|
|
|
|
|
|
|
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
|
|
|
in_shardings, out_shardings, resource_env,
|
|
|
|
donated_invars, name,
|
|
|
|
in_positional_semantics, out_positional_semantics):
|
|
|
|
checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors)
|
|
|
|
new_vals_in = [error.err, error.code, error.payload, *vals_in]
|
2022-08-16 16:51:26 -07:00
|
|
|
|
2022-08-05 12:17:41 -07:00
|
|
|
sharding = OpShardingSharding.get_replicated(
|
2022-08-16 16:51:26 -07:00
|
|
|
list(resource_env.physical_mesh.devices.flat))
|
|
|
|
new_in_shardings = (*[sharding] * 3, *in_shardings)
|
|
|
|
new_out_shardings = (*[sharding] * 3, *out_shardings)
|
|
|
|
|
|
|
|
if config.jax_array:
|
|
|
|
pos_sem = maps._PositionalSemantics.GLOBAL
|
|
|
|
else:
|
|
|
|
pos_sem = maps._positional_semantics.val
|
|
|
|
|
2022-07-29 17:27:40 +01:00
|
|
|
if not isinstance(in_positional_semantics, Iterable):
|
|
|
|
in_positional_semantics = (in_positional_semantics,)
|
|
|
|
if not isinstance(out_positional_semantics, Iterable):
|
|
|
|
out_positional_semantics = (out_positional_semantics,)
|
2022-08-16 16:51:26 -07:00
|
|
|
new_positional_sems_in = (*[pos_sem] * 3, *in_positional_semantics)
|
|
|
|
new_positional_sems_out = (*[pos_sem] * 3, *out_positional_semantics)
|
|
|
|
new_donated_invars = (*[False] * 3, *donated_invars)
|
|
|
|
|
2022-07-29 17:27:40 +01:00
|
|
|
err, code, payload, *vals_out = pjit.pjit_p.bind(
|
|
|
|
*new_vals_in,
|
|
|
|
jaxpr=checked_jaxpr,
|
|
|
|
in_shardings=new_in_shardings,
|
|
|
|
out_shardings=new_out_shardings,
|
|
|
|
resource_env=resource_env,
|
|
|
|
donated_invars=new_donated_invars,
|
|
|
|
name=name,
|
|
|
|
in_positional_semantics=new_positional_sems_in,
|
|
|
|
out_positional_semantics=new_positional_sems_out)
|
|
|
|
return vals_out, Error(err, code, msgs, payload)
|
|
|
|
error_checks[pjit.pjit_p] = pjit_error_check
|
|
|
|
|
|
|
|
|
2022-01-10 21:29:12 -08:00
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs):
|
2022-01-10 12:25:47 +00:00
|
|
|
if ErrorCategory.USER_CHECK not in enabled_errors:
|
2022-01-10 18:21:41 +00:00
|
|
|
return [], error
|
|
|
|
|
2022-09-08 08:58:44 -07:00
|
|
|
out_err = error.err | err
|
2022-01-10 21:29:12 -08:00
|
|
|
out_code = lax.select(error.err, error.code, code)
|
2022-03-08 14:41:18 +00:00
|
|
|
return [], Error(out_err, out_code, {**error.msgs, **msgs}, payload)
|
2022-01-10 21:29:12 -08:00
|
|
|
error_checks[assert_p] = assert_discharge_rule
|
|
|
|
|
|
|
|
|
|
|
|
## checkify api
|
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK'])
|
2022-01-10 18:21:41 +00:00
|
|
|
|
2022-01-10 12:25:47 +00:00
|
|
|
user_checks = frozenset({ErrorCategory.USER_CHECK})
|
2022-02-10 14:28:46 +00:00
|
|
|
nan_checks = frozenset({ErrorCategory.NAN})
|
|
|
|
index_checks = frozenset({ErrorCategory.OOB})
|
|
|
|
div_checks = frozenset({ErrorCategory.DIV})
|
|
|
|
float_checks = nan_checks | div_checks
|
|
|
|
automatic_checks = float_checks | index_checks
|
|
|
|
all_checks = automatic_checks | user_checks
|
2022-01-10 18:21:41 +00:00
|
|
|
|
2022-01-10 21:29:12 -08:00
|
|
|
Out = TypeVar('Out')
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
def checkify(fun: Callable[..., Out],
|
2022-02-08 20:41:19 +00:00
|
|
|
errors: FrozenSet[ErrorCategory] = user_checks
|
2022-01-18 22:22:57 -08:00
|
|
|
) -> Callable[..., Tuple[Error, Out]]:
|
2022-02-08 20:23:40 +00:00
|
|
|
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-08-24 09:49:51 -04:00
|
|
|
Run-time errors are either user-added :func:`~check` assertions, or
|
2022-01-10 12:25:47 +00:00
|
|
|
automatically added checks like NaN checks, depending on the ``errors``
|
|
|
|
argument.
|
|
|
|
|
2022-02-08 20:41:19 +00:00
|
|
|
The returned function will return an Error object `err` along with the output
|
|
|
|
of the original function. ``err.get()`` will either return ``None`` (if no
|
|
|
|
error occurred) or a string containing an error message. This error message
|
2022-02-10 14:28:46 +00:00
|
|
|
will correspond to the first error which occurred. ``err.throw()`` will raise
|
|
|
|
a ValueError with the error message if an error occurred.
|
|
|
|
|
2022-08-24 09:49:51 -04:00
|
|
|
By default only user-added :func:`~check` assertions are enabled. You can
|
2022-02-10 14:28:46 +00:00
|
|
|
enable automatic checks through the ``errors`` argument.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-02-10 14:28:46 +00:00
|
|
|
The automatic check sets which can be enabled, and when an error is generated:
|
2022-08-24 09:49:51 -04:00
|
|
|
- ``user_checks``: a :func:`~check` evaluated to False.
|
2022-02-10 14:28:46 +00:00
|
|
|
- ``nan_checks``: a floating-point operation generated a NaN value
|
2022-01-10 12:25:47 +00:00
|
|
|
as output.
|
2022-02-10 14:28:46 +00:00
|
|
|
- ``div_checks``: a division by zero.
|
|
|
|
- ``index_checks``: an index was out-of-bounds.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
2022-02-08 20:41:19 +00:00
|
|
|
Multiple categories can be enabled together by creating a `Set` (eg.
|
2022-02-10 14:28:46 +00:00
|
|
|
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be
|
|
|
|
re-combined (eg. ``errors=float_checks|user_checks``)
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
Args:
|
2022-08-24 09:49:51 -04:00
|
|
|
fun: Callable which can contain user checks (see :func:`~check`).
|
2022-02-08 20:41:19 +00:00
|
|
|
errors: A set of ErrorCategory values which defines the set of enabled
|
2022-02-10 14:28:46 +00:00
|
|
|
checks. By default only explicit ``checks`` are enabled
|
|
|
|
(``user_checks``). You can also for example enable NAN and
|
|
|
|
DIV errors by passing the ``float_checks`` set, or for
|
2022-02-08 20:41:19 +00:00
|
|
|
example combine multiple sets through set operations
|
2022-02-10 14:28:46 +00:00
|
|
|
(``float_checks | user_checks``)
|
2022-01-10 12:25:47 +00:00
|
|
|
Returns:
|
2022-02-08 20:23:40 +00:00
|
|
|
A function which accepts the same arguments as ``fun`` and returns as output
|
2022-02-10 14:28:46 +00:00
|
|
|
a pair where the first element is an ``Error`` value, representing the first
|
2022-08-24 09:49:51 -04:00
|
|
|
failed :func:`~check`, and the second element is the original output of ``fun``.
|
2022-01-10 12:25:47 +00:00
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> from jax.experimental import checkify
|
|
|
|
>>>
|
|
|
|
>>> @jax.jit
|
|
|
|
... def f(x):
|
|
|
|
... y = jnp.sin(x)
|
|
|
|
... return x+y
|
2022-02-10 14:28:46 +00:00
|
|
|
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
|
2022-01-10 12:25:47 +00:00
|
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
ValueError: nan generated by primitive sin
|
|
|
|
|
|
|
|
"""
|
2022-01-10 21:29:12 -08:00
|
|
|
@traceback_util.api_boundary
|
|
|
|
def checked_fun(*args, **kwargs):
|
|
|
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
|
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
2022-03-08 14:41:18 +00:00
|
|
|
(err, code, payload, out_flat), msgs = checkify_flat(f, errors, *args_flat)
|
2022-01-10 21:29:12 -08:00
|
|
|
out = tree_unflatten(out_tree(), out_flat)
|
2022-03-08 14:41:18 +00:00
|
|
|
return Error(err, code, msgs, payload), out
|
2022-01-10 21:29:12 -08:00
|
|
|
return checked_fun
|