mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1088 from fehiepsi/median
Add numpy.median and support ddof for numpy.var
This commit is contained in:
commit
fd98f957a9
@ -1083,8 +1083,6 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
||||
if out is not None:
|
||||
raise ValueError("var does not support the `out` argument.")
|
||||
|
||||
if ddof != 0:
|
||||
raise NotImplementedError("Only implemented for ddof=0.")
|
||||
if dtype is None:
|
||||
if (onp.issubdtype(_dtype(a), onp.bool_) or
|
||||
onp.issubdtype(_dtype(a), onp.integer)):
|
||||
@ -1092,7 +1090,16 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
||||
centered = subtract(a, mean(a, axis, dtype=dtype, keepdims=True))
|
||||
if iscomplexobj(centered):
|
||||
centered = lax.abs(centered)
|
||||
return mean(lax.mul(centered, centered), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
else:
|
||||
normalizer = onp.prod(onp.take(shape(a), axis))
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
return lax.div(
|
||||
sum(lax.mul(centered, centered), axis, dtype=dtype, keepdims=keepdims),
|
||||
lax.convert_element_type(normalizer, dtype))
|
||||
|
||||
|
||||
@_wraps(onp.std)
|
||||
@ -2720,6 +2727,13 @@ def percentile(a, q, axis=None, out=None, overwrite_input=False,
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
|
||||
|
||||
@_wraps(onp.median)
|
||||
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
|
||||
q = 0.5
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims)
|
||||
|
||||
### track unimplemented functions
|
||||
|
||||
def _not_implemented(fun):
|
||||
|
@ -1521,6 +1521,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for (op, q_rng) in (
|
||||
("percentile", jtu.rand_uniform(low=0., high=100.)),
|
||||
("quantile", jtu.rand_uniform(low=0., high=1.)),
|
||||
("median", jtu.rand_uniform(low=0., high=1.)),
|
||||
)
|
||||
for a_dtype in float_dtypes
|
||||
for a_shape, axis in (
|
||||
@ -1536,7 +1537,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
axis, keepdims):
|
||||
if op == "quantile" and numpy_version < (1, 15):
|
||||
raise SkipTest("Numpy < 1.15 does not have np.quantile")
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
||||
if op == "median":
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
||||
else:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
||||
onp_fun = partial(getattr(onp, op), axis=axis, keepdims=keepdims)
|
||||
lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims)
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
@ -1739,12 +1743,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testIssue956(self):
|
||||
self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims, "rng": rng}
|
||||
for shape in [(5,), (10, 5)]
|
||||
for dtype in all_dtypes
|
||||
for out_dtype in number_dtypes
|
||||
for axis in [None, 0, -1]
|
||||
for ddof in [0, 1, 2]
|
||||
for keepdims in [False, True]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
onp_fun = partial(onp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
||||
shape, dtype, rowvar, ddof, bias),
|
||||
"shape":shape, "dtype":dtype, "rowvar":rowvar, "ddof":ddof,
|
||||
"bias":bias, "rng": rng}
|
||||
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "rng": rng}
|
||||
for shape in [(5,), (10, 5), (3, 10)]
|
||||
for dtype in all_dtypes
|
||||
for rowvar in [True, False]
|
||||
@ -1766,8 +1790,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
||||
shape, dtype, rowvar, ddof, bias),
|
||||
"shape":shape, "dtype":dtype, "rowvar":rowvar, "ddof":ddof,
|
||||
"bias":bias, "rng": rng}
|
||||
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "rng": rng}
|
||||
for shape in [(5,), (10, 5), (3, 10)]
|
||||
for dtype in number_dtypes
|
||||
for rowvar in [True, False]
|
||||
|
Loading…
x
Reference in New Issue
Block a user