mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Incorporate a few recent NumPy API extensions. (#3586)
This commit is contained in:
parent
17fc8b75c2
commit
1f2025e12f
@ -1496,15 +1496,19 @@ def isnan(x):
|
||||
lax.bitwise_not(isinf(x)))
|
||||
|
||||
@_wraps(np.nan_to_num)
|
||||
def nan_to_num(x, copy=True):
|
||||
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
|
||||
del copy
|
||||
dtype = _dtype(x)
|
||||
if issubdtype(dtype, complexfloating):
|
||||
return lax.complex(nan_to_num(lax.real(x)), nan_to_num(lax.imag(x)))
|
||||
return lax.complex(
|
||||
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
|
||||
nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
|
||||
info = finfo(dtypes.canonicalize_dtype(dtype))
|
||||
x = where(isnan(x), _constant_like(x, 0), x)
|
||||
x = where(isposinf(x), _constant_like(x, info.max), x)
|
||||
x = where(isneginf(x), _constant_like(x, info.min), x)
|
||||
posinf = info.max if posinf is None else posinf
|
||||
neginf = info.min if neginf is None else neginf
|
||||
x = where(isnan(x), _constant_like(x, nan), x)
|
||||
x = where(isposinf(x), _constant_like(x, posinf), x)
|
||||
x = where(isneginf(x), _constant_like(x, neginf), x)
|
||||
return x
|
||||
|
||||
### Reducers
|
||||
@ -1709,9 +1713,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08):
|
||||
|
||||
|
||||
@_wraps(np.count_nonzero)
|
||||
def count_nonzero(a, axis=None):
|
||||
def count_nonzero(a, axis=None, keepdims=False):
|
||||
return sum(lax.ne(a, _constant_like(a, 0)), axis=axis,
|
||||
dtype=dtypes.canonicalize_dtype(np.int_))
|
||||
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
|
||||
|
||||
|
||||
_NONZERO_DOC = """\
|
||||
@ -2232,12 +2236,17 @@ def ones(shape, dtype=None):
|
||||
|
||||
|
||||
@_wraps(np.array_equal)
|
||||
def array_equal(a1, a2):
|
||||
def array_equal(a1, a2, equal_nan=False):
|
||||
try:
|
||||
a1, a2 = asarray(a1), asarray(a2)
|
||||
except Exception:
|
||||
return False
|
||||
return shape(a1) == shape(a2) and all(asarray(a1 == a2))
|
||||
if shape(a1) != shape(a2):
|
||||
return False
|
||||
eq = asarray(a1 == a2)
|
||||
if equal_nan:
|
||||
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
|
||||
return all(eq)
|
||||
|
||||
|
||||
# We can't create uninitialized arrays in XLA; use zeros for empty.
|
||||
|
Loading…
x
Reference in New Issue
Block a user