mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fixing weird behavior in segment_sum when num_segments is None (#4034)
Co-authored-by: alvarosg <alvarosg@google.com>
This commit is contained in:
parent
3b7329c92e
commit
ca1d8f4109
@ -316,8 +316,10 @@ def segment_sum(data, segment_ids, num_segments=None,
|
||||
need not be sorted. Values outside of the range [0, num_segments) are
|
||||
wrapped into that range by applying jnp.mod.
|
||||
num_segments: optional, an int with positive value indicating the number of
|
||||
segments. The default is ``max(segment_ids % data.shape[0]) + 1`` but
|
||||
since `num_segments` determines the size of the output, a static value
|
||||
segments. The default is set to be the minimum number of segments that
|
||||
would support all positive and negative indices in `segment_ids`
|
||||
calculated as ``max(max(segment_ids) + 1, max(-segment_ids))``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use `segment_sum` in a `jit`-compiled function.
|
||||
indices_are_sorted: whether `segment_ids` is known to be sorted
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates
|
||||
@ -327,7 +329,7 @@ def segment_sum(data, segment_ids, num_segments=None,
|
||||
segment sums.
|
||||
"""
|
||||
if num_segments is None:
|
||||
num_segments = jnp.max(jnp.mod(segment_ids, data.shape[0])) + 1
|
||||
num_segments = max(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
|
||||
num_segments = int(num_segments)
|
||||
|
||||
out = jnp.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)
|
||||
|
@ -1058,11 +1058,30 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with explicit num_segments larger than the higher index.
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=5)
|
||||
expected = np.array([13, 2, 7, 4, 0])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test without explicit num_segments
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with negative segment ids and segment ids larger than num_segments,
|
||||
# that will be wrapped with the `mod`.
|
||||
segment_ids = np.array([0, 4, 8, 1, 2, -6, -1, 3])
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with negative segment ids and without without explicit num_segments
|
||||
# such as num_segments is defined by the smaller index.
|
||||
segment_ids = np.array([3, 3, 3, 4, 5, 5, -7, -6])
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
expected = np.array([1, 3, 0, 13, 2, 7, 0])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testIndexDtypeError(self):
|
||||
# https://github.com/google/jax/issues/2795
|
||||
jnp.array(1) # get rid of startup warning
|
||||
|
Loading…
x
Reference in New Issue
Block a user