mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24210 from jakevdp:nan-to-num-doc
PiperOrigin-RevId: 684080842
This commit is contained in:
commit
c2deae8aca
@ -3400,11 +3400,53 @@ def fix(x: ArrayLike, out: None = None) -> Array:
|
||||
return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x))
|
||||
|
||||
|
||||
@util.implements(np.nan_to_num)
|
||||
@jit
|
||||
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
|
||||
posinf: ArrayLike | None = None,
|
||||
neginf: ArrayLike | None = None) -> Array:
|
||||
"""Replace NaN and infinite entries in an array.
|
||||
|
||||
JAX implementation of :func:`numpy.nan_to_num`.
|
||||
|
||||
Args:
|
||||
x: array of values to be replaced. If it does not have an inexact
|
||||
dtype it will be returned unmodified.
|
||||
copy: unused by JAX
|
||||
nan: value to substitute for NaN entries. Defaults to 0.0.
|
||||
posinf: value to substitute for positive infinite entries.
|
||||
Defaults to the maximum representable value.
|
||||
neginf: value to substitute for positive infinite entries.
|
||||
Defaults to the minimum representable value.
|
||||
|
||||
Returns:
|
||||
A copy of ``x`` with the requested substitutions.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.isnan`: return True where the array contains NaN
|
||||
- :func:`jax.numpy.isposinf`: return True where the array contains +inf
|
||||
- :func:`jax.numpy.isneginf`: return True where the array contains -inf
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
|
||||
|
||||
Default substitution values:
|
||||
|
||||
>>> jnp.nan_to_num(x)
|
||||
Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38,
|
||||
2.0000000e+00, -3.4028235e+38], dtype=float32)
|
||||
|
||||
Overriding substitutions for ``-inf`` and ``+inf``:
|
||||
|
||||
>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
|
||||
Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
|
||||
|
||||
If you only wish to substitute for NaN values while leaving ``inf`` values
|
||||
untouched, using :func:`~jax.numpy.where` with :func:`jax.numpy.isnan` is
|
||||
a better option:
|
||||
|
||||
>>> jnp.where(jnp.isnan(x), 0, x)
|
||||
Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
|
||||
"""
|
||||
del copy
|
||||
util.check_arraylike("nan_to_num", x)
|
||||
dtype = _dtype(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user