Checkify: add and remove primitives which are checked for NaN outputs.

Started from all primitives exported from `jax.lax` and removed
a primitive when:
- its output is int/bool (but what if the output is complex?)
- it does not generate NaNs, ie. if the input does not contain a NaN
  value, the output will not contain a NaN value (eg.
  reshape/concatenate/..., max/..)
- it's already handled by other rules (eg. div, gather/scatter and
  scan/cond/while)

Compared to the previous set:

added: {logistic, custom_linear_solve, igammac, igamma_grad_a, psum,
igamma, reduce, tan, rng_uniform, lgamma, digamma,
regularized_incomplete_beta, reduce_prod, reduce_window, cbrt,
bessel_i0e, random_gamma_grad, bessel_i1e}

removed: {shift_left, concatenate, complex, shift_right_arithmetic,
convert_element_type, conj, sign, round, shift_right_logical,
reduce_max, bitcast_convert_type, real, max, reduce_min, rev, slice,
min, imag, clamp, floor, select_n}
This commit is contained in:
lenamartens 2022-11-14 17:50:46 +00:00
parent a13541441b
commit ba3d4423eb

View File

@ -575,6 +575,28 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
msg = f'nan generated by primitive {prim.name} at {summary()}'
return out, assert_func(error, any_nans, msg, None)
# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
for prim in nan_primitives:
error_checks[prim] = partial(nan_error_check, prim)
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
@ -790,70 +812,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
error_checks[pjit.pjit_p] = pjit_error_check
def add_nan_check(prim):
error_checks[prim] = partial(nan_error_check, prim)
add_nan_check(lax.floor_p)
add_nan_check(lax.ceil_p)
add_nan_check(lax.round_p)
add_nan_check(lax.sign_p)
add_nan_check(lax.shift_left_p)
add_nan_check(lax.shift_right_arithmetic_p)
add_nan_check(lax.shift_right_logical_p)
add_nan_check(lax.bitcast_convert_type_p)
add_nan_check(lax.real_p)
add_nan_check(lax.complex_p)
add_nan_check(lax.conj_p)
add_nan_check(lax.imag_p)
add_nan_check(lax.add_p)
add_nan_check(lax.sub_p)
add_nan_check(lax.convert_element_type_p)
add_nan_check(lax.broadcast_in_dim_p)
add_nan_check(lax.concatenate_p)
add_nan_check(lax.pad_p)
add_nan_check(lax.reshape_p)
add_nan_check(lax.rev_p)
add_nan_check(lax.transpose_p)
add_nan_check(lax.slice_p)
add_nan_check(lax.reduce_sum_p)
add_nan_check(lax.reduce_window_sum_p)
add_nan_check(lax.fft_p)
add_nan_check(lax.cumsum_p)
add_nan_check(lax.cumprod_p)
add_nan_check(lax.cumlogsumexp_p)
add_nan_check(lax.cummax_p)
add_nan_check(lax.cummin_p)
add_nan_check(lax.erf_p)
add_nan_check(lax.expm1_p)
add_nan_check(lax.log1p_p)
add_nan_check(lax.sqrt_p)
add_nan_check(lax.rsqrt_p)
add_nan_check(lax.asinh_p)
add_nan_check(lax.acosh_p)
add_nan_check(lax.atanh_p)
add_nan_check(lax.erfc_p)
add_nan_check(lax.rem_p)
add_nan_check(lax.clamp_p)
add_nan_check(lax.erf_inv_p)
add_nan_check(lax.exp_p)
add_nan_check(lax.pow_p)
add_nan_check(lax.integer_pow_p)
add_nan_check(lax.tanh_p)
add_nan_check(lax.log_p)
add_nan_check(lax.atan2_p)
add_nan_check(lax.sin_p)
add_nan_check(lax.cos_p)
add_nan_check(lax.sinh_p)
add_nan_check(lax.cosh_p)
add_nan_check(lax.dot_general_p)
add_nan_check(lax.mul_p)
add_nan_check(lax.conv_general_dilated_p)
add_nan_check(lax.reduce_max_p)
add_nan_check(lax.reduce_min_p)
add_nan_check(lax.abs_p)
add_nan_check(lax.select_n_p)
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)
def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs):
if ErrorCategory.USER_CHECK not in enabled_errors: