Add test coverage for jnp.cov aweights & fweights

This commit is contained in:
Jake VanderPlas 2020-09-28 15:34:57 -07:00
parent fa1133885b
commit a51a4d91b3
2 changed files with 37 additions and 24 deletions

View File

@ -4256,17 +4256,15 @@ def compress(condition, a, axis=None, out=None):
@_wraps(np.cov)
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
aweights=None):
_check_arraylike("cov", m)
msg = ("jax.numpy.cov not implemented for nontrivial {}. "
"Open a feature request at https://github.com/google/jax/issues !")
if y is not None: raise NotImplementedError(msg.format('y'))
# These next two are actually implemented, just not tested.
if fweights is not None: raise NotImplementedError(msg.format('fweights'))
if aweights is not None: raise NotImplementedError(msg.format('aweights'))
if y is not None: raise NotImplementedError(
"jax.numpy.cov not implemented for nontrivial y. "
"Open a feature request at https://github.com/google/jax/issues !")
m, = _promote_args_inexact("cov", m)
if m.ndim > 2:
raise ValueError("m has more than 2 dimensions") # same as numpy error
X = array(m, ndmin=2, dtype=dtypes.canonicalize_dtype(result_type(m, float_)))
X = atleast_2d(m)
if not rowvar and X.shape[0] != 1:
X = X.T
if X.shape[0] == 0:
@ -4276,16 +4274,23 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
w = None
if fweights is not None:
if np.ndim(fweights) > 1:
_check_arraylike("cov", fweights)
if ndim(fweights) > 1:
raise RuntimeError("cannot handle multidimensional fweights")
if np.shape(fweights)[0] != X.shape[1]:
if shape(fweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and fweights")
w = asarray(fweights)
if not issubdtype(_dtype(fweights), integer):
raise TypeError("fweights must be integer.")
# Ensure positive fweights; note that numpy raises an error on negative fweights.
w = asarray(abs(fweights))
if aweights is not None:
if np.ndim(aweights) > 1:
_check_arraylike("cov", aweights)
if ndim(aweights) > 1:
raise RuntimeError("cannot handle multidimensional aweights")
if np.shape(aweights)[0] != X.shape[1]:
if shape(aweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and aweights")
# Ensure positive aweights: note that numpy raises an error for negative aweights.
aweights = abs(aweights)
w = aweights if w is None else w * aweights
avg, w_sum = average(X, axis=1, weights=w, returned=True)

View File

@ -3748,23 +3748,31 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
{"testcase_name":
"_shape={}_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
shape, dtype, rowvar, ddof, bias, fweights, aweights),
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng_factory": rng_factory}
"bias": bias, "fweights": fweights, "aweights": aweights}
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for rowvar in [True, False]
for bias in [True, False]
for ddof in [None, 2, 3]
for rng_factory in [jtu.rand_default]))
def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory):
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
np_fun = partial(np.cov, rowvar=rowvar, ddof=ddof, bias=bias)
jnp_fun = partial(jnp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
tol = {np.float32: 1e-5, np.complex64: 1e-5,
np.float64: 1e-13, np.complex128: 1e-13}
for fweights in [True, False]
for aweights in [True, False]))
def testCov(self, shape, dtype, rowvar, ddof, bias, fweights, aweights):
rng = jtu.rand_default(self.rng())
wrng = jtu.rand_positive(self.rng())
wdtype = np.real(dtype(0)).dtype
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
args_maker = lambda: [rng(shape, dtype),
wrng(wshape, int) if fweights else None,
wrng(wshape, wdtype) if aweights else None]
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
np_fun = lambda m, f, a: np.cov(m, fweights=f, aweights=a, **kwargs)
jnp_fun = lambda m, f, a: jnp.cov(m, fweights=f, aweights=a, **kwargs)
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
self._CheckAgainstNumpy(