Fix an incorrect output for jnp.cumsum.

If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
This commit is contained in:
Peter Hawkins 2024-09-24 13:13:29 +00:00
parent 80cb821a79
commit 562e9e8dff
3 changed files with 12 additions and 1 deletions

View File

@ -35,6 +35,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.
* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
if a non-boolean input was provided and `dtype=bool` was specified.
## jax 0.4.33 (September 16, 2024)

View File

@ -1810,16 +1810,20 @@ def _cumulative_reduction(
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
a_type: DType = dtypes.dtype(a)
result_type: DTypeLike = dtypes.dtype(dtype or a)
if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_):
result_type = _promote_integer_dtype(result_type)
result_type = dtypes.canonicalize_dtype(result_type)
if a_type != np.bool_ and dtype == np.bool_:
a = lax_internal.asarray(a).astype(np.bool_)
a = lax.convert_element_type(a, result_type)
result = reduction(a, axis)
# We downcast to boolean because we accumulate in integer types
if dtypes.issubdtype(dtype, np.bool_):
if dtype is not None and dtypes.issubdtype(dtype, np.bool_):
result = lax.convert_element_type(result, np.bool_)
return result

View File

@ -861,6 +861,10 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)
def testCumulativeSumBool(self):
out = jnp.cumulative_sum(jnp.array([[0.1], [0.1], [0.0]]), axis=-1,
dtype=jnp.bool_)
np.testing.assert_array_equal(np.array([[True], [True], [False]]), out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())