mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Error message and docstring updates RE: dynamic_slice (#3795)
This should clarify the underlying issues from #1007 and #3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This commit is contained in:
parent
dd3cb82135
commit
fe99a06ddf
@ -721,9 +721,12 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
start_indices: a list of scalar indices, one per dimension.
|
||||
start_indices: a list of scalar indices, one per dimension. These values
|
||||
may be dynamic.
|
||||
slice_sizes: the size of the slice. Must be a sequence of non-negative
|
||||
integers with length equal to `ndim(operand)`.
|
||||
integers with length equal to `ndim(operand)`. Inside a JIT compiled
|
||||
function, only static values are supported (all JAX arrays inside JIT
|
||||
must have statically known size).
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
|
@ -3751,8 +3751,10 @@ def _index_to_gather(x_shape, idx):
|
||||
or type(core.get_aval(elt)) is ConcreteArray
|
||||
for elt in (i.start, i.stop, i.step)):
|
||||
msg = ("Array slice indices must have static start/stop/step to be used "
|
||||
"with Numpy indexing syntax. Try lax.dynamic_slice/"
|
||||
"dynamic_update_slice instead.")
|
||||
"with NumPy indexing syntax. To index a statically sized "
|
||||
"array at a dynamic position, try lax.dynamic_slice/"
|
||||
"dynamic_update_slice (JAX does not support dynamically sized "
|
||||
"arrays).")
|
||||
raise IndexError(msg)
|
||||
start, limit, stride, needs_rev = _static_idx(i, x_shape[x_axis])
|
||||
if needs_rev:
|
||||
|
Loading…
x
Reference in New Issue
Block a user