Fixing weird behavior in segment_sum when num_segments is None (#4034)

Co-authored-by: alvarosg <alvarosg@google.com>
This commit is contained in:
Alvaro 2020-09-11 18:51:42 +01:00 committed by GitHub
parent 3b7329c92e
commit ca1d8f4109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 3 deletions

View File

@ -316,8 +316,10 @@ def segment_sum(data, segment_ids, num_segments=None,
need not be sorted. Values outside of the range [0, num_segments) are
wrapped into that range by applying jnp.mod.
num_segments: optional, an int with positive value indicating the number of
segments. The default is ``max(segment_ids % data.shape[0]) + 1`` but
since `num_segments` determines the size of the output, a static value
segments. The default is set to be the minimum number of segments that
would support all positive and negative indices in `segment_ids`
calculated as ``max(max(segment_ids) + 1, max(-segment_ids))``.
Since `num_segments` determines the size of the output, a static value
must be provided to use `segment_sum` in a `jit`-compiled function.
indices_are_sorted: whether `segment_ids` is known to be sorted
unique_indices: whether `segment_ids` is known to be free of duplicates
@ -327,7 +329,7 @@ def segment_sum(data, segment_ids, num_segments=None,
segment sums.
"""
if num_segments is None:
num_segments = jnp.max(jnp.mod(segment_ids, data.shape[0])) + 1
num_segments = max(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
num_segments = int(num_segments)
out = jnp.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

View File

@ -1058,11 +1058,30 @@ class IndexedUpdateTest(jtu.JaxTestCase):
expected = np.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with explicit num_segments larger than the higher index.
ans = ops.segment_sum(data, segment_ids, num_segments=5)
expected = np.array([13, 2, 7, 4, 0])
self.assertAllClose(ans, expected, check_dtypes=False)
# test without explicit num_segments
ans = ops.segment_sum(data, segment_ids)
expected = np.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and segment ids larger than num_segments,
# that will be wrapped with the `mod`.
segment_ids = np.array([0, 4, 8, 1, 2, -6, -1, 3])
ans = ops.segment_sum(data, segment_ids, num_segments=4)
expected = np.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and without without explicit num_segments
# such as num_segments is defined by the smaller index.
segment_ids = np.array([3, 3, 3, 4, 5, 5, -7, -6])
ans = ops.segment_sum(data, segment_ids)
expected = np.array([1, 3, 0, 13, 2, 7, 0])
self.assertAllClose(ans, expected, check_dtypes=False)
def testIndexDtypeError(self):
# https://github.com/google/jax/issues/2795
jnp.array(1) # get rid of startup warning