mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix rank promotion error in jnp.cov
This commit is contained in:
parent
651ddb5aa2
commit
3902515ef2
@ -5538,7 +5538,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
f = w_sum - ddof * sum(w * aweights) / w_sum
|
||||
|
||||
X = X - avg[:, None]
|
||||
X_T = X.T if w is None else (X * w).T
|
||||
X_T = X.T if w is None else (X * lax.broadcast_to_rank(w, X.ndim)).T
|
||||
return true_divide(dot(X, X_T.conj()), f).squeeze()
|
||||
|
||||
|
||||
|
@ -4679,6 +4679,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for ddof in [None, 2, 3]
|
||||
for fweights in [True, False]
|
||||
for aweights in [True, False]))
|
||||
@jax.numpy_rank_promotion('raise')
|
||||
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
wrng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user