From 26a3d3dc020406ca87ede46c5ad8a27952d28a19 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 23 Apr 2024 18:01:20 -0700 Subject: [PATCH] Only perform checks on slice sizes if they're static. PiperOrigin-RevId: 627560765 --- jax/_src/state/indexing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 716520c55..2544644d9 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -139,11 +139,12 @@ class NDIndexer: if value := _maybe_concretize(start): if value >= s: raise ValueError(f"Out of bound slice: start={value}, dim={s}.") - if value + (idx.size - 1) * idx.stride >= s: - raise ValueError( - f"Out of bound slice: start={value}, size={idx.size}," - f" stride={idx.stride}, dim={s}." - ) + if size := _maybe_concretize(idx.size): + if value + (size - 1) * idx.stride >= s: + raise ValueError( + f"Out of bound slice: start={value}, size={size}," + f" stride={idx.stride}, dim={s}." + ) continue # The shape of indexer integers should be broadcastable up to the # int_indexer_shape of the whole NDIndexer