mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

I initially wanted to upgrade to 1.15, but it seems to have a bug in how ternary expressions are type checked. For example, def f(x: int) -> str: ... def g(x: int) -> str: ... callback = f if ... else g # has type object!
1403 lines
56 KiB
Python
1403 lines
56 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import dataclasses
|
|
import functools
|
|
import itertools as it
|
|
from typing import TypeVar, Any, Union
|
|
|
|
import numpy as np
|
|
|
|
import jax.numpy as jnp
|
|
from jax import dtypes
|
|
from jax import lax
|
|
|
|
from jax.experimental import shard_map
|
|
from jax._src import api
|
|
from jax._src import api_util
|
|
from jax._src import ad_checkpoint
|
|
from jax._src import linear_util as lu
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import custom_derivatives
|
|
from jax._src import effects
|
|
from jax._src import pjit
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import tree_util as jtu
|
|
from jax._src.ad_util import SymbolicZero
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.tree_util import tree_flatten
|
|
from jax._src.tree_util import tree_map
|
|
from jax._src.tree_util import tree_unflatten
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
|
unzip3, weakref_lru_cache, HashableWrapper)
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Bool = Union[bool, Array]
|
|
Int = Union[int, Array]
|
|
ErrorCategory = type['JaxException']
|
|
Payload = list[Union[np.ndarray, Array]]
|
|
PyTreeDef = jtu.PyTreeDef
|
|
Out = TypeVar('Out')
|
|
|
|
## Utils
|
|
|
|
def popattr(obj, attrname):
|
|
val = getattr(obj, attrname)
|
|
delattr(obj, attrname)
|
|
return val
|
|
|
|
def setnewattr(obj, name, val):
|
|
sentinel = object()
|
|
assert getattr(obj, name, sentinel) is sentinel
|
|
setattr(obj, name, val)
|
|
|
|
# Concrete errors
|
|
|
|
class JaxException(Exception):
|
|
"""Python exception which can contain an error message with JAX run-time info."""
|
|
|
|
def __init__(self, traceback_info):
|
|
self.traceback_info = traceback_info
|
|
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
|
|
# self.with_traceback(self.traceback_info)
|
|
|
|
def __init_subclass__(cls):
|
|
jtu.register_pytree_node_class(cls)
|
|
|
|
def tree_flatten(self):
|
|
return ([], self.traceback_info)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
del payload
|
|
return cls(metadata)
|
|
|
|
def get_effect_type(self) -> ErrorEffect:
|
|
raise NotImplementedError
|
|
|
|
|
|
@functools.total_ordering
|
|
@dataclasses.dataclass(eq=True, frozen=True)
|
|
class ErrorEffect(effects.Effect):
|
|
error_type: type[JaxException]
|
|
shape_dtypes: tuple[api.ShapeDtypeStruct, ...]
|
|
|
|
def __lt__(self, other: ErrorEffect):
|
|
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
|
for sd in x.shape_dtypes)
|
|
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
|
|
return (unpack(self) < unpack(other))
|
|
|
|
effects.control_flow_allowed_effects.add_type(ErrorEffect)
|
|
effects.lowerable_effects.add_type(ErrorEffect)
|
|
|
|
class DivisionByZeroError(JaxException):
|
|
|
|
def __str__(self):
|
|
return 'division by zero'
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(DivisionByZeroError, ())
|
|
|
|
class NaNError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
|
|
def tree_flatten(self):
|
|
return ([], (self.traceback_info, self.prim))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, _):
|
|
return cls(*metadata)
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(NaNError, ())
|
|
|
|
def __str__(self):
|
|
return f'nan generated by primitive: {self.prim}.'
|
|
|
|
class OOBError(JaxException):
|
|
|
|
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
|
|
super().__init__(traceback_info)
|
|
self.prim = primitive_name
|
|
self.operand_shape = operand_shape
|
|
self._payload = payload
|
|
|
|
def tree_flatten(self):
|
|
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
return cls(*metadata, payload[0])
|
|
|
|
def __str__(self):
|
|
return (f'out-of-bounds indexing for array of '
|
|
f'shape {self.operand_shape}: '
|
|
f'index {self._payload[0]} is out of bounds for axis '
|
|
f'{self._payload[1]} with size {self._payload[2]}. ')
|
|
|
|
def get_effect_type(self):
|
|
return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))
|
|
|
|
class FailedCheckError(JaxException):
|
|
|
|
def __init__(self, traceback_info, fmt_string, *a, **k):
|
|
super().__init__(traceback_info)
|
|
self.fmt_string = fmt_string
|
|
self.args = a
|
|
self.kwargs = k
|
|
|
|
def tree_flatten(self):
|
|
return ((self.args, self.kwargs), # leaves
|
|
(self.traceback_info, self.fmt_string)) # treedef
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, payload):
|
|
args, kwargs = payload
|
|
return cls(*metadata, *args, **kwargs)
|
|
|
|
def __str__(self):
|
|
return (self.fmt_string.format(*self.args, **self.kwargs)
|
|
+ ' (`check` failed)')
|
|
|
|
def get_effect_type(self):
|
|
vals = jtu.tree_leaves((self.args, self.kwargs))
|
|
return ErrorEffect(
|
|
FailedCheckError,
|
|
tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
|
|
|
|
@dataclasses.dataclass
|
|
class BatchedError(JaxException):
|
|
error_mapping: dict[tuple[int, ...], JaxException]
|
|
|
|
def __post_init__(self):
|
|
traceback_info = list(self.error_mapping.values())[0].traceback_info
|
|
super().__init__(traceback_info)
|
|
|
|
|
|
def __str__(self):
|
|
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
|
|
for idx, e in self.error_mapping.items())
|
|
|
|
|
|
# Error Value
|
|
|
|
@jtu.register_pytree_node_class
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Error:
|
|
_pred: dict[ErrorEffect, Bool]
|
|
_code: dict[ErrorEffect, Int]
|
|
_metadata: dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
|
|
_payload: dict[ErrorEffect, Payload]
|
|
|
|
def get(self) -> str | None:
|
|
"""Returns error message if error happened, None if no error happened."""
|
|
exp = self.get_exception()
|
|
if exp is not None:
|
|
return str(exp)
|
|
return None
|
|
|
|
def get_exception(self) -> JaxException | None:
|
|
"""Returns Python exception if error happened, None if no error happened."""
|
|
if any(map(np.shape, self._pred.values())):
|
|
return self._get_batched_exception()
|
|
else:
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect]:
|
|
if min_code is None or code < min_code:
|
|
min_code = code
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
|
|
self._payload[cur_effect])
|
|
return None
|
|
|
|
def throw(self):
|
|
_check_error(self)
|
|
|
|
def __str__(self):
|
|
return f'Error({self.get()})'
|
|
|
|
# Internal helpers
|
|
|
|
def _get_batched_exception(self) -> BatchedError | None:
|
|
shape = np.shape(list(self._pred.values())[0])
|
|
error_mapping = {}
|
|
for idx in np.ndindex(*shape):
|
|
min_code = None
|
|
cur_effect = None
|
|
for error_effect, code in self._code.items():
|
|
if self._pred[error_effect][idx]: # type: ignore
|
|
if min_code is None or code[idx] < min_code: # type: ignore
|
|
min_code = code[idx] # type: ignore
|
|
cur_effect = error_effect
|
|
|
|
if cur_effect is not None:
|
|
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
|
|
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
|
|
error_mapping[idx] = jax_error
|
|
if error_mapping:
|
|
return BatchedError(error_mapping)
|
|
else:
|
|
return None
|
|
|
|
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
|
|
new_errs = {**self._pred, **{effect_type: pred}}
|
|
new_codes = {**self._code, **{effect_type: code}}
|
|
new_payload = {**self._payload, **{effect_type: payload}}
|
|
new_metadata = {**self._metadata, **metadata}
|
|
return Error(new_errs, new_codes, new_metadata, new_payload)
|
|
|
|
def _add_placeholder_effects(self, effects: set[ErrorEffect]):
|
|
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
|
|
new_err = self._pred.copy()
|
|
new_code = self._code.copy()
|
|
new_payload = self._payload.copy()
|
|
for effect in effects:
|
|
if effect not in self._pred.keys():
|
|
new_err[effect] = False
|
|
new_payload[effect] = list(
|
|
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
|
|
# The error value associated with this effect will never become True, so
|
|
# we don't need to set a meaningful code.
|
|
new_code[effect] = -1
|
|
return Error(new_err, new_code, self._metadata, new_payload)
|
|
|
|
def _replace(self, *args, **kwargs):
|
|
return dataclasses.replace(self, *args, **kwargs)
|
|
|
|
# PyTree methods
|
|
|
|
def tree_flatten(self):
|
|
return ((self._pred, self._code, self._payload), (self._metadata))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, data):
|
|
pred, code, payload = data
|
|
return cls(pred, code, metadata, payload)
|
|
|
|
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
|
|
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
|
|
|
|
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
|
|
code = next_code()
|
|
effect_type = new_error.get_effect_type()
|
|
new_payload, new_metadata = tree_flatten(new_error)
|
|
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
|
|
|
|
def update_error(error, pred, code, metadata, payload, effect_type):
|
|
err_of_type = error._pred.get(effect_type, False)
|
|
out_err = err_of_type | pred
|
|
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
|
|
cur_payload = error._payload.get(effect_type, None)
|
|
if cur_payload is not None:
|
|
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
|
|
else:
|
|
out_payload = payload
|
|
return error._update(effect_type, out_err, out_code, metadata, out_payload)
|
|
|
|
|
|
## Checkify transformation for plumbing functional error values.
|
|
|
|
@lu.transformation_with_aux2
|
|
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
|
|
error, out = f(*invals)
|
|
out_vals, out_tree = jtu.tree_flatten((error, out))
|
|
store.store((out_tree, set(error._pred.keys())))
|
|
return out_vals
|
|
|
|
def default_checkify_rule(primitive: core.Primitive, error: Error,
|
|
enabled_errors, *invals: core.Value,
|
|
**params: Any) -> tuple[Error, Sequence[core.Value]]:
|
|
"""Default rule for primitives in `checkify` interpreter."""
|
|
if 'call_jaxpr' not in params:
|
|
# Default non-HOP case: just call primitive and don't update error.
|
|
return error, primitive.bind(*invals, **params)
|
|
|
|
# Code below handles call- and map-primitives, by recursively calling
|
|
# checkify_jaxpr.
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
num_error_vals = len(err_vals)
|
|
if 'donated_invars' in params:
|
|
params = dict(params, donated_invars=(*[False]*num_error_vals,
|
|
*params['donated_invars']))
|
|
|
|
# call_jaxpr handling
|
|
call_jaxpr = params.pop('call_jaxpr')
|
|
if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p
|
|
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
else:
|
|
jaxpr, consts = call_jaxpr, ()
|
|
consts_ = tuple(HashableWrapper(c) for c in consts)
|
|
partial_checkify = lu.hashable_partial(
|
|
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
|
|
jaxpr, consts_, enabled_errors, err_tree)
|
|
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
|
|
partial_checkify)
|
|
|
|
# map-specific params handling.
|
|
if isinstance(primitive, core.MapPrimitive):
|
|
# Update `in_axes` and `out_axes_thunk` params for map primitive.
|
|
out_val_axes = params.pop('out_axes')
|
|
|
|
@as_hashable_function(closure=out_val_axes)
|
|
def out_axes_thunk():
|
|
out_err_num = metadata()[0].num_leaves - len(out_val_axes)
|
|
return (*(0,)*out_err_num, *out_val_axes)
|
|
|
|
params = dict(params,
|
|
in_axes=(*(None,)*num_error_vals, *params['in_axes']),
|
|
out_axes_thunk=out_axes_thunk)
|
|
|
|
all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params)
|
|
|
|
out_tree, _ = metadata()
|
|
error, out_vals = tree_unflatten(out_tree, all_vals)
|
|
if isinstance(primitive, core.MapPrimitive):
|
|
error = _reduce_any_error(error)
|
|
return error, out_vals
|
|
|
|
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
|
error: Error, *args) -> tuple[Error, list[core.Value]]:
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
|
|
enabled_errors, err_tree, *err_vals, *args)
|
|
|
|
def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
|
enabled_errors, err_tree: PyTreeDef,
|
|
*args: core.Value) -> tuple[Error, list[Any]]:
|
|
env: dict[core.Var, Any] = {}
|
|
err_vals, in_args = split_list(args, [err_tree.num_leaves])
|
|
error = jtu.tree_unflatten(err_tree, err_vals)
|
|
|
|
last_used = core.last_used(jaxpr)
|
|
|
|
def read_env(var: core.Atom):
|
|
if isinstance(var, core.Literal):
|
|
return var.val
|
|
return env[var]
|
|
|
|
def write_env(var: core.Var, val: Any):
|
|
env[var] = val
|
|
|
|
map(write_env, jaxpr.constvars, consts)
|
|
map(write_env, jaxpr.invars, in_args)
|
|
|
|
# interpreter loop
|
|
for eqn in jaxpr.eqns:
|
|
invals = map(read_env, eqn.invars)
|
|
checkify_rule = error_checks.get(
|
|
eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive))
|
|
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
|
with source_info_util.user_context(eqn.source_info.traceback,
|
|
name_stack=name_stack):
|
|
error, outvals = checkify_rule(error, enabled_errors,
|
|
*invals, **eqn.params)
|
|
if eqn.primitive.multiple_results:
|
|
map(write_env, eqn.outvars, outvals)
|
|
else:
|
|
write_env(eqn.outvars[0], outvals)
|
|
core.clean_up_dead_vars(eqn, env, last_used)
|
|
|
|
return error, map(read_env, jaxpr.outvars)
|
|
|
|
def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
|
|
err_tree, *args):
|
|
consts = tuple(c.x for c in hashable_consts)
|
|
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
|
|
|
|
@lu.transformation_with_aux2
|
|
def flatten_fun_output(f, store, *args):
|
|
ans = f(*args)
|
|
ans, out_tree = tree_flatten(ans)
|
|
store.store(out_tree)
|
|
return ans
|
|
|
|
|
|
def _reduce_any_error(error: Error):
|
|
out_error = init_error
|
|
for error_effect in error._pred.keys():
|
|
errs, codes, payloads = (error._pred[error_effect],
|
|
error._code[error_effect],
|
|
error._payload[error_effect])
|
|
reduced_idx = jnp.argsort(errs)[-1]
|
|
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
|
|
(errs, codes, payloads))
|
|
out_error = out_error._update(error_effect, pred, code, {}, payload)
|
|
|
|
out_error = out_error._replace(_metadata=error._metadata)
|
|
return out_error
|
|
|
|
## check_p primitive
|
|
|
|
check_p = core.Primitive('check')
|
|
check_p.multiple_results = True # zero results
|
|
|
|
|
|
def _pp_check(eqn, context, settings) -> core.pp.Doc:
|
|
annotation = (source_info_util.summarize(eqn.source_info)
|
|
if settings.source_info else None)
|
|
name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
|
|
if settings.name_stack else None)
|
|
trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
|
|
if k != "err_tree")
|
|
rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
|
|
core.pp_kv_pairs(trimmed_params, context, settings),
|
|
core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
|
return core.pp.concat([core.pp.text("", annotation), *rhs])
|
|
|
|
core.pp_eqn_rules[check_p] = _pp_check
|
|
|
|
# TODO(lenamartens): inherit from Exception instead of ValueError.
|
|
class JaxRuntimeError(ValueError):
|
|
pass
|
|
|
|
@check_p.def_impl
|
|
def check_impl(*args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
error = tree_unflatten(err_tree, args)
|
|
exc = error.get_exception()
|
|
if exc:
|
|
filtered_tb = traceback_util.filter_traceback(
|
|
exc.traceback_info.as_python_traceback())
|
|
exc.with_traceback(filtered_tb)
|
|
raise JaxRuntimeError(str(exc)) from exc
|
|
return []
|
|
|
|
@check_p.def_effectful_abstract_eval
|
|
def check_abstract_eval(*args, err_tree, debug):
|
|
del debug
|
|
return [], set(tree_unflatten(err_tree, args)._pred.keys())
|
|
|
|
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
|
functionalization_error = ValueError(
|
|
'Cannot abstractly evaluate a checkify.check which was not'
|
|
' functionalized. This probably means you tried to stage'
|
|
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
|
' through `checkify.checkify`.'
|
|
)
|
|
|
|
def check_lowering_rule(ctx, *args, err_tree, debug):
|
|
if debug:
|
|
# NOOP (check will only trigger when discharged)
|
|
return []
|
|
if not config.xla_runtime_errors.value:
|
|
raise functionalization_error
|
|
|
|
out_op, _, _ = mlir.emit_python_callback(
|
|
ctx, callback=functools.partial(python_err, err_tree),
|
|
token=None,
|
|
operands=args,
|
|
operand_avals=list(ctx.avals_in),
|
|
result_avals=list(ctx.avals_out),
|
|
has_side_effect=True)
|
|
return out_op
|
|
|
|
def check_lowering_rule_unsupported(*a, debug, **k):
|
|
if debug:
|
|
return []
|
|
raise functionalization_error
|
|
|
|
def python_err(err_tree, *args):
|
|
error = tree_unflatten(err_tree, args)
|
|
_check_error(error)
|
|
return []
|
|
|
|
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
|
|
platform='tpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='cpu')
|
|
mlir.register_lowering(check_p, check_lowering_rule,
|
|
platform='gpu')
|
|
|
|
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
|
|
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
|
|
if dim is not batching.not_mapped)
|
|
batched_args = (batching.bdim_at_front(a, d, size)
|
|
for a, d in zip(batched_args, batch_dims))
|
|
err = tree_unflatten(err_tree, batched_args)
|
|
_check_error(err, debug=debug)
|
|
return [], []
|
|
batching.primitive_batchers[check_p] = check_batching_rule
|
|
|
|
def check_jvp_rule(primals, _, *, err_tree, debug):
|
|
# Check primals, discard tangents.
|
|
check_p.bind(*primals, err_tree=err_tree, debug=debug)
|
|
return [], []
|
|
ad.primitive_jvps[check_p] = check_jvp_rule
|
|
|
|
## checkify rules
|
|
|
|
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
|
error_checks: dict[core.Primitive, ErrorCheckRule] = {}
|
|
|
|
|
|
def get_traceback():
|
|
return source_info_util.current().traceback
|
|
|
|
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
|
out = prim.bind(*in_vals, **params)
|
|
err = check_nans(prim, error, enabled_errors, out)
|
|
return err, out
|
|
|
|
def check_nans(prim, error, enabled_errors, out):
|
|
if NaNError not in enabled_errors:
|
|
return error
|
|
|
|
def isnan(x):
|
|
if jnp.issubdtype(x.dtype, dtypes.prng_key):
|
|
return False
|
|
return jnp.any(jnp.isnan(x))
|
|
|
|
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
|
|
if prim.multiple_results else isnan(out))
|
|
return assert_func(error, any_nans, NaNError(get_traceback(), prim.name))
|
|
|
|
|
|
# All primitives which can generate a NaN.
|
|
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
|
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
|
|
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
|
|
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
|
|
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
|
|
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
|
|
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
|
|
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
|
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
|
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
|
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
|
lax.reduce_sum_p, lax.reduce_window_p,
|
|
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
|
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
|
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
|
|
|
|
for _prim in nan_primitives:
|
|
error_checks[_prim] = functools.partial(nan_error_check, _prim)
|
|
|
|
|
|
def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
|
|
out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
operand_dims = np.array(operand.shape)
|
|
slice_sizes = np.array(slice_sizes)
|
|
start_indices = jnp.array(start_indices)
|
|
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)
|
|
|
|
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
|
|
error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_slice", operand.shape, payload))
|
|
return error, out
|
|
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check
|
|
|
|
def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
|
|
out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
operand_dims = np.array(operand.shape)
|
|
update_dims = np.array(update.shape)
|
|
start_indices = jnp.array(start_indices)
|
|
oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)
|
|
|
|
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
|
|
error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_update_slice", operand.shape, payload))
|
|
return error, out
|
|
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check
|
|
|
|
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
indices_are_sorted, mode, fill_value):
|
|
out = lax.gather_p.bind(
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
# compare to OOB masking logic in lax._gather_translation_rule
|
|
dnums = dimension_numbers
|
|
operand_dims = np.array(operand.shape)
|
|
num_batch_dims = len(start_indices.shape) - 1
|
|
|
|
upper_bound = operand_dims[np.array(dnums.start_index_map)]
|
|
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
|
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
|
|
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
|
|
|
|
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
|
|
error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "gather", operand.shape, payload))
|
|
return error, out
|
|
error_checks[lax.gather_p] = gather_error_check
|
|
|
|
def div_error_check(error, enabled_errors, x, y):
|
|
"""Checks for division by zero and NaN."""
|
|
if DivisionByZeroError in enabled_errors:
|
|
any_zero = jnp.any(jnp.equal(y, 0))
|
|
error = assert_func(error, any_zero, DivisionByZeroError(get_traceback()))
|
|
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
|
error_checks[lax.div_p] = div_error_check
|
|
|
|
def oob_payload(oob_mask, indices, dims_map, operand_shape):
|
|
# Get first OOB index, axis and axis size so it can be added to the error msg.
|
|
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
|
|
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
|
|
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
|
|
oob_axis_size = jnp.array(operand_shape)[oob_axis]
|
|
oob_index = jnp.ravel(indices)[flat_idx]
|
|
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
|
|
return payload
|
|
|
|
def scatter_oob(operand, indices, updates, dnums):
|
|
# Ref: see clamping code used in scatter_translation_rule
|
|
slice_sizes = []
|
|
pos = 0
|
|
for i in range(len(operand.shape)):
|
|
if i in dnums.inserted_window_dims:
|
|
slice_sizes.append(1)
|
|
else:
|
|
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
|
pos += 1
|
|
|
|
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
|
for i in dnums.scatter_dims_to_operand_dims],
|
|
np.int64)
|
|
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
|
(len(indices.shape) - 1,))
|
|
|
|
lower_oob = jnp.less(indices, 0)
|
|
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
|
|
oob_mask = jnp.logical_or(lower_oob, upper_oob)
|
|
payload = oob_payload(oob_mask, indices,
|
|
dnums.scatter_dims_to_operand_dims, operand.shape)
|
|
return jnp.any(oob_mask), payload
|
|
|
|
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
|
|
*, update_jaxpr, update_consts, dimension_numbers,
|
|
indices_are_sorted, unique_indices, mode):
|
|
"""Checks if indices are within bounds and update does not generate NaN."""
|
|
out = prim.bind(
|
|
operand, indices, updates, update_jaxpr=update_jaxpr,
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
|
mode=mode)
|
|
|
|
if OOBError not in enabled_errors:
|
|
return error, out
|
|
|
|
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
|
|
oob_error = OOBError(get_traceback(), prim.name, operand.shape, payload)
|
|
error = assert_func(error, out_of_bounds, oob_error)
|
|
error = check_nans(prim, error, enabled_errors, out)
|
|
return error, out
|
|
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
|
|
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_add_p)
|
|
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_mul_p)
|
|
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_min_p)
|
|
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
|
|
lax.scatter_max_p)
|
|
|
|
# HOP error check rules
|
|
|
|
@weakref_lru_cache
|
|
def jaxpr_to_checkify_jaxpr(
|
|
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
|
|
*flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
|
|
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
|
|
jaxpr.consts, enabled_errors,
|
|
err_tree)
|
|
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
|
|
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
|
|
|
|
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
|
|
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
|
out_tree, error_effects = metadata()
|
|
return checked_jaxpr, out_tree, error_effects
|
|
|
|
def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
|
|
# Get the error-effects out of all branches so the cond can be called with
|
|
# a merged error with all these effects.
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
in_avals = map(core.get_aval, [*err_vals, *ops])
|
|
def get_error_effects_from_jaxpr(jxpr):
|
|
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
|
|
*in_avals)
|
|
return effects
|
|
effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches]
|
|
merged_error = error._add_placeholder_effects(set().union(*effects))
|
|
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
|
|
|
# Update branch jaxprs to be checkified jaxprs.
|
|
in_avals = map(core.get_aval, [*err_vals, *ops])
|
|
new_branches, out_trees, _ = unzip3(
|
|
jaxpr_to_checkify_jaxpr(
|
|
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
|
|
|
|
err_and_outs = lax.cond_p.bind(
|
|
index, *err_vals, *ops,
|
|
branches=tuple(new_branches))
|
|
|
|
# we need to merge metadata across out_trees (a tuple)
|
|
err0, out = tree_unflatten(out_trees[0], err_and_outs)
|
|
merged_metadata = err0._metadata
|
|
for tr in out_trees[1:]:
|
|
err, _ = tree_unflatten(tr, err_and_outs)
|
|
merged_metadata = {**merged_metadata, **err._metadata}
|
|
return err0._replace(_metadata=merged_metadata), out
|
|
error_checks[lax.cond_p] = cond_error_check
|
|
|
|
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
|
num_consts, num_carry, linear, unroll, _split_transpose):
|
|
|
|
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
|
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
|
|
# Query body effects to create a merged error containing all effects (such
|
|
# that in and out carried error are of the same type).
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
|
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *new_in_aval)
|
|
|
|
merged_error = error._add_placeholder_effects(effects)
|
|
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
|
|
|
# Create checked-jaxpr, with the needed pre-processing on the inputs.
|
|
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
|
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *new_in_aval)
|
|
|
|
tomove = ([False] * len(err_vals) + [True] * len(consts)
|
|
+ [False] * (len(carry) + len(xs)))
|
|
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
|
|
new_in_flat = [*consts, *err_vals, *carry, *xs]
|
|
new_linear = (*[False] * len(err_vals), *linear)
|
|
err_and_out = lax.scan_p.bind(
|
|
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
|
num_consts=len(consts), num_carry=len(carry)+len(err_vals),
|
|
linear=new_linear, unroll=unroll, _split_transpose=_split_transpose)
|
|
err, out = tree_unflatten(out_tree, err_and_out)
|
|
return err, out
|
|
|
|
error_checks[lax.scan_p] = scan_error_check
|
|
|
|
def checkify_while_body_jaxpr(
|
|
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
|
|
enabled_errors, error: Error,
|
|
c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
|
|
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
|
body_f = core.jaxpr_as_fun(body_jaxpr)
|
|
def new_body_f(*c_consts_and_vals):
|
|
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
|
|
out = body_f(*vals)
|
|
# This checks if the next cond application will error
|
|
_ = cond_f(*c_consts, *out)
|
|
return out
|
|
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
|
|
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
|
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
|
*body_jaxpr.in_avals])
|
|
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
err_vals = map(core.get_aval, err_vals)
|
|
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
|
|
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
|
|
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
|
return jaxpr, out_tree, error_effects
|
|
|
|
|
|
@weakref_lru_cache
|
|
def ignore_error_output_jaxpr(jaxpr, num_error_vals: int):
|
|
"""Constructs a checked jaxpr which does not output its error value."""
|
|
consts = jaxpr.consts
|
|
jaxpr = jaxpr.jaxpr
|
|
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
|
|
return core.ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
|
cond_jaxpr, body_nconsts, body_jaxpr):
|
|
if cond_jaxpr.out_avals[0].shape:
|
|
# TODO(lenamartens, sharadmv): support batched while.
|
|
raise ValueError('Checkify does not support batched while-loops '
|
|
'(checkify-of-vmap-of-while). \nHint: if possible, move '
|
|
'the vmap to the outer level to get '
|
|
'vmap-of-checkify-of-while.')
|
|
|
|
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
|
# Check if the first cond application will error.
|
|
error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)
|
|
|
|
_, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
|
|
enabled_errors, error,
|
|
cond_nconsts)
|
|
# merged error!
|
|
error = error._add_placeholder_effects(error_effects)
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
|
|
cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
|
|
num_error_vals = len(err_vals)
|
|
to_move = ([False] * num_error_vals + [True] * cond_nconsts
|
|
+ [True] * body_nconsts + [False] * len(carry))
|
|
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
|
|
|
cond_in_flat = [*err_vals, *c_consts, *carry]
|
|
cond_in_flat = map(core.get_aval, cond_in_flat)
|
|
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
|
|
err_tree, *cond_in_flat)
|
|
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
|
|
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
|
|
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
|
|
|
|
new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
|
|
all_out_vals = lax.while_p.bind(
|
|
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
|
body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
|
|
# body_out_tree will have all the metadata of cond because it executes a cond!
|
|
error, out = tree_unflatten(body_out_tree, all_out_vals)
|
|
return error, out
|
|
error_checks[lax.while_p] = while_loop_error_check
|
|
|
|
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
|
in_shardings, out_shardings,
|
|
in_layouts, out_layouts,
|
|
resource_env, donated_invars, name, inline, keep_unused,
|
|
compiler_options_kvs):
|
|
# jaxpr to checked_jaxpr
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
new_vals_in = [*err_vals, *vals_in]
|
|
in_avals = tuple(map(core.get_aval, new_vals_in))
|
|
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
|
err_tree, *in_avals)
|
|
|
|
# Update pjit params to account for extra error values.
|
|
num_error_vals = len(err_vals)
|
|
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
|
|
|
sharding = sharding_impls.UNSPECIFIED
|
|
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
|
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
|
new_in_layouts = (*[None] * num_error_vals, *in_layouts)
|
|
new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
|
|
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
|
|
|
|
err_and_out = pjit.pjit_p.bind(
|
|
*new_vals_in,
|
|
jaxpr=checked_jaxpr,
|
|
in_shardings=new_in_shardings,
|
|
out_shardings=new_out_shardings,
|
|
in_layouts=new_in_layouts,
|
|
out_layouts=new_out_layouts,
|
|
resource_env=resource_env,
|
|
donated_invars=new_donated_invars,
|
|
name=name,
|
|
inline=inline,
|
|
keep_unused=keep_unused,
|
|
compiler_options_kvs=compiler_options_kvs,
|
|
)
|
|
return tree_unflatten(out_tree, err_and_out)
|
|
error_checks[pjit.pjit_p] = pjit_error_check
|
|
|
|
|
|
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
new_vals_in = [*err_vals, *vals_in]
|
|
in_avals = tuple(map(core.get_aval, new_vals_in))
|
|
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
|
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
|
|
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
|
|
err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr,
|
|
**params)
|
|
return tree_unflatten(out_tree, err_and_out)
|
|
error_checks[ad_checkpoint.remat_p] = remat_error_check
|
|
|
|
|
|
def shard_map_error_check(
|
|
error: Error, enabled_errors, *vals_in,
|
|
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
|
|
):
|
|
if (mesh := kwargs.get('mesh')) is None:
|
|
raise ValueError('Mesh must be provided for shard_map with checkify.')
|
|
|
|
err_vals, err_tree = jtu.tree_flatten(error)
|
|
num_error_vals = len(err_vals)
|
|
# Replicated sharding for in errors.
|
|
new_in_names = (*([{}] * num_error_vals), *in_names)
|
|
new_vals_in = [*err_vals, *vals_in]
|
|
in_avals = list(map(core.get_aval, new_vals_in))
|
|
for i, v in enumerate(in_avals):
|
|
if not (sharder := core.shard_aval_handlers.get(type(v))):
|
|
raise ValueError(f'Unsupported aval type: {type(v)}')
|
|
in_avals[i] = sharder(mesh, new_in_names[i], v)
|
|
|
|
with (core.extend_axis_env_nd(mesh.shape.items()),
|
|
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh))):
|
|
# jaxpr to checked_jaxpr
|
|
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
|
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
|
|
)
|
|
num_out_error_vals = out_tree.num_leaves - len(out_names)
|
|
|
|
def expand_errors_leading_dim(*xs):
|
|
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
|
|
errs, outs = split_list(outs, [num_out_error_vals])
|
|
errs = [lax.expand_dims(e, [0]) for e in errs]
|
|
return *errs, *outs
|
|
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(expand_errors_leading_dim,
|
|
debug_info=checked_jaxpr.jaxpr.debug_info),
|
|
checked_jaxpr.in_avals
|
|
)
|
|
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
# Update shard_map params to account for extra error values.
|
|
# Use fully sharded partitioning for out errors.
|
|
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
|
|
subfun = lu.hashable_partial(
|
|
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
|
|
checked_jaxpr.jaxpr, checked_jaxpr.consts
|
|
)
|
|
new_params = dict(
|
|
jaxpr=checked_jaxpr.jaxpr,
|
|
in_names=new_in_names,
|
|
out_names=new_out_names,
|
|
**kwargs,
|
|
)
|
|
_, new_params = shard_map.shard_map_p.get_bind_params(new_params)
|
|
|
|
err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params)
|
|
return tree_unflatten(out_tree, err_and_out)
|
|
error_checks[shard_map.shard_map_p] = shard_map_error_check
|
|
|
|
def custom_jvp_call_rule(in_err: Error,
|
|
enabled_errors: set, *in_vals, num_consts,
|
|
jvp_jaxpr_fun: lu.WrappedFun,
|
|
call_jaxpr: core.ClosedJaxpr, **params):
|
|
# The types to have in mind are:
|
|
# jvp : (a -> b) -> (a, T a) -> (b, T b)
|
|
# checkify : (a -> b) -> a -> Err b
|
|
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
|
|
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
|
|
# where the other Err' components are trivial (of float0 dtype).
|
|
# Semantically, we don't add checks to the JVP rule. To check the result of a
|
|
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
|
|
# just forwards the input error and code (and trivial tangents) to the output.
|
|
err_vals, err_tree = jtu.tree_flatten(in_err)
|
|
partial_checkify = lu.wrap_init(
|
|
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
|
|
call_jaxpr.consts, enabled_errors, err_tree),
|
|
debug_info=call_jaxpr.jaxpr.debug_info)
|
|
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
|
|
partial_checkify)
|
|
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
|
|
jvp, jvp_out_tree = flatten_fun_output(jvp)
|
|
all_outs = custom_derivatives.custom_jvp_call_p.bind(
|
|
partial_checkify, jvp, *err_vals, *in_vals, **params)
|
|
fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree)
|
|
if fst:
|
|
err_and_out_tree, _ = out_metadata
|
|
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
|
|
else:
|
|
err_vals, out_vals = split_list(all_outs, [len(err_vals)])
|
|
# forward input error to output
|
|
out_err = jtu.tree_unflatten(err_tree, err_vals)
|
|
return out_err, out_vals
|
|
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
|
|
|
|
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
|
|
# outputs that checkify adds (just forwarding the error data's primal and
|
|
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
|
|
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
|
|
# Adding another layer of lu.transformation was tricky, though maybe doable.
|
|
def lift_jvp(num_errs: int, num_consts: int,
|
|
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
|
|
def jvp(*xs):
|
|
n, ragged = divmod(len(xs), 2)
|
|
assert not ragged
|
|
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
|
|
zeros = [type(t) is SymbolicZero for t in tangents]
|
|
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
|
|
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
|
|
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
|
|
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
|
|
nz_out_tangents_ = iter(nz_out_tangents)
|
|
out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
|
|
if z else next(nz_out_tangents_)
|
|
for p, z in zip(out_primals, out_zeros)]
|
|
assert next(nz_out_tangents_, None) is None
|
|
primal_errs = xs[num_consts:num_consts+num_errs]
|
|
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
|
|
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
|
|
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
|
|
|
|
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
|
|
fun_jaxpr: core.ClosedJaxpr,
|
|
fwd_jaxpr_thunk, num_consts,
|
|
bwd: lu.WrappedFun, out_trees,
|
|
symbolic_zeros: bool):
|
|
err_vals, err_tree = jtu.tree_flatten(in_err)
|
|
num_errs = err_tree.num_leaves
|
|
checkified_fun = lu.wrap_init(
|
|
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
|
|
fun_jaxpr.consts, enabled_errors, err_tree),
|
|
debug_info=fun_jaxpr.jaxpr.debug_info)
|
|
checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
|
|
checkified_fun)
|
|
|
|
def checkified_fwd(*args):
|
|
# TODO(lenamartens, sharadmv): why not checkify here?
|
|
xs, zeros = args[::2], args[1::2]
|
|
xs, zeros = xs[num_errs:], zeros[num_errs:]
|
|
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
|
|
xs_without_consts = xs[num_consts:]
|
|
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
|
|
|
|
# TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr
|
|
checkified_fwd_wrapped = lu.wrap_init(checkified_fwd,
|
|
debug_info=fun_jaxpr.jaxpr.debug_info)
|
|
bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)),
|
|
debug_info=bwd.debug_info)
|
|
checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped)
|
|
all_outs = custom_derivatives.custom_vjp_call_p.bind(
|
|
checkified_fun, checkified_fwd_wrapped,
|
|
bwd_, *err_vals, *in_vals, out_trees=out_trees,
|
|
symbolic_zeros=symbolic_zeros)
|
|
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
|
|
if fst:
|
|
err_and_out_tree, _ = out_metadata
|
|
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
|
|
else:
|
|
out_err, out_vals = in_err, all_outs
|
|
return out_err, out_vals
|
|
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule
|
|
|
|
|
|
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
|
|
del debug
|
|
new_error = tree_unflatten(err_tree, args)
|
|
# Split up new_error into error to be functionalized if it's included in
|
|
# enabled_errors (=discharged_error) and an error to be defunctionalized if
|
|
# it's not included (=recharged_error)
|
|
discharged_error = error
|
|
recharged_error = init_error
|
|
for error_effect in new_error._pred.keys():
|
|
pred = new_error._pred[error_effect]
|
|
code = new_error._code[error_effect]
|
|
payload = new_error._payload[error_effect]
|
|
if error_effect.error_type in enabled_errors:
|
|
discharged_error = update_error(discharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
else:
|
|
recharged_error = update_error(recharged_error, pred, code, {}, payload,
|
|
error_effect)
|
|
|
|
discharged_error = discharged_error._replace(
|
|
_metadata={**new_error._metadata, **discharged_error._metadata})
|
|
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
|
|
# TODO(lenamartens): we actually need to recharge, but this would be a
|
|
# breaking API change so leaving for a follow-up.
|
|
# check_error(recharged_error)
|
|
return discharged_error, []
|
|
error_checks[check_p] = check_discharge_rule
|
|
|
|
|
|
## checkify public api
|
|
|
|
user_checks = frozenset({FailedCheckError})
|
|
nan_checks = frozenset({NaNError})
|
|
index_checks = frozenset({OOBError})
|
|
div_checks = frozenset({DivisionByZeroError})
|
|
float_checks = nan_checks | div_checks
|
|
automatic_checks = float_checks | index_checks
|
|
all_checks = automatic_checks | user_checks
|
|
|
|
|
|
def checkify(f: Callable[..., Out],
|
|
errors: frozenset[ErrorCategory] = user_checks
|
|
) -> Callable[..., tuple[Error, Out]]:
|
|
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
|
|
|
Run-time errors are either user-added :func:`~check` assertions, or
|
|
automatically added checks like NaN checks, depending on the ``errors``
|
|
argument.
|
|
|
|
The returned function will return an Error object `err` along with the output
|
|
of the original function. ``err.get()`` will either return ``None`` (if no
|
|
error occurred) or a string containing an error message. This error message
|
|
will correspond to the first error which occurred. ``err.throw()`` will raise
|
|
a ValueError with the error message if an error occurred.
|
|
|
|
By default only user-added :func:`~check` assertions are enabled. You can
|
|
enable automatic checks through the ``errors`` argument.
|
|
|
|
The automatic check sets which can be enabled, and when an error is generated:
|
|
- ``user_checks``: a :func:`~check` evaluated to False.
|
|
- ``nan_checks``: a floating-point operation generated a NaN value
|
|
as output.
|
|
- ``div_checks``: a division by zero.
|
|
- ``index_checks``: an index was out-of-bounds.
|
|
|
|
Multiple categories can be enabled together by passing in an error `Set` (eg.
|
|
``errors=nan_checks``). Multiple sets can be re-combined (eg.
|
|
``errors=float_checks|user_checks``)
|
|
|
|
Args:
|
|
fun: Callable which can contain user checks (see :func:`~check`).
|
|
errors: A set of ErrorCategory values which defines the set of enabled
|
|
checks. By default only explicit ``checks`` are enabled
|
|
(``user_checks``). You can also for example enable NAN and
|
|
DIV errors by passing the ``float_checks`` set, or for
|
|
example combine multiple sets through set operations
|
|
(``float_checks | user_checks``)
|
|
Returns:
|
|
A function which accepts the same arguments as ``fun`` and returns as output
|
|
a pair where the first element is an ``Error`` value, representing the first
|
|
failed :func:`~check`, and the second element is the original output of
|
|
``fun``.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>>
|
|
>>> @jax.jit
|
|
... def f(x):
|
|
... y = jnp.sin(x)
|
|
... return x+y
|
|
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
|
|
"""
|
|
@traceback_util.api_boundary
|
|
def checked_fun(*args, **kwargs):
|
|
# close over all arguments so they're not turned into abstract values.
|
|
in_tree = jtu.tree_structure(((), {}))
|
|
closed_f = lambda: f(*args, **kwargs)
|
|
# stage:
|
|
debug = api_util.debug_info("checkify", f, args, kwargs)
|
|
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
|
|
debug_info=debug),
|
|
in_tree)
|
|
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
|
|
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
|
# checkify:
|
|
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
|
|
return error, jtu.tree_unflatten(out_tree(), out_flat)
|
|
return checked_fun
|
|
|
|
def check(pred: Bool, msg: str,
|
|
*fmt_args,
|
|
debug: bool = False,
|
|
**fmt_kwargs,
|
|
) -> None:
|
|
"""Check a predicate, add an error with msg if predicate is False.
|
|
|
|
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
|
Before staging a function with checks, :func:`~checkify` it!
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added. Can be a format string.
|
|
debug: Whether to turn on debugging mode. If True, check will be removed
|
|
during execution. If False, the the check must be functionalized using
|
|
checkify.checkify.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "{x} needs to be positive!", x=x)
|
|
... return 1/x
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(-3.)
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
|
|
|
|
"""
|
|
_check(pred, msg, debug, *fmt_args, **fmt_kwargs)
|
|
|
|
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
|
|
if not is_scalar_pred(pred):
|
|
prim_name = 'debug_check' if debug else 'check'
|
|
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
|
|
for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
|
|
if not isinstance(arg, (Array, np.ndarray)):
|
|
raise TypeError('Formatting arguments to checkify.check need to be '
|
|
'PyTrees of arrays, but got '
|
|
f'{arg!r} of type {type(arg)}.')
|
|
new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs)
|
|
error = assert_func(init_error, jnp.logical_not(pred), new_error)
|
|
_check_error(error, debug=debug)
|
|
|
|
def _check_error(error, *, debug=False):
|
|
if any(map(np.shape, error._pred.values())):
|
|
error = _reduce_any_error(error)
|
|
err_args, tree_def = tree_flatten(error)
|
|
|
|
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
|
|
|
|
|
|
def is_scalar_pred(pred) -> bool:
|
|
return (isinstance(pred, bool) or
|
|
isinstance(pred, Array) and pred.shape == () and
|
|
pred.dtype == jnp.dtype('bool'))
|
|
|
|
|
|
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
|
"""Check a predicate when running under checkify, otherwise is a no-op.
|
|
|
|
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
|
|
otherwise the check will be dropped.
|
|
|
|
Args:
|
|
pred: if False, a FailedCheckError error is added.
|
|
msg: error message if error is added.
|
|
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
|
`msg`, eg.:
|
|
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
|
|
Note that these arguments can be traced values allowing you to add
|
|
run-time values to the error message.
|
|
Note that tracking these run-time arrays will increase your memory usage,
|
|
even if no error happens.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.debug_check(x!=0, "cannot be zero!")
|
|
... return x
|
|
>>> _ = f(0) # running without checkify means no debug_check is run.
|
|
>>> checked_f = checkify.checkify(f)
|
|
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
|
|
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.checkify.JaxRuntimeError: cannot be zero!
|
|
|
|
"""
|
|
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
|
|
|
|
|
|
def check_error(error: Error) -> None:
|
|
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
|
|
|
The semantics of this function are equivalent to:
|
|
|
|
>>> def check_error(err: Error) -> None:
|
|
... err.throw() # can raise ValueError
|
|
|
|
But unlike that implementation, ``check_error`` can be functionalized using
|
|
the :func:`~checkify` transformation.
|
|
|
|
This function is similar to :func:`~check` but with a different signature: whereas
|
|
:func:`~check` takes as arguments a boolean predicate and a new error message
|
|
string, this function takes an ``Error`` value as argument. Both :func:`~check`
|
|
and this function raise a Python Exception on failure (a side-effect), and
|
|
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
|
|
:func:`~jax.lax.scan`, etc. Both also can
|
|
be functionalized by using :func:`~checkify`.
|
|
|
|
But unlike :func:`~check`, this function is like a direct inverse of
|
|
:func:`~checkify`:
|
|
whereas :func:`~checkify` takes as input a function which
|
|
can raise a Python
|
|
Exception and produces a new function without that effect but which produces
|
|
an ``Error`` value as output, this ``check_error`` function can accept an
|
|
``Error`` value as input and can produce the side-effect of raising an
|
|
Exception. That is, while :func:`~checkify` goes from
|
|
functionalizable Exception
|
|
effect to error value, this ``check_error`` goes from error value to
|
|
functionalizable Exception effect.
|
|
|
|
``check_error`` is useful when you want to turn checks represented by an
|
|
``Error`` value (produced by functionalizing ``checks`` via
|
|
:func:`~checkify`) back into Python Exceptions.
|
|
|
|
Args:
|
|
error: Error to check.
|
|
|
|
For example, you might want to functionalize part of your program through
|
|
checkify, stage out your functionalized code through :func:`~jax.jit`, then
|
|
re-inject your error value outside of the :func:`~jax.jit`:
|
|
|
|
>>> import jax
|
|
>>> from jax.experimental import checkify
|
|
>>> def f(x):
|
|
... checkify.check(x>0, "must be positive!")
|
|
... return x
|
|
>>> def with_inner_jit(x):
|
|
... checked_f = checkify.checkify(f)
|
|
... # a checkified function can be jitted
|
|
... error, out = jax.jit(checked_f)(x)
|
|
... checkify.check_error(error)
|
|
... return out
|
|
>>> _ = with_inner_jit(1) # no failed check
|
|
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
jax._src.JaxRuntimeError: must be positive!
|
|
>>> # can re-checkify
|
|
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
|
"""
|
|
if not isinstance(error, Error):
|
|
raise TypeError('check_error takes an Error as argument, '
|
|
f'got type {type(error)} instead.')
|
|
_check_error(error, debug=False)
|