jnp.unique: properly handle NaN values

This commit is contained in:
Jake VanderPlas 2022-01-13 15:54:07 -08:00
parent cd73a4195f
commit bd157cf056
3 changed files with 38 additions and 1 deletions

View File

@ -30,6 +30,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`#9178`).
* {func}`jax.numpy.unique` now treats ``NaN`` values in the same way as `np.unique` in
NumPy versions 1.21 and newer: at most one ``NaN`` value will appear in the uniquified
output ({jax-issue}`9184`).
* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).

View File

@ -5596,6 +5596,11 @@ def take_along_axis(arr, indices, axis: Optional[int]):
@partial(jit, static_argnums=1)
def _unique_sorted_mask(ar, axis):
aux = moveaxis(ar, axis, 0)
if issubdtype(aux.dtype, np.complexfloating):
# Work around issue in sorting of complex numbers with Nan only in the
# imaginary component. This can be removed if sorting in this situation
# is fixed to match numpy.
aux = where(isnan(aux), lax._const(aux, nan), aux)
size, *out_shape = aux.shape
if _prod(out_shape) == 0:
size = 1
@ -5604,7 +5609,13 @@ def _unique_sorted_mask(ar, axis):
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
aux = aux[perm]
if aux.size:
mask = ones(size, dtype=bool).at[1:].set(any(aux[1:] != aux[:-1], tuple(range(1, aux.ndim))))
if issubdtype(aux.dtype, inexact):
# This is appropriate for both float and complex due to the documented behavior of np.unique:
# See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
else:
neq = lax.ne
mask = ones(size, dtype=bool).at[1:].set(any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
else:
mask = zeros(size, dtype=bool)
return aux, mask, perm

View File

@ -2479,6 +2479,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 21), "Numpy < 1.21 does not properly handle NaN values in unique.")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{dtype.__name__}", "dtype": dtype}
for dtype in inexact_dtypes))
def testUniqueNans(self, dtype):
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
x = [complex(i, j) for i, j in itertools.product(x, repeat=2)]
return [np.array(x, dtype=dtype)]
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
jnp_fun = partial(jnp.unique, **kwds)
def np_fun(x):
dtype = x.dtype
# numpy unique fails for bfloat16 NaNs, so we cast to float64
if x.dtype == jnp.bfloat16:
x = x.astype('float64')
u, *rest = np.unique(x, **kwds)
return (u.astype(dtype), *rest)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fixed_size={}".format(fixed_size),
"fixed_size": fixed_size}