Merge pull request #1088 from fehiepsi/median

Add numpy.median and support ddof for numpy.var
This commit is contained in:
Matthew Johnson 2019-08-01 20:57:28 -07:00 committed by GitHub
commit fd98f957a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 8 deletions

View File

@ -1083,8 +1083,6 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if out is not None:
raise ValueError("var does not support the `out` argument.")
if ddof != 0:
raise NotImplementedError("Only implemented for ddof=0.")
if dtype is None:
if (onp.issubdtype(_dtype(a), onp.bool_) or
onp.issubdtype(_dtype(a), onp.integer)):
@ -1092,7 +1090,16 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
centered = subtract(a, mean(a, axis, dtype=dtype, keepdims=True))
if iscomplexobj(centered):
centered = lax.abs(centered)
return mean(lax.mul(centered, centered), axis, dtype=dtype, keepdims=keepdims)
if axis is None:
normalizer = size(a)
else:
normalizer = onp.prod(onp.take(shape(a), axis))
normalizer = normalizer - ddof
return lax.div(
sum(lax.mul(centered, centered), axis, dtype=dtype, keepdims=keepdims),
lax.convert_element_type(normalizer, dtype))
@_wraps(onp.std)
@ -2720,6 +2727,13 @@ def percentile(a, q, axis=None, out=None, overwrite_input=False,
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
@_wraps(onp.median)
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
q = 0.5
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims)
### track unimplemented functions
def _not_implemented(fun):

View File

@ -1521,6 +1521,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for (op, q_rng) in (
("percentile", jtu.rand_uniform(low=0., high=100.)),
("quantile", jtu.rand_uniform(low=0., high=1.)),
("median", jtu.rand_uniform(low=0., high=1.)),
)
for a_dtype in float_dtypes
for a_shape, axis in (
@ -1536,7 +1537,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
axis, keepdims):
if op == "quantile" and numpy_version < (1, 15):
raise SkipTest("Numpy < 1.15 does not have np.quantile")
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
if op == "median":
args_maker = lambda: [a_rng(a_shape, a_dtype)]
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
onp_fun = partial(getattr(onp, op), axis=axis, keepdims=keepdims)
lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims)
# TODO(phawkins): we currently set dtype=False because we aren't as
@ -1739,12 +1743,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testIssue956(self):
self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1)))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims, "rng": rng}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in number_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]
for rng in [jtu.rand_default()]))
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
onp_fun = partial(onp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
"shape":shape, "dtype":dtype, "rowvar":rowvar, "ddof":ddof,
"bias":bias, "rng": rng}
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng": rng}
for shape in [(5,), (10, 5), (3, 10)]
for dtype in all_dtypes
for rowvar in [True, False]
@ -1766,8 +1790,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
"shape":shape, "dtype":dtype, "rowvar":rowvar, "ddof":ddof,
"bias":bias, "rng": rng}
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng": rng}
for shape in [(5,), (10, 5), (3, 10)]
for dtype in number_dtypes
for rowvar in [True, False]