mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #9178 from jakevdp:sort-corner-cases
PiperOrigin-RevId: 421646783
This commit is contained in:
commit
cd73a4195f
@ -23,6 +23,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#8678`).
|
||||
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
|
||||
bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent,
|
||||
where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN``
|
||||
representations are now treated as equivalent and sorted to the end of the array.
|
||||
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
|
||||
values with different internal bit representations were not treated as equivalent, and
|
||||
were sorted according to those bit patterns ({jax-issue}`#9178`).
|
||||
|
||||
* Bug fixes:
|
||||
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).
|
||||
|
@ -960,8 +960,11 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
||||
operator.
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
||||
|
||||
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
|
||||
are sorted to the end of the array. For complex inputs, the sort order is
|
||||
lexicographic over the real and imaginary parts, with the real part primary.
|
||||
|
||||
Args:
|
||||
operand : Array or sequence of arrays
|
||||
@ -3671,8 +3674,10 @@ def _float_to_int_for_sort(x):
|
||||
# x = bit_cast<int32>(f);
|
||||
# y = x < 0 ? int32_max - x : x;
|
||||
# then y is ordered as an int32 such that finite values have the obvious
|
||||
# order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
||||
# and end of the ordering.
|
||||
# order. In this scheme, -0 would be before 0, and -NaN and NaN appear at
|
||||
# the beginning and end of the ordering. This causes issues for stable
|
||||
# sorts, so we avoid this by standardizing the representation of zeros
|
||||
# and NaNs in the output.
|
||||
# Note that in order to avoid -x to overflow, we calculate
|
||||
# int32_max - x as unsigned, and then convert back to signed.
|
||||
if x.dtype == dtypes.bfloat16:
|
||||
@ -3683,6 +3688,17 @@ def _float_to_int_for_sort(x):
|
||||
|
||||
signed = bitcast_convert_type(x, signed_dtype)
|
||||
unsigned = bitcast_convert_type(x, unsigned_dtype)
|
||||
|
||||
# We cannot standardize zeros in x because XLA elides this is some cases.
|
||||
# We cannot standardize NaNs in x because it triggers jax.debug_nans
|
||||
# So instead we do these replacements in the signed integer representation.
|
||||
|
||||
# Standardize zeros:
|
||||
signed = select(eq(x, _zero(x)), _zeros(signed), signed)
|
||||
# Standardize nans:
|
||||
signed_nan = x.dtype.type(np.nan).view(signed_dtype)
|
||||
signed = select(_isnan(x), full_like(signed, signed_nan), signed)
|
||||
|
||||
flipped = bitcast_convert_type(
|
||||
sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
|
||||
return select(lt(signed, _zero(signed)), flipped, signed)
|
||||
@ -3690,7 +3706,8 @@ def _float_to_int_for_sort(x):
|
||||
# Default comparator that sorts the operands lexicographically on the
|
||||
# first `num_keys` arguments.
|
||||
# For floating point types, a total order is created where
|
||||
# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
|
||||
# -infinity < ... < 0 < ... < infinity < NaN.
|
||||
# 0.0 and -0.0 are treated as equivalent, as are all NaN representations.
|
||||
# For complex types, the (real, imag) pairs are sorted lexicographically
|
||||
# (following NumPy's semantics).
|
||||
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
||||
@ -4364,6 +4381,9 @@ _two: Callable = partial(full_like, shape=(), fill_value=2)
|
||||
dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
||||
_dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def _isnan(x) -> bool:
|
||||
return ne(x, x)
|
||||
|
||||
def _iscomplex(x) -> bool:
|
||||
return dtypes.issubdtype(_dtype(x), np.complexfloating)
|
||||
|
||||
|
@ -2979,18 +2979,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
||||
for dtype in inexact_dtypes))
|
||||
def testSearchsortedNans(self, dtype):
|
||||
{"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side}
|
||||
for dtype in inexact_dtypes
|
||||
for side in ['left', 'right']))
|
||||
def testSearchsortedNans(self, dtype, side):
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
raise SkipTest("Known failure for complex inputs; see #9107")
|
||||
sorted = jnp.array([-np.nan, -np.inf, -1, 0, 1, np.inf, np.nan], dtype=dtype)
|
||||
self.assertArraysEqual(
|
||||
jnp.searchsorted(sorted, sorted, side='left'),
|
||||
jnp.arange(len(sorted)))
|
||||
self.assertArraysEqual(
|
||||
jnp.searchsorted(sorted, sorted, side='right'),
|
||||
jnp.arange(1, 1 + len(sorted)))
|
||||
x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype)
|
||||
# The sign bit should not matter for 0.0 or NaN, so argsorting the above should be
|
||||
# equivalent to argsorting the following:
|
||||
x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5])
|
||||
|
||||
fun = partial(jnp.searchsorted, side=side)
|
||||
self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv))
|
||||
self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_bins={}_right={}_reverse={}".format(
|
||||
|
@ -1872,6 +1872,19 @@ class LaxTest(jtu.JaxTestCase):
|
||||
fun = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
def testSortFloatSpecialValues(self, dtype):
|
||||
# Test confirms that
|
||||
# - NaNs are sorted to the end, regardless of representation
|
||||
# - sign bit of 0.0 is ignored
|
||||
x = jnp.array([-np.inf, 0.0, -0.0, np.inf, np.nan, -np.nan], dtype=dtype)
|
||||
index = lax.iota(dtypes.int_, x.size)
|
||||
argsort = lambda x: lax.sort_key_val(x, lax.iota(dtypes.int_, x.size), is_stable=True)[1]
|
||||
self.assertArraysEqual(argsort(x), index)
|
||||
self.assertArraysEqual(jax.jit(argsort)(x), index)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
|
||||
|
Loading…
x
Reference in New Issue
Block a user