Merge pull request #9178 from jakevdp:sort-corner-cases

PiperOrigin-RevId: 421646783
This commit is contained in:
jax authors 2022-01-13 13:33:02 -08:00
commit cd73a4195f
4 changed files with 57 additions and 15 deletions

View File

@ -23,6 +23,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent,
where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN``
representations are now treated as equivalent and sorted to the end of the array.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`#9178`).
* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).

View File

@ -960,8 +960,11 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
operator.
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
are sorted to the end of the array. For complex inputs, the sort order is
lexicographic over the real and imaginary parts, with the real part primary.
Args:
operand : Array or sequence of arrays
@ -3671,8 +3674,10 @@ def _float_to_int_for_sort(x):
# x = bit_cast<int32>(f);
# y = x < 0 ? int32_max - x : x;
# then y is ordered as an int32 such that finite values have the obvious
# order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
# and end of the ordering.
# order. In this scheme, -0 would be before 0, and -NaN and NaN appear at
# the beginning and end of the ordering. This causes issues for stable
# sorts, so we avoid this by standardizing the representation of zeros
# and NaNs in the output.
# Note that in order to avoid -x to overflow, we calculate
# int32_max - x as unsigned, and then convert back to signed.
if x.dtype == dtypes.bfloat16:
@ -3683,6 +3688,17 @@ def _float_to_int_for_sort(x):
signed = bitcast_convert_type(x, signed_dtype)
unsigned = bitcast_convert_type(x, unsigned_dtype)
# We cannot standardize zeros in x because XLA elides this is some cases.
# We cannot standardize NaNs in x because it triggers jax.debug_nans
# So instead we do these replacements in the signed integer representation.
# Standardize zeros:
signed = select(eq(x, _zero(x)), _zeros(signed), signed)
# Standardize nans:
signed_nan = x.dtype.type(np.nan).view(signed_dtype)
signed = select(_isnan(x), full_like(signed, signed_nan), signed)
flipped = bitcast_convert_type(
sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
return select(lt(signed, _zero(signed)), flipped, signed)
@ -3690,7 +3706,8 @@ def _float_to_int_for_sort(x):
# Default comparator that sorts the operands lexicographically on the
# first `num_keys` arguments.
# For floating point types, a total order is created where
# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
# -infinity < ... < 0 < ... < infinity < NaN.
# 0.0 and -0.0 are treated as equivalent, as are all NaN representations.
# For complex types, the (real, imag) pairs are sorted lexicographically
# (following NumPy's semantics).
# This code adds complex-number support and lexicographic ordering to the algorithm from:
@ -4364,6 +4381,9 @@ _two: Callable = partial(full_like, shape=(), fill_value=2)
dtype: Callable = partial(dtypes.dtype, canonicalize=True)
_dtype: Callable = partial(dtypes.dtype, canonicalize=True)
def _isnan(x) -> bool:
return ne(x, x)
def _iscomplex(x) -> bool:
return dtypes.issubdtype(_dtype(x), np.complexfloating)

View File

@ -2979,18 +2979,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
for dtype in inexact_dtypes))
def testSearchsortedNans(self, dtype):
{"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side}
for dtype in inexact_dtypes
for side in ['left', 'right']))
def testSearchsortedNans(self, dtype, side):
if np.issubdtype(dtype, np.complexfloating):
raise SkipTest("Known failure for complex inputs; see #9107")
sorted = jnp.array([-np.nan, -np.inf, -1, 0, 1, np.inf, np.nan], dtype=dtype)
self.assertArraysEqual(
jnp.searchsorted(sorted, sorted, side='left'),
jnp.arange(len(sorted)))
self.assertArraysEqual(
jnp.searchsorted(sorted, sorted, side='right'),
jnp.arange(1, 1 + len(sorted)))
x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype)
# The sign bit should not matter for 0.0 or NaN, so argsorting the above should be
# equivalent to argsorting the following:
x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5])
fun = partial(jnp.searchsorted, side=side)
self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv))
self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_bins={}_right={}_reverse={}".format(

View File

@ -1872,6 +1872,19 @@ class LaxTest(jtu.JaxTestCase):
fun = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
for dtype in float_dtypes))
def testSortFloatSpecialValues(self, dtype):
# Test confirms that
# - NaNs are sorted to the end, regardless of representation
# - sign bit of 0.0 is ignored
x = jnp.array([-np.inf, 0.0, -0.0, np.inf, np.nan, -np.nan], dtype=dtype)
index = lax.iota(dtypes.int_, x.size)
argsort = lambda x: lax.sort_key_val(x, lax.iota(dtypes.int_, x.size), is_stable=True)[1]
self.assertArraysEqual(argsort(x), index)
self.assertArraysEqual(jax.jit(argsort)(x), index)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),