add tests for np.sort (c.f. #221)

This commit is contained in:
Matthew Johnson 2019-01-13 09:01:01 -08:00
parent dfa2cb821f
commit 5a4713f108
3 changed files with 39 additions and 4 deletions

View File

@ -391,7 +391,7 @@ def sort(operand, dimension=-1):
return sort_p.bind(operand, dimension=dimension)
def sort_key_val(keys, values, dimension=-1):
# TODO new sort_key_val is variadic
# TODO(mattjj): new sort_key_val is variadic
result = sort_key_val_p.bind(keys, values, dimension=dimension)
sorted_keys, sorted_values = result
return sorted_keys, sorted_values

View File

@ -235,7 +235,6 @@ absolute = abs = _one_to_one_unop(onp.absolute, lax.abs)
fabs = _one_to_one_unop(onp.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
negative = _one_to_one_unop(onp.negative, lax.neg)
sort = _one_to_one_unop(onp.sort, lax.sort)
sign = _one_to_one_unop(onp.sign, lax.sign)
floor = _one_to_one_unop(onp.floor, lax.floor, True)
@ -1523,6 +1522,20 @@ def _argminmax(op, a, axis):
mask_idxs = where(lax._eq_meet(a, op(a, axis, keepdims=True)), idxs, maxval)
return min(mask_idxs, axis)
@_wraps(onp.sort)
def sort(a, axis=-1, kind='quicksort', order=None):
if kind != 'quicksort':
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
msg = "'order' argument to sort is not supported."
raise ValueError(msg)
if axis is None:
return lax.sort(a.ravel(), 0)
else:
return lax.sort(a, axis % ndim(a))
### Indexing
@ -1722,7 +1735,8 @@ def _static_idx(idx, size):
def _not_implemented(fun):
@_wraps(fun)
def wrapped(*args, **kwargs):
raise Exception("Numpy function {} not yet implemented".format(fun))
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
return wrapped
# Build a set of all unimplemented NumPy functions.

View File

@ -932,7 +932,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# TODO(mattjj): test infix operator overrides
def testRavel(self):
# TODO(mattjj): support this method-based syntax?
rng = onp.random.RandomState(0)
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
@ -959,6 +958,28 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
ans = lnp.arange(0.0, 1.0, 0.1)
self.assertAllClose(expected, ans, check_dtypes=True)
def testSortManual(self):
# manual tests for sort are nice because we don't have to worry about ties.
# lax.sort is tested combinatorially.
ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4]))
expected = onp.array([4, 8, 15, 16, 23, 42])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a, axis=None)
expected = onp.array([[1, 1, 3, 4]])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a) # last axis
expected = onp.array([[1, 4], [1, 3]])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a, axis=0)
expected = onp.array([[1, 1], [3, 4]])
self.assertAllClose(expected, ans, check_dtypes=True)
if __name__ == "__main__":
absltest.main()