Merge pull request #15397 from jakevdp:fix-split-annotation

PiperOrigin-RevId: 522341314
This commit is contained in:
jax authors 2023-04-06 08:23:54 -07:00
commit 492b9c1455

View File

@ -1154,7 +1154,8 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
return util._broadcast_to(array, shape)
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
def _split(op: str, ary: ArrayLike,
indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
util.check_arraylike(op, ary)
ary = asarray(ary)
@ -1193,12 +1194,13 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
for start, end in zip(split_indices[:-1], split_indices[1:])]
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
@util._wraps(getattr(np, op), update_doc=False)
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]:
# for 1-D array, hsplit becomes vsplit
nonlocal axis
util.check_arraylike(op, ary)
@ -1213,7 +1215,8 @@ hsplit = _split_on_axis("hsplit", axis=1)
dsplit = _split_on_axis("dsplit", axis=2)
@util._wraps(np.array_split)
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)
@util._wraps(np.clip, skip_params=['out'])