mirror of
https://github.com/ROCm/jax.git
synced 2025-04-27 06:26:08 +00:00
Add test coverage for jnp.cov aweights & fweights
This commit is contained in:
parent
fa1133885b
commit
a51a4d91b3
@ -4256,17 +4256,15 @@ def compress(condition, a, axis=None, out=None):
|
||||
@_wraps(np.cov)
|
||||
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
aweights=None):
|
||||
_check_arraylike("cov", m)
|
||||
msg = ("jax.numpy.cov not implemented for nontrivial {}. "
|
||||
"Open a feature request at https://github.com/google/jax/issues !")
|
||||
if y is not None: raise NotImplementedError(msg.format('y'))
|
||||
# These next two are actually implemented, just not tested.
|
||||
if fweights is not None: raise NotImplementedError(msg.format('fweights'))
|
||||
if aweights is not None: raise NotImplementedError(msg.format('aweights'))
|
||||
if y is not None: raise NotImplementedError(
|
||||
"jax.numpy.cov not implemented for nontrivial y. "
|
||||
"Open a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
m, = _promote_args_inexact("cov", m)
|
||||
|
||||
if m.ndim > 2:
|
||||
raise ValueError("m has more than 2 dimensions") # same as numpy error
|
||||
X = array(m, ndmin=2, dtype=dtypes.canonicalize_dtype(result_type(m, float_)))
|
||||
X = atleast_2d(m)
|
||||
if not rowvar and X.shape[0] != 1:
|
||||
X = X.T
|
||||
if X.shape[0] == 0:
|
||||
@ -4276,16 +4274,23 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
|
||||
w = None
|
||||
if fweights is not None:
|
||||
if np.ndim(fweights) > 1:
|
||||
_check_arraylike("cov", fweights)
|
||||
if ndim(fweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional fweights")
|
||||
if np.shape(fweights)[0] != X.shape[1]:
|
||||
if shape(fweights)[0] != X.shape[1]:
|
||||
raise RuntimeError("incompatible numbers of samples and fweights")
|
||||
w = asarray(fweights)
|
||||
if not issubdtype(_dtype(fweights), integer):
|
||||
raise TypeError("fweights must be integer.")
|
||||
# Ensure positive fweights; note that numpy raises an error on negative fweights.
|
||||
w = asarray(abs(fweights))
|
||||
if aweights is not None:
|
||||
if np.ndim(aweights) > 1:
|
||||
_check_arraylike("cov", aweights)
|
||||
if ndim(aweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional aweights")
|
||||
if np.shape(aweights)[0] != X.shape[1]:
|
||||
if shape(aweights)[0] != X.shape[1]:
|
||||
raise RuntimeError("incompatible numbers of samples and aweights")
|
||||
# Ensure positive aweights: note that numpy raises an error for negative aweights.
|
||||
aweights = abs(aweights)
|
||||
w = aweights if w is None else w * aweights
|
||||
|
||||
avg, w_sum = average(X, axis=1, weights=w, returned=True)
|
||||
|
@ -3748,23 +3748,31 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
||||
shape, dtype, rowvar, ddof, bias),
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
|
||||
shape, dtype, rowvar, ddof, bias, fweights, aweights),
|
||||
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "rng_factory": rng_factory}
|
||||
"bias": bias, "fweights": fweights, "aweights": aweights}
|
||||
for shape in [(5,), (10, 5), (5, 10)]
|
||||
for dtype in all_dtypes
|
||||
for rowvar in [True, False]
|
||||
for bias in [True, False]
|
||||
for ddof in [None, 2, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
np_fun = partial(np.cov, rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
jnp_fun = partial(jnp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
tol = {np.float32: 1e-5, np.complex64: 1e-5,
|
||||
np.float64: 1e-13, np.complex128: 1e-13}
|
||||
for fweights in [True, False]
|
||||
for aweights in [True, False]))
|
||||
def testCov(self, shape, dtype, rowvar, ddof, bias, fweights, aweights):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
wrng = jtu.rand_positive(self.rng())
|
||||
wdtype = np.real(dtype(0)).dtype
|
||||
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
|
||||
args_maker = lambda: [rng(shape, dtype),
|
||||
wrng(wshape, int) if fweights else None,
|
||||
wrng(wshape, wdtype) if aweights else None]
|
||||
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
np_fun = lambda m, f, a: np.cov(m, fweights=f, aweights=a, **kwargs)
|
||||
jnp_fun = lambda m, f, a: jnp.cov(m, fweights=f, aweights=a, **kwargs)
|
||||
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
|
||||
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
|
||||
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
|
||||
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
|
||||
self._CheckAgainstNumpy(
|
||||
|
Loading…
x
Reference in New Issue
Block a user