Merge pull request #14370 from jakevdp:argpartition-impl

PiperOrigin-RevId: 508194466
This commit is contained in:
jax authors 2023-02-08 15:10:50 -08:00
commit ccb974a150
4 changed files with 85 additions and 22 deletions

View File

@ -68,6 +68,7 @@ namespace; they are listed below.
arctanh
argmax
argmin
argpartition
argsort
argwhere
around

View File

@ -3529,11 +3529,17 @@ def msort(a):
@_wraps(np.partition, lax_description="""
The jax version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`.
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
@ -3548,6 +3554,38 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
return swapaxes(out, -1, axis)
@_wraps(np.argpartition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
axis = _canonicalize_axis(axis, arr.ndim)
kth = _canonicalize_axis(kth, arr.shape[axis])
arr = swapaxes(arr, axis, -1)
bottom_ind = lax.top_k(-arr, kth + 1)[1]
# To avoid issues with duplicate values, we compute the top indices via a proxy
set_to_zero = lambda a, i: a.at[i].set(0)
for _ in range(arr.ndim - 1):
set_to_zero = jax.vmap(set_to_zero)
proxy = set_to_zero(ones(arr.shape), bottom_ind)
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
return swapaxes(out, -1, axis)
@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a_shape = shape(a)
@ -4947,19 +4985,6 @@ def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
"consider arr.flatten() instead.")
### track unimplemented functions
_NOT_IMPLEMENTED_DESC = """
*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
"""
def _not_implemented(fun, module=None):
@_wraps(fun, module=module, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
def wrapped(*args, **kwargs):
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
return wrapped
@_wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
@ -5086,12 +5111,6 @@ _diff_methods = ["choose", "conj", "conjugate", "copy", "cumprod", "cumsum",
"ravel", "repeat", "sort", "squeeze", "std", "sum",
"swapaxes", "take", "trace", "var"]
# These methods are mentioned explicitly by nondiff_methods, so we create
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)

View File

@ -174,6 +174,7 @@ from jax._src.numpy.lax_numpy import (
nan_to_num as nan_to_num,
nanargmax as nanargmax,
nanargmin as nanargmin,
argpartition as argpartition,
nanmedian as nanmedian,
nanpercentile as nanpercentile,
nanquantile as nanquantile,

View File

@ -3599,10 +3599,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)
# Assert that pivot point is equal
# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_output, axis=axis, index=kth),
lax.index_in_dim(np_output, axis=axis, index=kth))
# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
@ -3610,6 +3612,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
for axis in range(-len(shape), len(shape))
for kth in range(-shape[axis], shape[axis])],
dtype=default_dtypes,
)
def testArgpartition(self, shape, dtype, axis, kth):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
np_output = np.argpartition(arg, axis=axis, kth=kth)
# Assert that all indices are present
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
# Because JAX & numpy may treat duplicates differently, we must compare values
# rather than indices.
getvals = lambda x, ind: x[ind]
for ax in range(arg.ndim):
if ax != range(arg.ndim)[axis]:
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)
# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_values, axis=axis, index=kth),
lax.index_in_dim(np_values, axis=axis, index=kth))
# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
@jtu.sample_product(
[dict(shifts=shifts, axis=axis)
for shifts, axis in [
@ -5183,6 +5224,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],