Only perform checks on slice sizes if they're static.

PiperOrigin-RevId: 627560765
This commit is contained in:
jax authors 2024-04-23 18:01:20 -07:00
parent 8239674dab
commit 26a3d3dc02

View File

@ -139,11 +139,12 @@ class NDIndexer:
if value := _maybe_concretize(start):
if value >= s:
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
if value + (idx.size - 1) * idx.stride >= s:
raise ValueError(
f"Out of bound slice: start={value}, size={idx.size},"
f" stride={idx.stride}, dim={s}."
)
if size := _maybe_concretize(idx.size):
if value + (size - 1) * idx.stride >= s:
raise ValueError(
f"Out of bound slice: start={value}, size={size},"
f" stride={idx.stride}, dim={s}."
)
continue
# The shape of indexer integers should be broadcastable up to the
# int_indexer_shape of the whole NDIndexer