mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: add and remove primitives which are checked for NaN outputs.
Started from all primitives exported from `jax.lax` and removed a primitive when: - its output is int/bool (but what if the output is complex?) - it does not generate NaNs, ie. if the input does not contain a NaN value, the output will not contain a NaN value (eg. reshape/concatenate/..., max/..) - it's already handled by other rules (eg. div, gather/scatter and scan/cond/while) Compared to the previous set: added: {logistic, custom_linear_solve, igammac, igamma_grad_a, psum, igamma, reduce, tan, rng_uniform, lgamma, digamma, regularized_incomplete_beta, reduce_prod, reduce_window, cbrt, bessel_i0e, random_gamma_grad, bessel_i1e} removed: {shift_left, concatenate, complex, shift_right_arithmetic, convert_element_type, conj, sign, round, shift_right_logical, reduce_max, bitcast_convert_type, real, max, reduce_min, rev, slice, min, imag, clamp, floor, select_n}
This commit is contained in:
parent
a13541441b
commit
ba3d4423eb
@ -575,6 +575,28 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
||||
msg = f'nan generated by primitive {prim.name} at {summary()}'
|
||||
return out, assert_func(error, any_nans, msg, None)
|
||||
|
||||
|
||||
# All primitives which can generate a NaN.
|
||||
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
||||
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
|
||||
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
|
||||
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
|
||||
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
|
||||
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
|
||||
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
|
||||
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
||||
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
||||
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
||||
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
||||
lax.reduce_sum_p, lax.reduce_window_p,
|
||||
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
||||
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
||||
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
|
||||
|
||||
for prim in nan_primitives:
|
||||
error_checks[prim] = partial(nan_error_check, prim)
|
||||
|
||||
|
||||
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
@ -790,70 +812,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
||||
|
||||
def add_nan_check(prim):
|
||||
error_checks[prim] = partial(nan_error_check, prim)
|
||||
|
||||
add_nan_check(lax.floor_p)
|
||||
add_nan_check(lax.ceil_p)
|
||||
add_nan_check(lax.round_p)
|
||||
add_nan_check(lax.sign_p)
|
||||
add_nan_check(lax.shift_left_p)
|
||||
add_nan_check(lax.shift_right_arithmetic_p)
|
||||
add_nan_check(lax.shift_right_logical_p)
|
||||
add_nan_check(lax.bitcast_convert_type_p)
|
||||
add_nan_check(lax.real_p)
|
||||
add_nan_check(lax.complex_p)
|
||||
add_nan_check(lax.conj_p)
|
||||
add_nan_check(lax.imag_p)
|
||||
add_nan_check(lax.add_p)
|
||||
add_nan_check(lax.sub_p)
|
||||
add_nan_check(lax.convert_element_type_p)
|
||||
add_nan_check(lax.broadcast_in_dim_p)
|
||||
add_nan_check(lax.concatenate_p)
|
||||
add_nan_check(lax.pad_p)
|
||||
add_nan_check(lax.reshape_p)
|
||||
add_nan_check(lax.rev_p)
|
||||
add_nan_check(lax.transpose_p)
|
||||
add_nan_check(lax.slice_p)
|
||||
add_nan_check(lax.reduce_sum_p)
|
||||
add_nan_check(lax.reduce_window_sum_p)
|
||||
add_nan_check(lax.fft_p)
|
||||
add_nan_check(lax.cumsum_p)
|
||||
add_nan_check(lax.cumprod_p)
|
||||
add_nan_check(lax.cumlogsumexp_p)
|
||||
add_nan_check(lax.cummax_p)
|
||||
add_nan_check(lax.cummin_p)
|
||||
add_nan_check(lax.erf_p)
|
||||
add_nan_check(lax.expm1_p)
|
||||
add_nan_check(lax.log1p_p)
|
||||
add_nan_check(lax.sqrt_p)
|
||||
add_nan_check(lax.rsqrt_p)
|
||||
add_nan_check(lax.asinh_p)
|
||||
add_nan_check(lax.acosh_p)
|
||||
add_nan_check(lax.atanh_p)
|
||||
add_nan_check(lax.erfc_p)
|
||||
add_nan_check(lax.rem_p)
|
||||
add_nan_check(lax.clamp_p)
|
||||
add_nan_check(lax.erf_inv_p)
|
||||
add_nan_check(lax.exp_p)
|
||||
add_nan_check(lax.pow_p)
|
||||
add_nan_check(lax.integer_pow_p)
|
||||
add_nan_check(lax.tanh_p)
|
||||
add_nan_check(lax.log_p)
|
||||
add_nan_check(lax.atan2_p)
|
||||
add_nan_check(lax.sin_p)
|
||||
add_nan_check(lax.cos_p)
|
||||
add_nan_check(lax.sinh_p)
|
||||
add_nan_check(lax.cosh_p)
|
||||
add_nan_check(lax.dot_general_p)
|
||||
add_nan_check(lax.mul_p)
|
||||
add_nan_check(lax.conv_general_dilated_p)
|
||||
add_nan_check(lax.reduce_max_p)
|
||||
add_nan_check(lax.reduce_min_p)
|
||||
add_nan_check(lax.abs_p)
|
||||
add_nan_check(lax.select_n_p)
|
||||
add_nan_check(lax.max_p)
|
||||
add_nan_check(lax.min_p)
|
||||
|
||||
def assert_discharge_rule(error, enabled_errors, err, code, payload, *, msgs):
|
||||
if ErrorCategory.USER_CHECK not in enabled_errors:
|
||||
|
Loading…
x
Reference in New Issue
Block a user