diff --git a/jax/ops/scatter.py b/jax/ops/scatter.py index 4131a45f5..c9dc196a6 100644 --- a/jax/ops/scatter.py +++ b/jax/ops/scatter.py @@ -316,8 +316,10 @@ def segment_sum(data, segment_ids, num_segments=None, need not be sorted. Values outside of the range [0, num_segments) are wrapped into that range by applying jnp.mod. num_segments: optional, an int with positive value indicating the number of - segments. The default is ``max(segment_ids % data.shape[0]) + 1`` but - since `num_segments` determines the size of the output, a static value + segments. The default is set to be the minimum number of segments that + would support all positive and negative indices in `segment_ids` + calculated as ``max(max(segment_ids) + 1, max(-segment_ids))``. + Since `num_segments` determines the size of the output, a static value must be provided to use `segment_sum` in a `jit`-compiled function. indices_are_sorted: whether `segment_ids` is known to be sorted unique_indices: whether `segment_ids` is known to be free of duplicates @@ -327,7 +329,7 @@ def segment_sum(data, segment_ids, num_segments=None, segment sums. """ if num_segments is None: - num_segments = jnp.max(jnp.mod(segment_ids, data.shape[0])) + 1 + num_segments = max(jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) num_segments = int(num_segments) out = jnp.zeros((num_segments,) + data.shape[1:], dtype=data.dtype) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index f22d3e70d..3e41f7395 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1058,11 +1058,30 @@ class IndexedUpdateTest(jtu.JaxTestCase): expected = np.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) + # test with explicit num_segments larger than the higher index. + ans = ops.segment_sum(data, segment_ids, num_segments=5) + expected = np.array([13, 2, 7, 4, 0]) + self.assertAllClose(ans, expected, check_dtypes=False) + # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = np.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) + # test with negative segment ids and segment ids larger than num_segments, + # that will be wrapped with the `mod`. + segment_ids = np.array([0, 4, 8, 1, 2, -6, -1, 3]) + ans = ops.segment_sum(data, segment_ids, num_segments=4) + expected = np.array([13, 2, 7, 4]) + self.assertAllClose(ans, expected, check_dtypes=False) + + # test with negative segment ids and without without explicit num_segments + # such as num_segments is defined by the smaller index. + segment_ids = np.array([3, 3, 3, 4, 5, 5, -7, -6]) + ans = ops.segment_sum(data, segment_ids) + expected = np.array([1, 3, 0, 13, 2, 7, 0]) + self.assertAllClose(ans, expected, check_dtypes=False) + def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning