Incorporate a few recent NumPy API extensions. (#3586)

This commit is contained in:
Peter Hawkins 2020-06-28 11:27:02 -04:00 committed by GitHub
parent 17fc8b75c2
commit 1f2025e12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1496,15 +1496,19 @@ def isnan(x):
lax.bitwise_not(isinf(x)))
@_wraps(np.nan_to_num)
def nan_to_num(x, copy=True):
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
del copy
dtype = _dtype(x)
if issubdtype(dtype, complexfloating):
return lax.complex(nan_to_num(lax.real(x)), nan_to_num(lax.imag(x)))
return lax.complex(
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
info = finfo(dtypes.canonicalize_dtype(dtype))
x = where(isnan(x), _constant_like(x, 0), x)
x = where(isposinf(x), _constant_like(x, info.max), x)
x = where(isneginf(x), _constant_like(x, info.min), x)
posinf = info.max if posinf is None else posinf
neginf = info.min if neginf is None else neginf
x = where(isnan(x), _constant_like(x, nan), x)
x = where(isposinf(x), _constant_like(x, posinf), x)
x = where(isneginf(x), _constant_like(x, neginf), x)
return x
### Reducers
@ -1709,9 +1713,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08):
@_wraps(np.count_nonzero)
def count_nonzero(a, axis=None):
def count_nonzero(a, axis=None, keepdims=False):
return sum(lax.ne(a, _constant_like(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_))
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
_NONZERO_DOC = """\
@ -2232,12 +2236,17 @@ def ones(shape, dtype=None):
@_wraps(np.array_equal)
def array_equal(a1, a2):
def array_equal(a1, a2, equal_nan=False):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
return shape(a1) == shape(a2) and all(asarray(a1 == a2))
if shape(a1) != shape(a2):
return False
eq = asarray(a1 == a2)
if equal_nan:
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
return all(eq)
# We can't create uninitialized arrays in XLA; use zeros for empty.