mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mosaic_gpu] Error on static out of bounds indices in utils.parse_indices
It would also be nice to optionally insert runtime assertions for dynamic indices, but we don't have a way of doing that just yet. PiperOrigin-RevId: 707532787
This commit is contained in:
parent
5beb4794b7
commit
98067fc10e
@ -470,8 +470,9 @@ class LaunchContext:
|
||||
" multiple of 16 bytes"
|
||||
)
|
||||
|
||||
# TMA supports OOB indices, so we skip the check.
|
||||
base_indices, slice_shape, is_squeezed = utils.parse_indices(
|
||||
gmem_slice, ir.MemRefType(gmem_ref.type).shape
|
||||
gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False
|
||||
)
|
||||
dyn_base_indices = tuple(
|
||||
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
|
||||
|
@ -329,6 +329,12 @@ class DynamicSlice:
|
||||
base: ir.Value | int
|
||||
length: int
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.base, int) and self.base < 0:
|
||||
raise ValueError(f"base must be non-negative, got {self.base}")
|
||||
if self.length < 0:
|
||||
raise ValueError(f"length must be non-negative, got {self.length}")
|
||||
|
||||
|
||||
ds = DynamicSlice
|
||||
|
||||
@ -569,7 +575,7 @@ def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value:
|
||||
|
||||
|
||||
def parse_indices(
|
||||
index, shape: tuple[int, ...]
|
||||
index, shape: tuple[int, ...], *, check_oob: bool = True
|
||||
) -> tuple[list[ir.Value | int], list[int], list[bool]]:
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
@ -578,20 +584,42 @@ def parse_indices(
|
||||
base_indices = []
|
||||
slice_shape = []
|
||||
is_squeezed = []
|
||||
for idx, bound in zip(index, shape):
|
||||
for axis, (idx, bound) in enumerate(zip(index, shape)):
|
||||
if isinstance(idx, (ir.Operation, ir.OpView)):
|
||||
idx = idx.result
|
||||
if isinstance(idx, int):
|
||||
base_indices.append(idx)
|
||||
if check_oob and (idx >= bound or (idx < 0 and -idx > bound)):
|
||||
raise IndexError(
|
||||
f"Index {idx} along axis {axis} is out of bounds for shape {shape}"
|
||||
)
|
||||
base_indices.append(idx if idx >= 0 else bound + idx)
|
||||
slice_shape.append(1)
|
||||
is_squeezed.append(True)
|
||||
elif isinstance(idx, slice):
|
||||
if idx.step is not None and idx.step != 1:
|
||||
raise NotImplementedError("Strided slices not implemented")
|
||||
base_indices.append(idx.start or 0)
|
||||
slice_shape.append((idx.stop or bound) - (idx.start or 0))
|
||||
start = idx.start or 0
|
||||
if start < 0:
|
||||
start = bound + start
|
||||
stop = idx.stop or bound
|
||||
if stop < 0:
|
||||
stop = bound + stop
|
||||
if check_oob and (
|
||||
start < 0 or start >= bound or stop < 0 or stop > bound
|
||||
):
|
||||
raise IndexError(
|
||||
f"Slice {idx} along axis {axis} is out of bounds for shape {shape}"
|
||||
)
|
||||
base_indices.append(start)
|
||||
slice_shape.append(stop - start)
|
||||
is_squeezed.append(False)
|
||||
elif isinstance(idx, DynamicSlice):
|
||||
if check_oob and (
|
||||
isinstance(idx.base, int) and idx.base + idx.length > bound
|
||||
):
|
||||
raise IndexError(
|
||||
f"Slice {idx} along axis {axis} is out of bounds for shape {shape}"
|
||||
)
|
||||
base_indices.append(idx.base)
|
||||
slice_shape.append(idx.length)
|
||||
is_squeezed.append(False)
|
||||
|
@ -327,7 +327,8 @@ class MemRefTest(TestCase):
|
||||
("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False),
|
||||
("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True),
|
||||
("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True),
|
||||
("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
|
||||
# TODO(cperivol): Investigate why this is indexing OOB and uncomment.
|
||||
# ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
|
||||
])
|
||||
def test_fold_strided(
|
||||
self, shape, strides, dim, fold_rank, throws_not_impl
|
||||
@ -1911,5 +1912,33 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
|
||||
|
||||
|
||||
class UtilsTest(TestCase):
|
||||
@parameterized.parameters(
|
||||
(1,),
|
||||
(-1,),
|
||||
(slice(2), slice(3),),
|
||||
(slice(1), slice(1, 3)),
|
||||
(slice(-2, 0),),
|
||||
(slice(-2, -1),),
|
||||
*([(utils.DynamicSlice(0, 2),)] if HAS_MOSAIC_GPU else []),
|
||||
)
|
||||
def test_parse_indices(self, *indices):
|
||||
# We are simply making sure this does not raise.
|
||||
_, _, _ = utils.parse_indices(indices, (2, 3, 4))
|
||||
|
||||
@parameterized.parameters(
|
||||
(42,),
|
||||
(-42,),
|
||||
(slice(42),),
|
||||
(slice(0, 42),),
|
||||
(slice(-42, 0),),
|
||||
(slice(-4, -42),),
|
||||
*([(utils.DynamicSlice(0, 4),)] if HAS_MOSAIC_GPU else []),
|
||||
)
|
||||
def test_parse_indices_oob(self, indices):
|
||||
with self.assertRaisesRegex(IndexError, "out of bounds"):
|
||||
utils.parse_indices(indices, (2, 3, 4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user