mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Forbid collapsing of size-0 dimensions in gather() operations.
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule. This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively. PiperOrigin-RevId: 406346374
This commit is contained in:
parent
345ab50963
commit
d0065d8a76
@ -4723,9 +4723,9 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
|
||||
|
||||
for i in range(len(collapsed_slice_dims)):
|
||||
bound = slice_sizes[collapsed_slice_dims[i]]
|
||||
if bound > 1:
|
||||
raise TypeError(f"Gather op can only collapse slice dims with bound 1 "
|
||||
f"or 0, but bound is {bound} for index "
|
||||
if bound != 1:
|
||||
raise TypeError(f"Gather op can only collapse slice dims with bound 1, "
|
||||
f"but bound is {bound} for index "
|
||||
f"{collapsed_slice_dims[i]} at position {i}.")
|
||||
|
||||
expanded_indices_shape.pop(index_vector_dim)
|
||||
@ -4857,6 +4857,20 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
operand = batching.moveaxis(operand, operand_bdim, 0)
|
||||
indices = batching.moveaxis(indices, indices_bdim, 0)
|
||||
|
||||
# This slightly awkward special case is needed because the shape rule for
|
||||
# gather does not allow size-1 slices out of a size-0 dimension, even if
|
||||
# the number of slices is zero. Likely the best fix would be to change the
|
||||
# definition of gather() so it can be batched without the construction of
|
||||
# an explicit iota of size-1 slices.
|
||||
if core.symbolic_equal_dim(operand.shape[0], 0):
|
||||
output_shape = _gather_shape_rule(
|
||||
core.ShapedArray(operand.shape[1:], operand.dtype),
|
||||
core.ShapedArray(indices.shape[1:], indices.dtype),
|
||||
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, fill_value=fill_value)
|
||||
return full((0,) + output_shape, _zero(operand)), 0
|
||||
|
||||
# Example: user code had indices shape (3, 4, 5), and we have to deal with
|
||||
# indices shape (7, 3, 4, 5). We transform that to indices of shape
|
||||
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
|
||||
@ -4866,8 +4880,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
counts = broadcasted_iota(indices.dtype, tuple(count_shape), 0)
|
||||
indices = concatenate([counts, indices], len(count_shape) - 1)
|
||||
|
||||
batch_slice_size = 1 if core.greater_equal_dim(operand.shape[0], 1) else 0
|
||||
slice_sizes = (batch_slice_size,) + slice_sizes
|
||||
slice_sizes = (1,) + slice_sizes
|
||||
collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
|
||||
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
|
||||
start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))
|
||||
|
@ -5449,7 +5449,12 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
|
||||
return a
|
||||
|
||||
slice_sizes[axis_idx] = _min(indices.size, 1)
|
||||
if indices.size == 0:
|
||||
out_shape = (slice_sizes[:axis_idx] + list(indices.shape) +
|
||||
slice_sizes[axis_idx + 1:])
|
||||
return full_like(a, 0, shape=out_shape)
|
||||
|
||||
slice_sizes[axis_idx] = 1
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(
|
||||
list(range(axis_idx)) +
|
||||
|
@ -2041,7 +2041,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"Gather op must have one slice size for every input dimension"),
|
||||
("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
||||
(4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6),
|
||||
("Gather op can only collapse slice dims with bound 1 or 0, but bound "
|
||||
("Gather op can only collapse slice dims with bound 1, but bound "
|
||||
"is 9 for index 1 at position 0."))
|
||||
]
|
||||
))
|
||||
|
Loading…
x
Reference in New Issue
Block a user