Forbid collapsing of size-0 dimensions in gather() operations.

The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.

This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.

PiperOrigin-RevId: 406346374
This commit is contained in:
Peter Hawkins 2021-10-29 06:33:47 -07:00 committed by jax authors
parent 345ab50963
commit d0065d8a76
3 changed files with 25 additions and 7 deletions

View File

@ -4723,9 +4723,9 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
for i in range(len(collapsed_slice_dims)):
bound = slice_sizes[collapsed_slice_dims[i]]
if bound > 1:
raise TypeError(f"Gather op can only collapse slice dims with bound 1 "
f"or 0, but bound is {bound} for index "
if bound != 1:
raise TypeError(f"Gather op can only collapse slice dims with bound 1, "
f"but bound is {bound} for index "
f"{collapsed_slice_dims[i]} at position {i}.")
expanded_indices_shape.pop(index_vector_dim)
@ -4857,6 +4857,20 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
operand = batching.moveaxis(operand, operand_bdim, 0)
indices = batching.moveaxis(indices, indices_bdim, 0)
# This slightly awkward special case is needed because the shape rule for
# gather does not allow size-1 slices out of a size-0 dimension, even if
# the number of slices is zero. Likely the best fix would be to change the
# definition of gather() so it can be batched without the construction of
# an explicit iota of size-1 slices.
if core.symbolic_equal_dim(operand.shape[0], 0):
output_shape = _gather_shape_rule(
core.ShapedArray(operand.shape[1:], operand.dtype),
core.ShapedArray(indices.shape[1:], indices.dtype),
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
return full((0,) + output_shape, _zero(operand)), 0
# Example: user code had indices shape (3, 4, 5), and we have to deal with
# indices shape (7, 3, 4, 5). We transform that to indices of shape
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
@ -4866,8 +4880,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
counts = broadcasted_iota(indices.dtype, tuple(count_shape), 0)
indices = concatenate([counts, indices], len(count_shape) - 1)
batch_slice_size = 1 if core.greater_equal_dim(operand.shape[0], 1) else 0
slice_sizes = (batch_slice_size,) + slice_sizes
slice_sizes = (1,) + slice_sizes
collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))

View File

@ -5449,7 +5449,12 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
return a
slice_sizes[axis_idx] = _min(indices.size, 1)
if indices.size == 0:
out_shape = (slice_sizes[:axis_idx] + list(indices.shape) +
slice_sizes[axis_idx + 1:])
return full_like(a, 0, shape=out_shape)
slice_sizes[axis_idx] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(
list(range(axis_idx)) +

View File

@ -2041,7 +2041,7 @@ class LaxTest(jtu.JaxTestCase):
"Gather op must have one slice size for every input dimension"),
("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
(4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6),
("Gather op can only collapse slice dims with bound 1 or 0, but bound "
("Gather op can only collapse slice dims with bound 1, but bound "
"is 9 for index 1 at position 0."))
]
))