mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Pallas GPU no longer assumes that all slices have stride 1
Fixes #20895. PiperOrigin-RevId: 639031975
This commit is contained in:
parent
d26819d6cd
commit
d6a84cc5f3
@ -1505,14 +1505,22 @@ def _compute_pointers_from_indices(
|
||||
else:
|
||||
index = next(indexer_iter)
|
||||
if isinstance(index, primitives.Slice):
|
||||
# Handle slices with static and dynamic indices and static sizes
|
||||
if isinstance(index.start, int):
|
||||
ptr_dim_offset = _make_range(index.start, index.start + index.size)
|
||||
else:
|
||||
if index.is_dynamic_start:
|
||||
# Compute the offset as start + range(0, size).
|
||||
ptr_dim_offset = _add(
|
||||
_bcast_to(index.start, [index.size]),
|
||||
_ir_cast(_make_range(0, index.size), index.start.type, signed=False),
|
||||
)
|
||||
elif index.stride > 1:
|
||||
# Compute the offset as start + range(0, size) * stride.
|
||||
iota = _make_range(0, index.size)
|
||||
ptr_dim_offset = _add(
|
||||
_bcast_to(_i32_constant(index.start), [index.size]),
|
||||
_mul(iota, _full(iota.type, index.stride)),
|
||||
)
|
||||
else:
|
||||
ptr_dim_offset = _make_range(index.start, index.start + index.size)
|
||||
|
||||
# We need to add broadcastable dimensions for the advanced int indexing
|
||||
# and for previous slices
|
||||
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
|
||||
|
@ -437,6 +437,21 @@ class PallasCallTest(PallasTest):
|
||||
x = random.normal(key, (size,))
|
||||
np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_strided_load(self):
|
||||
if self.INTERPRET:
|
||||
# TODO(b/329733289): Remove this once the bug is fixed.
|
||||
self.skipTest("Strided load not yet supported in interpreter mode")
|
||||
|
||||
# Reproducer from https://github.com/google/jax/issues/20895.
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[::4]
|
||||
|
||||
x = jnp.arange(16, dtype=jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x[::4])
|
||||
|
||||
def test_broadcasted_load_store(self):
|
||||
m, n = 16, 32
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user