diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 716520c55..2544644d9 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -139,11 +139,12 @@ class NDIndexer: if value := _maybe_concretize(start): if value >= s: raise ValueError(f"Out of bound slice: start={value}, dim={s}.") - if value + (idx.size - 1) * idx.stride >= s: - raise ValueError( - f"Out of bound slice: start={value}, size={idx.size}," - f" stride={idx.stride}, dim={s}." - ) + if size := _maybe_concretize(idx.size): + if value + (size - 1) * idx.stride >= s: + raise ValueError( + f"Out of bound slice: start={value}, size={size}," + f" stride={idx.stride}, dim={s}." + ) continue # The shape of indexer integers should be broadcastable up to the # int_indexer_shape of the whole NDIndexer