mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas] Fix array indexing error when dimension size is not a multiple of stride
This commit is contained in:
parent
7d2f0a75c1
commit
e0faa596b3
@ -169,7 +169,7 @@ def _maybe_convert_to_slice(
|
||||
return None
|
||||
|
||||
start = i.start
|
||||
end = i.start + i.size * i.stride
|
||||
end = i.start + (i.size - 1) * i.stride + 1
|
||||
stride = i.stride
|
||||
|
||||
# cannot convert to static `slice` if `start` or `end` is dynamic
|
||||
|
@ -647,27 +647,37 @@ class IndexerOpsInterpretTest(IndexerOpsTest):
|
||||
|
||||
# TODO(ayx): Fix all test cases here
|
||||
_ADVANCED_INDEXER_TEST_CASES = [
|
||||
((8, 2), lambda arr, a, b, c, d: arr[2]),
|
||||
((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
|
||||
((16, 3), lambda arr, a, b, c, d: arr[a, a]),
|
||||
((16, 16), lambda arr, a, b, c, d: arr[::4, ::4]),
|
||||
# integer
|
||||
((3, 2), lambda arr, a, b, c, d: arr[2]),
|
||||
# slice
|
||||
((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]),
|
||||
((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]),
|
||||
((16, 3), lambda arr, a, b, c, d: arr[a, :]),
|
||||
# ((16, 3), lambda arr, a, b, c, d: arr[:, a]),
|
||||
((16, 3), lambda arr, a, b, c, d: arr[a, ::4]),
|
||||
# ((16, 3), lambda arr, a, b, c, d: arr[::4, a]),
|
||||
((8, 2), lambda arr, a, b, c, d: arr[1::3, :]),
|
||||
# array
|
||||
((4, 3), lambda arr, a, b, c, d: arr[a]),
|
||||
((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]),
|
||||
# integer + 1-D array
|
||||
((4, 3), lambda arr, a, b, c, d: arr[2, a]),
|
||||
((4, 3), lambda arr, a, b, c, d: arr[a, 2]),
|
||||
# slice + 1-D array
|
||||
((4, 3), lambda arr, a, b, c, d: arr[a, :]),
|
||||
# ((4, 3), lambda arr, a, b, c, d: arr[:, a]),
|
||||
((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]),
|
||||
# ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]),
|
||||
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]),
|
||||
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]),
|
||||
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, ::2, ::2]),
|
||||
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, ::2, ::2]),
|
||||
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
|
||||
# ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
|
||||
((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
|
||||
((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
|
||||
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]),
|
||||
# ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, c]),
|
||||
((8, 6, 4), lambda arr, a, b, c, d: arr[a]),
|
||||
((6, 8, 4), lambda arr, a, b, c, d: arr[c, c]),
|
||||
((6, 8, 4), lambda arr, a, b, c, d: arr[c, ::3]),
|
||||
# ((8, 6, 4), lambda arr, a, b, c, d: arr[::3, c]),
|
||||
((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
|
||||
((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]),
|
||||
# slice + array w/ broadcasting
|
||||
((8, 8, 3, 6), lambda arr, a, b, c, d: \
|
||||
arr[b[:, None], ::4, a[None], a[:, None]]),
|
||||
# integer + slice + 1-D array
|
||||
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]),
|
||||
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]),
|
||||
# boolean
|
||||
# ((6, 2), lambda arr, a, b, c, d: arr[d]),
|
||||
# ((8, 6), lambda arr, a, b, c, d: arr[::4, d]),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user