Pallas GPU no longer assumes that all slices have stride 1

Fixes #20895.

PiperOrigin-RevId: 639031975
This commit is contained in:
Sergei Lebedev 2024-05-31 07:43:31 -07:00 committed by jax authors
parent d26819d6cd
commit d6a84cc5f3
2 changed files with 27 additions and 4 deletions

View File

@ -1505,14 +1505,22 @@ def _compute_pointers_from_indices(
else:
index = next(indexer_iter)
if isinstance(index, primitives.Slice):
# Handle slices with static and dynamic indices and static sizes
if isinstance(index.start, int):
ptr_dim_offset = _make_range(index.start, index.start + index.size)
else:
if index.is_dynamic_start:
# Compute the offset as start + range(0, size).
ptr_dim_offset = _add(
_bcast_to(index.start, [index.size]),
_ir_cast(_make_range(0, index.size), index.start.type, signed=False),
)
elif index.stride > 1:
# Compute the offset as start + range(0, size) * stride.
iota = _make_range(0, index.size)
ptr_dim_offset = _add(
_bcast_to(_i32_constant(index.start), [index.size]),
_mul(iota, _full(iota.type, index.stride)),
)
else:
ptr_dim_offset = _make_range(index.start, index.start + index.size)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx

View File

@ -437,6 +437,21 @@ class PallasCallTest(PallasTest):
x = random.normal(key, (size,))
np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5)
def test_strided_load(self):
if self.INTERPRET:
# TODO(b/329733289): Remove this once the bug is fixed.
self.skipTest("Strided load not yet supported in interpreter mode")
# Reproducer from https://github.com/google/jax/issues/20895.
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[::4]
x = jnp.arange(16, dtype=jnp.float32)
np.testing.assert_array_equal(kernel(x), x[::4])
def test_broadcasted_load_store(self):
m, n = 16, 32
@functools.partial(