From 562e9e8dff57129d8ba10298aa2dbb4ba0d6b4e0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 13:13:29 +0000 Subject: [PATCH] Fix an incorrect output for jnp.cumsum. If dtype=bool but a non-bool input is passed, we should test for non-equality with zero rather than performing a cast to integer. --- CHANGELOG.md | 3 +++ jax/_src/numpy/reductions.py | 6 +++++- tests/lax_numpy_reducers_test.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43db6e197..5bdcd1c20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. The argument was only used by `xmap` which was removed in 0.4.31. +* Bug fixes + * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs + if a non-boolean input was provided and `dtype=bool` was specified. ## jax 0.4.33 (September 16, 2024) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 043c976ef..3436b00cf 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -1810,16 +1810,20 @@ def _cumulative_reduction( if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + a_type: DType = dtypes.dtype(a) result_type: DTypeLike = dtypes.dtype(dtype or a) if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): result_type = _promote_integer_dtype(result_type) result_type = dtypes.canonicalize_dtype(result_type) + if a_type != np.bool_ and dtype == np.bool_: + a = lax_internal.asarray(a).astype(np.bool_) + a = lax.convert_element_type(a, result_type) result = reduction(a, axis) # We downcast to boolean because we accumulate in integer types - if dtypes.issubdtype(dtype, np.bool_): + if dtype is not None and dtypes.issubdtype(dtype, np.bool_): result = lax.convert_element_type(result, np.bool_) return result diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 33830c541..623c11a51 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -861,6 +861,10 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, msg): jnp.cumulative_sum(x, include_initial=include_initial) + def testCumulativeSumBool(self): + out = jnp.cumulative_sum(jnp.array([[0.1], [0.1], [0.0]]), axis=-1, + dtype=jnp.bool_) + np.testing.assert_array_equal(np.array([[True], [True], [False]]), out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())