mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
BUG: fix reduction for scalar input
This commit is contained in:
parent
88f5e26482
commit
c70c3d5063
@ -2065,7 +2065,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
else:
|
||||
normalizer = _axis_size(a, axis)
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, a.shape), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if dtype is None:
|
||||
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
|
||||
@ -2150,7 +2150,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
else:
|
||||
normalizer = _axis_size(a, axis)
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, a.shape), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
result = sum(centered, axis, keepdims=keepdims, where=where)
|
||||
|
Loading…
x
Reference in New Issue
Block a user