mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #15397 from jakevdp:fix-split-annotation
PiperOrigin-RevId: 522341314
This commit is contained in:
commit
492b9c1455
@ -1154,7 +1154,8 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
|
|||||||
return util._broadcast_to(array, shape)
|
return util._broadcast_to(array, shape)
|
||||||
|
|
||||||
|
|
||||||
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
def _split(op: str, ary: ArrayLike,
|
||||||
|
indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||||
axis: int = 0) -> List[Array]:
|
axis: int = 0) -> List[Array]:
|
||||||
util.check_arraylike(op, ary)
|
util.check_arraylike(op, ary)
|
||||||
ary = asarray(ary)
|
ary = asarray(ary)
|
||||||
@ -1193,12 +1194,13 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
|||||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||||
|
|
||||||
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
|
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
|
||||||
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||||
|
axis: int = 0) -> List[Array]:
|
||||||
return _split("split", ary, indices_or_sections, axis=axis)
|
return _split("split", ary, indices_or_sections, axis=axis)
|
||||||
|
|
||||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
|
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
|
||||||
@util._wraps(getattr(np, op), update_doc=False)
|
@util._wraps(getattr(np, op), update_doc=False)
|
||||||
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
|
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]:
|
||||||
# for 1-D array, hsplit becomes vsplit
|
# for 1-D array, hsplit becomes vsplit
|
||||||
nonlocal axis
|
nonlocal axis
|
||||||
util.check_arraylike(op, ary)
|
util.check_arraylike(op, ary)
|
||||||
@ -1213,7 +1215,8 @@ hsplit = _split_on_axis("hsplit", axis=1)
|
|||||||
dsplit = _split_on_axis("dsplit", axis=2)
|
dsplit = _split_on_axis("dsplit", axis=2)
|
||||||
|
|
||||||
@util._wraps(np.array_split)
|
@util._wraps(np.array_split)
|
||||||
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||||
|
axis: int = 0) -> List[Array]:
|
||||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||||
|
|
||||||
@util._wraps(np.clip, skip_params=['out'])
|
@util._wraps(np.clip, skip_params=['out'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user