mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #23459 from jakevdp:split-doc
PiperOrigin-RevId: 672532876
This commit is contained in:
commit
c28b3de599
@ -2593,30 +2593,207 @@ def _split(op: str, ary: ArrayLike,
|
||||
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
|
||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||
|
||||
@util.implements(np.split, lax_description=_ARRAY_VIEW_DOC)
|
||||
|
||||
def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
|
||||
axis: int = 0) -> list[Array]:
|
||||
"""Split an array into sub-arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.split`.
|
||||
|
||||
Args:
|
||||
ary: N-dimensional array-like object to split
|
||||
indices_or_sections: either a single integer or a sequence of indices.
|
||||
|
||||
- if ``indices_or_sections`` is an integer *N*, then *N* must evenly divide
|
||||
``ary.shape[axis]`` and ``ary`` will be divided into *N* equally-sized
|
||||
chunks along ``axis``.
|
||||
- if ``indices_or_sections`` is a sequence of integers, then these integers
|
||||
specify the boundary between unevenly-sized chunks along ``axis``; see
|
||||
examples below.
|
||||
|
||||
axis: the axis along which to split; defaults to 0.
|
||||
|
||||
Returns:
|
||||
A list of arrays. If ``indices_or_sections`` is an integer *N*, then the list is
|
||||
of length *N*. If ``indices_or_sections`` is a sequence *seq*, then the list is
|
||||
is of length *len(seq) + 1*.
|
||||
|
||||
Examples:
|
||||
Splitting a 1-dimensional array:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
Split into three equal sections:
|
||||
|
||||
>>> chunks = jnp.split(x, 3)
|
||||
>>> print(*chunks)
|
||||
[1 2 3] [4 5 6] [7 8 9]
|
||||
|
||||
Split into sections by index:
|
||||
|
||||
>>> chunks = jnp.split(x, [2, 7]) # [x[0:2], x[2:7], x[7:]]
|
||||
>>> print(*chunks)
|
||||
[1 2] [3 4 5 6 7] [8 9]
|
||||
|
||||
Splitting a two-dimensional array along axis 1:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
... [5, 6, 7, 8]])
|
||||
>>> x1, x2 = jnp.split(x, 2, axis=1)
|
||||
>>> print(x1)
|
||||
[[1 2]
|
||||
[5 6]]
|
||||
>>> print(x2)
|
||||
[[3 4]
|
||||
[7 8]]
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
|
||||
to be an integer that does not evenly divide the size of the array.
|
||||
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
|
||||
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
|
||||
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
|
||||
"""
|
||||
return _split("split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]:
|
||||
@util.implements(getattr(np, op), update_doc=False)
|
||||
def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
|
||||
# for 1-D array, hsplit becomes vsplit
|
||||
nonlocal axis
|
||||
util.check_arraylike(op, ary)
|
||||
a = asarray(ary)
|
||||
if axis == 1 and len(a.shape) == 1:
|
||||
axis = 0
|
||||
return _split(op, ary, indices_or_sections, axis=axis)
|
||||
return f
|
||||
|
||||
vsplit = _split_on_axis("vsplit", axis=0)
|
||||
hsplit = _split_on_axis("hsplit", axis=1)
|
||||
dsplit = _split_on_axis("dsplit", axis=2)
|
||||
def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
|
||||
"""Split an array into sub-arrays vertically.
|
||||
|
||||
JAX implementation of :func:`numpy.vsplit`.
|
||||
|
||||
Refer to the documentation of :func:`jax.numpy.split` for details; ``vsplit`` is
|
||||
equivalent to ``split`` with ``axis=0``.
|
||||
|
||||
Examples:
|
||||
1D array:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5, 6])
|
||||
>>> x1, x2 = jnp.vsplit(x, 2)
|
||||
>>> print(x1, x2)
|
||||
[1 2 3] [4 5 6]
|
||||
|
||||
2D array:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
... [5, 6, 7, 8]])
|
||||
>>> x1, x2 = jnp.vsplit(x, 2)
|
||||
>>> print(x1, x2)
|
||||
[[1 2 3 4]] [[5 6 7 8]]
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.split`: split an array along any axis.
|
||||
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
|
||||
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
|
||||
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
|
||||
to be an integer that does not evenly divide the size of the array.
|
||||
"""
|
||||
return _split("vsplit", ary, indices_or_sections, axis=0)
|
||||
|
||||
|
||||
def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
|
||||
"""Split an array into sub-arrays horizontally.
|
||||
|
||||
JAX implementation of :func:`numpy.hsplit`.
|
||||
|
||||
Refer to the documentation of :func:`jax.numpy.split` for details. ``hsplit`` is
|
||||
equivalent to ``split`` with ``axis=1``, or ``axis=0`` for one-dimensional arrays.
|
||||
|
||||
Examples:
|
||||
1D array:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5, 6])
|
||||
>>> x1, x2 = jnp.hsplit(x, 2)
|
||||
>>> print(x1, x2)
|
||||
[1 2 3] [4 5 6]
|
||||
|
||||
2D array:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
... [5, 6, 7, 8]])
|
||||
>>> x1, x2 = jnp.hsplit(x, 2)
|
||||
>>> print(x1)
|
||||
[[1 2]
|
||||
[5 6]]
|
||||
>>> print(x2)
|
||||
[[3 4]
|
||||
[7 8]]
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.split`: split an array along any axis.
|
||||
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
|
||||
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
|
||||
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
|
||||
to be an integer that does not evenly divide the size of the array.
|
||||
"""
|
||||
util.check_arraylike("hsplit", ary)
|
||||
a = asarray(ary)
|
||||
return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1)
|
||||
|
||||
|
||||
def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
|
||||
"""Split an array into sub-arrays depth-wise.
|
||||
|
||||
JAX implementation of :func:`numpy.dsplit`.
|
||||
|
||||
Refer to the documentation of :func:`jax.numpy.split` for details. ``dsplit`` is
|
||||
equivalent to ``split`` with ``axis=2``.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 1, 4)
|
||||
>>> print(x)
|
||||
[[[ 0 1 2 3]]
|
||||
<BLANKLINE>
|
||||
[[ 4 5 6 7]]
|
||||
<BLANKLINE>
|
||||
[[ 8 9 10 11]]]
|
||||
>>> x1, x2 = jnp.dsplit(x, 2)
|
||||
>>> print(x1)
|
||||
[[[0 1]]
|
||||
<BLANKLINE>
|
||||
[[4 5]]
|
||||
<BLANKLINE>
|
||||
[[8 9]]]
|
||||
>>> print(x2)
|
||||
[[[ 2 3]]
|
||||
<BLANKLINE>
|
||||
[[ 6 7]]
|
||||
<BLANKLINE>
|
||||
[[10 11]]]
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.split`: split an array along any axis.
|
||||
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
|
||||
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
|
||||
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
|
||||
to be an integer that does not evenly divide the size of the array.
|
||||
"""
|
||||
return _split("dsplit", ary, indices_or_sections, axis=2)
|
||||
|
||||
|
||||
@util.implements(np.array_split)
|
||||
def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
|
||||
axis: int = 0) -> list[Array]:
|
||||
"""Split an array into sub-arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.array_split`.
|
||||
|
||||
Refer to the documentation of :func:`jax.numpy.split` for details; ``array_split``
|
||||
is equivalent to ``split``, but allows integer ``indices_or_sections`` which does
|
||||
not evenly divide the split axis.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
>>> chunks = jnp.array_split(x, 4)
|
||||
>>> print(*chunks)
|
||||
[1 2 3] [4 5] [6 7] [8 9]
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.split`: split an array along any axis.
|
||||
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
|
||||
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
|
||||
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
|
||||
"""
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
|
||||
|
@ -6289,6 +6289,7 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
unimplemented = ['fromfile', 'fromiter']
|
||||
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
|
||||
'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'round_']
|
||||
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']
|
||||
|
||||
for name in dir(jnp):
|
||||
if name.startswith('_') or name in unimplemented:
|
||||
@ -6313,7 +6314,7 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
|
||||
elif name in aliases:
|
||||
assert "Alias of" in obj.__doc__
|
||||
else:
|
||||
elif name not in skip_args_check:
|
||||
# Other functions should have nontrivial docs including "Args" and "Returns".
|
||||
doc = obj.__doc__
|
||||
self.assertNotEmpty(doc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user