diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 89d48268c..6926b8700 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -575,6 +575,28 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params): msg = f'nan generated by primitive {prim.name} at {summary()}' return out, assert_func(error, any_nans, msg, None) + +# All primitives which can generate a NaN. +nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, + lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, + lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, + lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, + lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, + lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, + lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, + lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, + lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, + lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, + lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_sum_p, lax.reduce_window_p, + lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, + lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, + lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] + +for prim in nan_primitives: + error_checks[prim] = partial(nan_error_check, prim) + + def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -790,70 +812,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, error_checks[pjit.pjit_p] = pjit_error_check -def add_nan_check(prim): - error_checks[prim] = partial(nan_error_check, prim) - -add_nan_check(lax.floor_p) -add_nan_check(lax.ceil_p) -add_nan_check(lax.round_p) -add_nan_check(lax.sign_p) -add_nan_check(lax.shift_left_p) -add_nan_check(lax.shift_right_arithmetic_p) -add_nan_check(lax.shift_right_logical_p) -add_nan_check(lax.bitcast_convert_type_p) -add_nan_check(lax.real_p) -add_nan_check(lax.complex_p) -add_nan_check(lax.conj_p) -add_nan_check(lax.imag_p) -add_nan_check(lax.add_p) -add_nan_check(lax.sub_p) -add_nan_check(lax.convert_element_type_p) -add_nan_check(lax.broadcast_in_dim_p) -add_nan_check(lax.concatenate_p) -add_nan_check(lax.pad_p) -add_nan_check(lax.reshape_p) -add_nan_check(lax.rev_p) -add_nan_check(lax.transpose_p) -add_nan_check(lax.slice_p) -add_nan_check(lax.reduce_sum_p) -add_nan_check(lax.reduce_window_sum_p) -add_nan_check(lax.fft_p) -add_nan_check(lax.cumsum_p) -add_nan_check(lax.cumprod_p) -add_nan_check(lax.cumlogsumexp_p) -add_nan_check(lax.cummax_p) -add_nan_check(lax.cummin_p) -add_nan_check(lax.erf_p) -add_nan_check(lax.expm1_p) -add_nan_check(lax.log1p_p) -add_nan_check(lax.sqrt_p) -add_nan_check(lax.rsqrt_p) -add_nan_check(lax.asinh_p) -add_nan_check(lax.acosh_p) -add_nan_check(lax.atanh_p) -add_nan_check(lax.erfc_p) -add_nan_check(lax.rem_p) -add_nan_check(lax.clamp_p) -add_nan_check(lax.erf_inv_p) -add_nan_check(lax.exp_p) -add_nan_check(lax.pow_p) -add_nan_check(lax.integer_pow_p) -add_nan_check(lax.tanh_p) -add_nan_check(lax.log_p) -add_nan_check(lax.atan2_p) -add_nan_check(lax.sin_p) -add_nan_check(lax.cos_p) -add_nan_check(lax.sinh_p) -add_nan_check(lax.cosh_p) -add_nan_check(lax.dot_general_p) -add_nan_check(lax.mul_p) -add_nan_check(lax.conv_general_dilated_p) -add_nan_check(lax.reduce_max_p) -add_nan_check(lax.reduce_min_p) -add_nan_check(lax.abs_p) -add_nan_check(lax.select_n_p) -add_nan_check(lax.max_p) -add_nan_check(lax.min_p) def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs): if ErrorCategory.USER_CHECK not in enabled_errors: