mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior. Fixes https://github.com/google/jax/issues/8634 PiperOrigin-RevId: 411635687
This commit is contained in:
parent
5415306257
commit
4679f455f9
@ -13,6 +13,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.26 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
|
||||
* Bug fixes:
|
||||
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
|
||||
`FILL_OR_DROP` semantics, as documented. This primarily afects the
|
||||
reverse-mode derivative, where gradients corresponding to out-of-bounds
|
||||
indices will now be returned as 0. (#8634).
|
||||
|
||||
## jaxlib 0.1.74 (Nov 17, 2021)
|
||||
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
|
||||
|
@ -412,8 +412,10 @@ def _segment_update(name: str,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: Optional[int] = None,
|
||||
reducer: Optional[Callable] = None) -> Array:
|
||||
reducer: Optional[Callable] = None,
|
||||
mode: Optional[lax.GatherScatterMode] = None) -> Array:
|
||||
jnp._check_arraylike(name, data, segment_ids)
|
||||
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
|
||||
data = jnp.asarray(data)
|
||||
segment_ids = jnp.asarray(segment_ids)
|
||||
dtype = data.dtype
|
||||
@ -430,7 +432,7 @@ def _segment_update(name: str,
|
||||
if num_buckets == 1:
|
||||
return _scatter_update(
|
||||
out, segment_ids, data, scatter_op, indices_are_sorted,
|
||||
unique_indices, normalize_indices=False)
|
||||
unique_indices, normalize_indices=False, mode=mode)
|
||||
|
||||
# Bucketize indices and perform segment_update on each bucket to improve
|
||||
# numerical stability for operations like product and sum.
|
||||
@ -450,7 +452,8 @@ def segment_sum(data: Array,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: Optional[int] = None) -> Array:
|
||||
bucket_size: Optional[int] = None,
|
||||
mode: Optional[lax.GatherScatterMode] = None) -> Array:
|
||||
"""Computes the sum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_sum
|
||||
@ -460,8 +463,7 @@ def segment_sum(data: Array,
|
||||
data: an array with the values to be summed.
|
||||
segment_ids: an array with integer dtype that indicates the segments of
|
||||
`data` (along its leading axis) to be summed. Values can be repeated and
|
||||
need not be sorted. Values outside of the range [0, num_segments) are
|
||||
dropped and do not contribute to the sum.
|
||||
need not be sorted.
|
||||
num_segments: optional, an int with nonnegative value indicating the number
|
||||
of segments. The default is set to be the minimum number of segments that
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
@ -473,6 +475,9 @@ def segment_sum(data: Array,
|
||||
bucket_size: size of bucket to group indices into. ``segment_sum`` is
|
||||
performed on each bucket separately to improve numerical stability of
|
||||
addition. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the sum.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
@ -492,8 +497,9 @@ def segment_sum(data: Array,
|
||||
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([1, 5, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update("segment_sum", data, segment_ids, lax.scatter_add, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.sum)
|
||||
return _segment_update(
|
||||
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode)
|
||||
|
||||
|
||||
def segment_prod(data: Array,
|
||||
@ -501,7 +507,8 @@ def segment_prod(data: Array,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: Optional[int] = None) -> Array:
|
||||
bucket_size: Optional[int] = None,
|
||||
mode: Optional[lax.GatherScatterMode] = None) -> Array:
|
||||
"""Computes the product within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_prod
|
||||
@ -524,6 +531,9 @@ def segment_prod(data: Array,
|
||||
bucket_size: size of bucket to group indices into. ``segment_prod`` is
|
||||
performed on each bucket separately to improve numerical stability of
|
||||
addition. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the sum.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
@ -543,8 +553,9 @@ def segment_prod(data: Array,
|
||||
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([ 0, 6, 20], dtype=int32)
|
||||
"""
|
||||
return _segment_update("segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.prod)
|
||||
return _segment_update(
|
||||
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode)
|
||||
|
||||
|
||||
def segment_max(data: Array,
|
||||
@ -552,7 +563,8 @@ def segment_max(data: Array,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: Optional[int] = None) -> Array:
|
||||
bucket_size: Optional[int] = None,
|
||||
mode: Optional[lax.GatherScatterMode] = None) -> Array:
|
||||
"""Computes the maximum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_max
|
||||
@ -574,6 +586,9 @@ def segment_max(data: Array,
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_max`` is
|
||||
performed on each bucket separately. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the sum.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
@ -593,8 +608,9 @@ def segment_max(data: Array,
|
||||
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([1, 3, 5], dtype=int32)
|
||||
"""
|
||||
return _segment_update("segment_max", data, segment_ids, lax.scatter_max, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.max)
|
||||
return _segment_update(
|
||||
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode)
|
||||
|
||||
|
||||
def segment_min(data: Array,
|
||||
@ -602,7 +618,8 @@ def segment_min(data: Array,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: Optional[int] = None) -> Array:
|
||||
bucket_size: Optional[int] = None,
|
||||
mode: Optional[lax.GatherScatterMode] = None) -> Array:
|
||||
"""Computes the minimum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_min
|
||||
@ -624,6 +641,9 @@ def segment_min(data: Array,
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_min`` is
|
||||
performed on each bucket separately. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the sum.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
@ -643,5 +663,6 @@ def segment_min(data: Array,
|
||||
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([0, 2, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update("segment_min", data, segment_ids, lax.scatter_min, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.min)
|
||||
return _segment_update(
|
||||
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode)
|
||||
|
@ -1155,6 +1155,17 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
expected = jnp.array([0, 0, 0, 13, 2, 7])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSegmentSumOutOfBounds(self):
|
||||
def fn(data, segment_ids):
|
||||
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()
|
||||
|
||||
data = np.array([0, 0], dtype=np.float32)
|
||||
num_segments = 2
|
||||
segment_ids = np.array([2, 3])
|
||||
val, grad = jax.value_and_grad(fn)(data, segment_ids)
|
||||
self.assertAllClose(val, np.array(0., np.float32))
|
||||
self.assertAllClose(grad, np.array([0., 0.], np.float32))
|
||||
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list({
|
||||
|
Loading…
x
Reference in New Issue
Block a user