diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 55d6f6aa6..f40a3726c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1154,7 +1154,8 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array: return util._broadcast_to(array, shape) -def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], +def _split(op: str, ary: ArrayLike, + indices_or_sections: Union[int, Sequence[int], ArrayLike], axis: int = 0) -> List[Array]: util.check_arraylike(op, ary) ary = asarray(ary) @@ -1193,12 +1194,13 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], for start, end in zip(split_indices[:-1], split_indices[1:])] @util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC) -def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]: +def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike], + axis: int = 0) -> List[Array]: return _split("split", ary, indices_or_sections, axis=axis) def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]: @util._wraps(getattr(np, op), update_doc=False) - def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]: + def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]: # for 1-D array, hsplit becomes vsplit nonlocal axis util.check_arraylike(op, ary) @@ -1213,7 +1215,8 @@ hsplit = _split_on_axis("hsplit", axis=1) dsplit = _split_on_axis("dsplit", axis=2) @util._wraps(np.array_split) -def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]: +def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike], + axis: int = 0) -> List[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) @util._wraps(np.clip, skip_params=['out'])