mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add tests for np.sort (c.f. #221)
This commit is contained in:
parent
dfa2cb821f
commit
5a4713f108
@ -391,7 +391,7 @@ def sort(operand, dimension=-1):
|
||||
return sort_p.bind(operand, dimension=dimension)
|
||||
|
||||
def sort_key_val(keys, values, dimension=-1):
|
||||
# TODO new sort_key_val is variadic
|
||||
# TODO(mattjj): new sort_key_val is variadic
|
||||
result = sort_key_val_p.bind(keys, values, dimension=dimension)
|
||||
sorted_keys, sorted_values = result
|
||||
return sorted_keys, sorted_values
|
||||
|
@ -235,7 +235,6 @@ absolute = abs = _one_to_one_unop(onp.absolute, lax.abs)
|
||||
fabs = _one_to_one_unop(onp.fabs, lax.abs, True)
|
||||
bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
|
||||
negative = _one_to_one_unop(onp.negative, lax.neg)
|
||||
sort = _one_to_one_unop(onp.sort, lax.sort)
|
||||
sign = _one_to_one_unop(onp.sign, lax.sign)
|
||||
|
||||
floor = _one_to_one_unop(onp.floor, lax.floor, True)
|
||||
@ -1523,6 +1522,20 @@ def _argminmax(op, a, axis):
|
||||
mask_idxs = where(lax._eq_meet(a, op(a, axis, keepdims=True)), idxs, maxval)
|
||||
return min(mask_idxs, axis)
|
||||
|
||||
|
||||
@_wraps(onp.sort)
|
||||
def sort(a, axis=-1, kind='quicksort', order=None):
|
||||
if kind != 'quicksort':
|
||||
warnings.warn("'kind' argument to sort is ignored.")
|
||||
if order is not None:
|
||||
msg = "'order' argument to sort is not supported."
|
||||
raise ValueError(msg)
|
||||
if axis is None:
|
||||
return lax.sort(a.ravel(), 0)
|
||||
else:
|
||||
return lax.sort(a, axis % ndim(a))
|
||||
|
||||
|
||||
### Indexing
|
||||
|
||||
|
||||
@ -1722,7 +1735,8 @@ def _static_idx(idx, size):
|
||||
def _not_implemented(fun):
|
||||
@_wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
raise Exception("Numpy function {} not yet implemented".format(fun))
|
||||
msg = "Numpy function {} not yet implemented"
|
||||
raise NotImplementedError(msg.format(fun))
|
||||
return wrapped
|
||||
|
||||
# Build a set of all unimplemented NumPy functions.
|
||||
|
@ -932,7 +932,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# TODO(mattjj): test infix operator overrides
|
||||
|
||||
def testRavel(self):
|
||||
# TODO(mattjj): support this method-based syntax?
|
||||
rng = onp.random.RandomState(0)
|
||||
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
||||
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
|
||||
@ -959,6 +958,28 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
ans = lnp.arange(0.0, 1.0, 0.1)
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
|
||||
def testSortManual(self):
|
||||
# manual tests for sort are nice because we don't have to worry about ties.
|
||||
# lax.sort is tested combinatorially.
|
||||
ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4]))
|
||||
expected = onp.array([4, 8, 15, 16, 23, 42])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
|
||||
a = onp.array([[1, 4], [3, 1]])
|
||||
ans = lnp.sort(a, axis=None)
|
||||
expected = onp.array([[1, 1, 3, 4]])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
|
||||
a = onp.array([[1, 4], [3, 1]])
|
||||
ans = lnp.sort(a) # last axis
|
||||
expected = onp.array([[1, 4], [1, 3]])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
|
||||
a = onp.array([[1, 4], [3, 1]])
|
||||
ans = lnp.sort(a, axis=0)
|
||||
expected = onp.array([[1, 1], [3, 4]])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user