mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Merge pull request #15397 from jakevdp:fix-split-annotation
PiperOrigin-RevId: 522341314
This commit is contained in:
commit
492b9c1455
@ -1154,7 +1154,8 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
|
||||
return util._broadcast_to(array, shape)
|
||||
|
||||
|
||||
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
||||
def _split(op: str, ary: ArrayLike,
|
||||
indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
util.check_arraylike(op, ary)
|
||||
ary = asarray(ary)
|
||||
@ -1193,12 +1194,13 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||
|
||||
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
|
||||
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
||||
def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
return _split("split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
|
||||
@util._wraps(getattr(np, op), update_doc=False)
|
||||
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
|
||||
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]:
|
||||
# for 1-D array, hsplit becomes vsplit
|
||||
nonlocal axis
|
||||
util.check_arraylike(op, ary)
|
||||
@ -1213,7 +1215,8 @@ hsplit = _split_on_axis("hsplit", axis=1)
|
||||
dsplit = _split_on_axis("dsplit", axis=2)
|
||||
|
||||
@util._wraps(np.array_split)
|
||||
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
||||
def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@util._wraps(np.clip, skip_params=['out'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user