Fix debug_nans false positive in jnp.quantile

This commit is contained in:
Jake VanderPlas 2024-11-05 15:36:14 -08:00
parent ea1e879577
commit 44c6883cee

View File

@ -2360,7 +2360,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
index[axis] = high
high_value = a[tuple(index)]
else:
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
with jax.debug_nans(False):
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)