mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14370 from jakevdp:argpartition-impl
PiperOrigin-RevId: 508194466
This commit is contained in:
commit
ccb974a150
@ -68,6 +68,7 @@ namespace; they are listed below.
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
argwhere
|
||||
around
|
||||
|
@ -3529,11 +3529,17 @@ def msort(a):
|
||||
|
||||
|
||||
@_wraps(np.partition, lax_description="""
|
||||
The jax version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`.
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
The JAX version differs from the NumPy version in the treatment of NaN entries;
|
||||
NaNs which have the negative bit set are sorted to the beginning of the array.
|
||||
""")
|
||||
@partial(jit, static_argnames=['kth', 'axis'])
|
||||
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
_check_arraylike("partition", a)
|
||||
arr = asarray(a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
@ -3548,6 +3554,38 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
return swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@_wraps(np.argpartition, lax_description="""
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
The JAX version differs from the NumPy version in the treatment of NaN entries;
|
||||
NaNs which have the negative bit set are sorted to the beginning of the array.
|
||||
""")
|
||||
@partial(jit, static_argnames=['kth', 'axis'])
|
||||
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
_check_arraylike("partition", a)
|
||||
arr = asarray(a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
kth = _canonicalize_axis(kth, arr.shape[axis])
|
||||
|
||||
arr = swapaxes(arr, axis, -1)
|
||||
bottom_ind = lax.top_k(-arr, kth + 1)[1]
|
||||
|
||||
# To avoid issues with duplicate values, we compute the top indices via a proxy
|
||||
set_to_zero = lambda a, i: a.at[i].set(0)
|
||||
for _ in range(arr.ndim - 1):
|
||||
set_to_zero = jax.vmap(set_to_zero)
|
||||
proxy = set_to_zero(ones(arr.shape), bottom_ind)
|
||||
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
|
||||
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
|
||||
return swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _roll(a, shift, axis):
|
||||
a_shape = shape(a)
|
||||
@ -4947,19 +4985,6 @@ def _notimplemented_flat(self):
|
||||
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
|
||||
"consider arr.flatten() instead.")
|
||||
|
||||
### track unimplemented functions
|
||||
|
||||
_NOT_IMPLEMENTED_DESC = """
|
||||
*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
|
||||
"""
|
||||
|
||||
def _not_implemented(fun, module=None):
|
||||
@_wraps(fun, module=module, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
|
||||
def wrapped(*args, **kwargs):
|
||||
msg = "Numpy function {} not yet implemented"
|
||||
raise NotImplementedError(msg.format(fun))
|
||||
return wrapped
|
||||
|
||||
|
||||
@_wraps(np.place, lax_description="""
|
||||
Numpy function :func:`numpy.place` is not available in JAX and will raise a
|
||||
@ -5086,12 +5111,6 @@ _diff_methods = ["choose", "conj", "conjugate", "copy", "cumprod", "cumsum",
|
||||
"ravel", "repeat", "sort", "squeeze", "std", "sum",
|
||||
"swapaxes", "take", "trace", "var"]
|
||||
|
||||
# These methods are mentioned explicitly by nondiff_methods, so we create
|
||||
# _not_implemented implementations of them here rather than in __init__.py.
|
||||
# TODO(phawkins): implement these.
|
||||
argpartition = _not_implemented(np.argpartition)
|
||||
|
||||
|
||||
# Experimental support for NumPy's module dispatch with NEP-37.
|
||||
# Currently requires https://github.com/seberg/numpy-dispatch
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)
|
||||
|
@ -174,6 +174,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
nan_to_num as nan_to_num,
|
||||
nanargmax as nanargmax,
|
||||
nanargmin as nanargmin,
|
||||
argpartition as argpartition,
|
||||
nanmedian as nanmedian,
|
||||
nanpercentile as nanpercentile,
|
||||
nanquantile as nanquantile,
|
||||
|
@ -3599,10 +3599,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
|
||||
np_output = np.partition(arg, axis=axis, kth=kth)
|
||||
|
||||
# Assert that pivot point is equal
|
||||
# Assert that pivot point is equal:
|
||||
self.assertArraysEqual(
|
||||
lax.index_in_dim(jnp_output, axis=axis, index=kth),
|
||||
lax.index_in_dim(np_output, axis=axis, index=kth))
|
||||
|
||||
# Assert remaining values are correctly partitioned:
|
||||
self.assertArraysEqual(
|
||||
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
|
||||
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
|
||||
@ -3610,6 +3612,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
|
||||
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
|
||||
|
||||
@jtu.sample_product(
|
||||
[{'shape': shape, 'axis': axis, 'kth': kth}
|
||||
for shape in nonzerodim_shapes
|
||||
for axis in range(-len(shape), len(shape))
|
||||
for kth in range(-shape[axis], shape[axis])],
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testArgpartition(self, shape, dtype, axis, kth):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
arg = rng(shape, dtype)
|
||||
|
||||
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
|
||||
np_output = np.argpartition(arg, axis=axis, kth=kth)
|
||||
|
||||
# Assert that all indices are present
|
||||
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
|
||||
|
||||
# Because JAX & numpy may treat duplicates differently, we must compare values
|
||||
# rather than indices.
|
||||
getvals = lambda x, ind: x[ind]
|
||||
for ax in range(arg.ndim):
|
||||
if ax != range(arg.ndim)[axis]:
|
||||
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
|
||||
jnp_values = getvals(arg, jnp_output)
|
||||
np_values = getvals(arg, np_output)
|
||||
|
||||
# Assert that pivot point is equal:
|
||||
self.assertArraysEqual(
|
||||
lax.index_in_dim(jnp_values, axis=axis, index=kth),
|
||||
lax.index_in_dim(np_values, axis=axis, index=kth))
|
||||
|
||||
# Assert remaining values are correctly partitioned:
|
||||
self.assertArraysEqual(
|
||||
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
|
||||
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
|
||||
self.assertArraysEqual(
|
||||
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
|
||||
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shifts=shifts, axis=axis)
|
||||
for shifts, axis in [
|
||||
@ -5183,6 +5224,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
||||
unsupported_params = {
|
||||
'argpartition': ['kind', 'order'],
|
||||
'asarray': ['like'],
|
||||
'broadcast_to': ['subok'],
|
||||
'clip': ['kwargs'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user