Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.

This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
This commit is contained in:
Peter Hawkins 2021-11-22 13:32:25 -08:00 committed by jax authors
parent 5415306257
commit 4679f455f9
3 changed files with 53 additions and 16 deletions

View File

@ -13,6 +13,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.26 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
`FILL_OR_DROP` semantics, as documented. This primarily afects the
reverse-mode derivative, where gradients corresponding to out-of-bounds
indices will now be returned as 0. (#8634).
## jaxlib 0.1.74 (Nov 17, 2021)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via

View File

@ -412,8 +412,10 @@ def _segment_update(name: str,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
reducer: Optional[Callable] = None) -> Array:
reducer: Optional[Callable] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
jnp._check_arraylike(name, data, segment_ids)
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
data = jnp.asarray(data)
segment_ids = jnp.asarray(segment_ids)
dtype = data.dtype
@ -430,7 +432,7 @@ def _segment_update(name: str,
if num_buckets == 1:
return _scatter_update(
out, segment_ids, data, scatter_op, indices_are_sorted,
unique_indices, normalize_indices=False)
unique_indices, normalize_indices=False, mode=mode)
# Bucketize indices and perform segment_update on each bucket to improve
# numerical stability for operations like product and sum.
@ -450,7 +452,8 @@ def segment_sum(data: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the sum within segments of an array.
Similar to TensorFlow's `segment_sum
@ -460,8 +463,7 @@ def segment_sum(data: Array,
data: an array with the values to be summed.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted. Values outside of the range [0, num_segments) are
dropped and do not contribute to the sum.
need not be sorted.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
@ -473,6 +475,9 @@ def segment_sum(data: Array,
bucket_size: size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
@ -492,8 +497,9 @@ def segment_sum(data: Array,
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 5, 4], dtype=int32)
"""
return _segment_update("segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.sum)
return _segment_update(
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode)
def segment_prod(data: Array,
@ -501,7 +507,8 @@ def segment_prod(data: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the product within segments of an array.
Similar to TensorFlow's `segment_prod
@ -524,6 +531,9 @@ def segment_prod(data: Array,
bucket_size: size of bucket to group indices into. ``segment_prod`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
@ -543,8 +553,9 @@ def segment_prod(data: Array,
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
DeviceArray([ 0, 6, 20], dtype=int32)
"""
return _segment_update("segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.prod)
return _segment_update(
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode)
def segment_max(data: Array,
@ -552,7 +563,8 @@ def segment_max(data: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the maximum within segments of an array.
Similar to TensorFlow's `segment_max
@ -574,6 +586,9 @@ def segment_max(data: Array,
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_max`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
@ -593,8 +608,9 @@ def segment_max(data: Array,
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 3, 5], dtype=int32)
"""
return _segment_update("segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.max)
return _segment_update(
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode)
def segment_min(data: Array,
@ -602,7 +618,8 @@ def segment_min(data: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the minimum within segments of an array.
Similar to TensorFlow's `segment_min
@ -624,6 +641,9 @@ def segment_min(data: Array,
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_min`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
@ -643,5 +663,6 @@ def segment_min(data: Array,
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
DeviceArray([0, 2, 4], dtype=int32)
"""
return _segment_update("segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.min)
return _segment_update(
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode)

View File

@ -1155,6 +1155,17 @@ class IndexedUpdateTest(jtu.JaxTestCase):
expected = jnp.array([0, 0, 0, 13, 2, 7])
self.assertAllClose(ans, expected, check_dtypes=False)
def testSegmentSumOutOfBounds(self):
def fn(data, segment_ids):
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()
data = np.array([0, 0], dtype=np.float32)
num_segments = 2
segment_ids = np.array([2, 3])
val, grad = jax.value_and_grad(fn)(data, segment_ids)
self.assertAllClose(val, np.array(0., np.float32))
self.assertAllClose(grad, np.array([0., 0.], np.float32))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({