[Pallas] Fix array indexing error when dimension size is not a multiple of stride

This commit is contained in:
Ayaka 2024-09-10 02:46:14 +01:00
parent 7d2f0a75c1
commit e0faa596b3
2 changed files with 28 additions and 18 deletions

View File

@ -169,7 +169,7 @@ def _maybe_convert_to_slice(
return None
start = i.start
end = i.start + i.size * i.stride
end = i.start + (i.size - 1) * i.stride + 1
stride = i.stride
# cannot convert to static `slice` if `start` or `end` is dynamic

View File

@ -647,27 +647,37 @@ class IndexerOpsInterpretTest(IndexerOpsTest):
# TODO(ayx): Fix all test cases here
_ADVANCED_INDEXER_TEST_CASES = [
((8, 2), lambda arr, a, b, c, d: arr[2]),
((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
((16, 3), lambda arr, a, b, c, d: arr[a, a]),
((16, 16), lambda arr, a, b, c, d: arr[::4, ::4]),
# integer
((3, 2), lambda arr, a, b, c, d: arr[2]),
# slice
((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]),
((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]),
((16, 3), lambda arr, a, b, c, d: arr[a, :]),
# ((16, 3), lambda arr, a, b, c, d: arr[:, a]),
((16, 3), lambda arr, a, b, c, d: arr[a, ::4]),
# ((16, 3), lambda arr, a, b, c, d: arr[::4, a]),
((8, 2), lambda arr, a, b, c, d: arr[1::3, :]),
# array
((4, 3), lambda arr, a, b, c, d: arr[a]),
((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]),
# integer + 1-D array
((4, 3), lambda arr, a, b, c, d: arr[2, a]),
((4, 3), lambda arr, a, b, c, d: arr[a, 2]),
# slice + 1-D array
((4, 3), lambda arr, a, b, c, d: arr[a, :]),
# ((4, 3), lambda arr, a, b, c, d: arr[:, a]),
((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]),
# ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]),
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]),
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]),
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, ::2, ::2]),
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, ::2, ::2]),
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
# ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]),
# ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, c]),
((8, 6, 4), lambda arr, a, b, c, d: arr[a]),
((6, 8, 4), lambda arr, a, b, c, d: arr[c, c]),
((6, 8, 4), lambda arr, a, b, c, d: arr[c, ::3]),
# ((8, 6, 4), lambda arr, a, b, c, d: arr[::3, c]),
((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]),
# slice + array w/ broadcasting
((8, 8, 3, 6), lambda arr, a, b, c, d: \
arr[b[:, None], ::4, a[None], a[:, None]]),
# integer + slice + 1-D array
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]),
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]),
# boolean
# ((6, 2), lambda arr, a, b, c, d: arr[d]),
# ((8, 6), lambda arr, a, b, c, d: arr[::4, d]),
]