mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix debug_nans false positive in jnp.quantile
This commit is contained in:
parent
ea1e879577
commit
44c6883cee
@ -2360,7 +2360,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
|
||||
index[axis] = high
|
||||
high_value = a[tuple(index)]
|
||||
else:
|
||||
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
|
||||
with jax.debug_nans(False):
|
||||
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
|
||||
a = lax.sort(a, dimension=axis)
|
||||
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
|
||||
q = lax.mul(q, n - 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user