Merge pull request #24210 from jakevdp:nan-to-num-doc

PiperOrigin-RevId: 684080842
This commit is contained in:
jax authors 2024-10-09 10:10:01 -07:00
commit c2deae8aca

View File

@ -3400,11 +3400,53 @@ def fix(x: ArrayLike, out: None = None) -> Array:
return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x))
@util.implements(np.nan_to_num)
@jit
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: ArrayLike | None = None,
neginf: ArrayLike | None = None) -> Array:
"""Replace NaN and infinite entries in an array.
JAX implementation of :func:`numpy.nan_to_num`.
Args:
x: array of values to be replaced. If it does not have an inexact
dtype it will be returned unmodified.
copy: unused by JAX
nan: value to substitute for NaN entries. Defaults to 0.0.
posinf: value to substitute for positive infinite entries.
Defaults to the maximum representable value.
neginf: value to substitute for positive infinite entries.
Defaults to the minimum representable value.
Returns:
A copy of ``x`` with the requested substitutions.
See also:
- :func:`jax.numpy.isnan`: return True where the array contains NaN
- :func:`jax.numpy.isposinf`: return True where the array contains +inf
- :func:`jax.numpy.isneginf`: return True where the array contains -inf
Examples:
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
Default substitution values:
>>> jnp.nan_to_num(x)
Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38,
2.0000000e+00, -3.4028235e+38], dtype=float32)
Overriding substitutions for ``-inf`` and ``+inf``:
>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
If you only wish to substitute for NaN values while leaving ``inf`` values
untouched, using :func:`~jax.numpy.where` with :func:`jax.numpy.isnan` is
a better option:
>>> jnp.where(jnp.isnan(x), 0, x)
Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
"""
del copy
util.check_arraylike("nan_to_num", x)
dtype = _dtype(x)