Fix rank promotion error in jnp.cov

This commit is contained in:
Jake VanderPlas 2021-07-12 13:24:02 -07:00
parent 651ddb5aa2
commit 3902515ef2
2 changed files with 2 additions and 1 deletions

View File

@ -5538,7 +5538,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
f = w_sum - ddof * sum(w * aweights) / w_sum
X = X - avg[:, None]
X_T = X.T if w is None else (X * w).T
X_T = X.T if w is None else (X * lax.broadcast_to_rank(w, X.ndim)).T
return true_divide(dot(X, X_T.conj()), f).squeeze()

View File

@ -4679,6 +4679,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for ddof in [None, 2, 3]
for fweights in [True, False]
for aweights in [True, False]))
@jax.numpy_rank_promotion('raise')
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
rng = jtu.rand_default(self.rng())
wrng = jtu.rand_positive(self.rng())