Error message and docstring updates RE: dynamic_slice (#3795)

This should clarify the underlying issues from #1007 and #3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
This commit is contained in:
Stephan Hoyer 2020-07-20 06:08:54 -07:00 committed by GitHub
parent dd3cb82135
commit fe99a06ddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 4 deletions

View File

@ -721,9 +721,12 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
Args:
operand: an array to slice.
start_indices: a list of scalar indices, one per dimension.
start_indices: a list of scalar indices, one per dimension. These values
may be dynamic.
slice_sizes: the size of the slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`.
integers with length equal to `ndim(operand)`. Inside a JIT compiled
function, only static values are supported (all JAX arrays inside JIT
must have statically known size).
Returns:
An array containing the slice.

View File

@ -3751,8 +3751,10 @@ def _index_to_gather(x_shape, idx):
or type(core.get_aval(elt)) is ConcreteArray
for elt in (i.start, i.stop, i.step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with Numpy indexing syntax. Try lax.dynamic_slice/"
"dynamic_update_slice instead.")
"with NumPy indexing syntax. To index a statically sized "
"array at a dynamic position, try lax.dynamic_slice/"
"dynamic_update_slice (JAX does not support dynamically sized "
"arrays).")
raise IndexError(msg)
start, limit, stride, needs_rev = _static_idx(i, x_shape[x_axis])
if needs_rev: