mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for non-equality with zero rather than performing a cast to integer.
This commit is contained in:
parent
80cb821a79
commit
562e9e8dff
@ -35,6 +35,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
|
||||
The argument was only used by `xmap` which was removed in 0.4.31.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
|
||||
if a non-boolean input was provided and `dtype=bool` was specified.
|
||||
|
||||
## jax 0.4.33 (September 16, 2024)
|
||||
|
||||
|
@ -1810,16 +1810,20 @@ def _cumulative_reduction(
|
||||
if fill_nan:
|
||||
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
|
||||
|
||||
a_type: DType = dtypes.dtype(a)
|
||||
result_type: DTypeLike = dtypes.dtype(dtype or a)
|
||||
if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_):
|
||||
result_type = _promote_integer_dtype(result_type)
|
||||
result_type = dtypes.canonicalize_dtype(result_type)
|
||||
|
||||
if a_type != np.bool_ and dtype == np.bool_:
|
||||
a = lax_internal.asarray(a).astype(np.bool_)
|
||||
|
||||
a = lax.convert_element_type(a, result_type)
|
||||
result = reduction(a, axis)
|
||||
|
||||
# We downcast to boolean because we accumulate in integer types
|
||||
if dtypes.issubdtype(dtype, np.bool_):
|
||||
if dtype is not None and dtypes.issubdtype(dtype, np.bool_):
|
||||
result = lax.convert_element_type(result, np.bool_)
|
||||
return result
|
||||
|
||||
|
@ -861,6 +861,10 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jnp.cumulative_sum(x, include_initial=include_initial)
|
||||
|
||||
def testCumulativeSumBool(self):
|
||||
out = jnp.cumulative_sum(jnp.array([[0.1], [0.1], [0.0]]), axis=-1,
|
||||
dtype=jnp.bool_)
|
||||
np.testing.assert_array_equal(np.array([[True], [True], [False]]), out)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user