diff --git a/CHANGELOG.md b/CHANGELOG.md index 140867054..b86f6d67a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. environment variable, or the ```--flax_host_callback_ad_transforms``` flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}`#8678`). + * Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the + bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent, + where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN`` + representations are now treated as equivalent and sorted to the end of the array. + Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN`` + values with different internal bit representations were not treated as equivalent, and + were sorted according to those bit patterns ({jax-issue}`#9178`). * Bug fixes: * host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`). diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a5f4af77a..87e07e80c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -960,8 +960,11 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array: def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1, is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]: """Wraps XLA's `Sort - `_ - operator. + `_ operator. + + For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values + are sorted to the end of the array. For complex inputs, the sort order is + lexicographic over the real and imaginary parts, with the real part primary. Args: operand : Array or sequence of arrays @@ -3671,8 +3674,10 @@ def _float_to_int_for_sort(x): # x = bit_cast(f); # y = x < 0 ? int32_max - x : x; # then y is ordered as an int32 such that finite values have the obvious - # order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - # and end of the ordering. + # order. In this scheme, -0 would be before 0, and -NaN and NaN appear at + # the beginning and end of the ordering. This causes issues for stable + # sorts, so we avoid this by standardizing the representation of zeros + # and NaNs in the output. # Note that in order to avoid -x to overflow, we calculate # int32_max - x as unsigned, and then convert back to signed. if x.dtype == dtypes.bfloat16: @@ -3683,6 +3688,17 @@ def _float_to_int_for_sort(x): signed = bitcast_convert_type(x, signed_dtype) unsigned = bitcast_convert_type(x, unsigned_dtype) + + # We cannot standardize zeros in x because XLA elides this is some cases. + # We cannot standardize NaNs in x because it triggers jax.debug_nans + # So instead we do these replacements in the signed integer representation. + + # Standardize zeros: + signed = select(eq(x, _zero(x)), _zeros(signed), signed) + # Standardize nans: + signed_nan = x.dtype.type(np.nan).view(signed_dtype) + signed = select(_isnan(x), full_like(signed, signed_nan), signed) + flipped = bitcast_convert_type( sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype) return select(lt(signed, _zero(signed)), flipped, signed) @@ -3690,7 +3706,8 @@ def _float_to_int_for_sort(x): # Default comparator that sorts the operands lexicographically on the # first `num_keys` arguments. # For floating point types, a total order is created where -# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN. +# -infinity < ... < 0 < ... < infinity < NaN. +# 0.0 and -0.0 are treated as equivalent, as are all NaN representations. # For complex types, the (real, imag) pairs are sorted lexicographically # (following NumPy's semantics). # This code adds complex-number support and lexicographic ordering to the algorithm from: @@ -4364,6 +4381,9 @@ _two: Callable = partial(full_like, shape=(), fill_value=2) dtype: Callable = partial(dtypes.dtype, canonicalize=True) _dtype: Callable = partial(dtypes.dtype, canonicalize=True) +def _isnan(x) -> bool: + return ne(x, x) + def _iscomplex(x) -> bool: return dtypes.issubdtype(_dtype(x), np.complexfloating) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f821b58e4..cf3e084d6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2979,18 +2979,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype} - for dtype in inexact_dtypes)) - def testSearchsortedNans(self, dtype): + {"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side} + for dtype in inexact_dtypes + for side in ['left', 'right'])) + def testSearchsortedNans(self, dtype, side): if np.issubdtype(dtype, np.complexfloating): raise SkipTest("Known failure for complex inputs; see #9107") - sorted = jnp.array([-np.nan, -np.inf, -1, 0, 1, np.inf, np.nan], dtype=dtype) - self.assertArraysEqual( - jnp.searchsorted(sorted, sorted, side='left'), - jnp.arange(len(sorted))) - self.assertArraysEqual( - jnp.searchsorted(sorted, sorted, side='right'), - jnp.arange(1, 1 + len(sorted))) + x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) + # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be + # equivalent to argsorting the following: + x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) + + fun = partial(jnp.searchsorted, side=side) + self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv)) + self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_bins={}_right={}_reverse={}".format( diff --git a/tests/lax_test.py b/tests/lax_test.py index 8141f4459..f785fddc9 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1872,6 +1872,19 @@ class LaxTest(jtu.JaxTestCase): fun = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) self._CompileAndCheck(fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype} + for dtype in float_dtypes)) + def testSortFloatSpecialValues(self, dtype): + # Test confirms that + # - NaNs are sorted to the end, regardless of representation + # - sign bit of 0.0 is ignored + x = jnp.array([-np.inf, 0.0, -0.0, np.inf, np.nan, -np.nan], dtype=dtype) + index = lax.iota(dtypes.int_, x.size) + argsort = lambda x: lax.sort_key_val(x, lax.iota(dtypes.int_, x.size), is_stable=True)[1] + self.assertArraysEqual(argsort(x), index) + self.assertArraysEqual(jax.jit(argsort)(x), index) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),