Add new unstack function to numpy/array_api namespaces

This commit is contained in:
Meekail Zain 2024-04-15 21:03:26 +00:00
parent 64775d02a3
commit 6bdc83c680
9 changed files with 46 additions and 0 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.27
* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.
* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover

View File

@ -417,6 +417,7 @@ namespace; they are listed below.
unique_values
unpackbits
unravel_index
unstack
unsignedinteger
unwrap
vander

View File

@ -1887,6 +1887,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)
@util.implements(getattr(np, 'unstack', None))
@partial(jit, static_argnames="axis")
def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
util.check_arraylike("unstack", x)
x = asarray(x)
if x.ndim == 0:
raise ValueError(
"Unstack requires arrays with rank > 0, however a scalar array was "
"passed."
)
return tuple(moveaxis(x, axis, 0))
@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
util.check_arraylike("tile", A)

View File

@ -180,6 +180,7 @@ from jax.experimental.array_api._manipulation_functions import (
roll as roll,
squeeze as squeeze,
stack as stack,
unstack as unstack,
)
from jax.experimental.array_api._searching_functions import (

View File

@ -82,3 +82,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
"""Joins a sequence of arrays along a new axis."""
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)
def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Splits an array in a sequence of arrays along the given axis."""
return jax.numpy.unstack(x, axis=axis)

View File

@ -253,6 +253,7 @@ from jax._src.numpy.lax_numpy import (
unpackbits as unpackbits,
unravel_index as unravel_index,
unsignedinteger as unsignedinteger,
unstack as unstack,
unwrap as unwrap,
vander as vander,
vdot as vdot,

View File

@ -859,6 +859,7 @@ def unpackbits(
) -> Array: ...
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
unsignedinteger = _np.unsignedinteger
def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ...
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
axis: int = ..., period: ArrayLike = ...) -> Array: ...
def vander(

View File

@ -171,6 +171,7 @@ MAIN_NAMESPACE = {
'unique_counts',
'unique_inverse',
'unique_values',
'unstack',
'var',
'vecdot',
'where',

View File

@ -173,6 +173,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for a in out]
return f
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
dtype=all_dtypes,
)
def testUnstack(self, shape, axis, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
if jnp.asarray(x).ndim == 0:
with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"):
jnp.unstack(x, axis=axis)
return
y = jnp.unstack(x, axis=axis)
if shape[axis] == 0:
self.assertEqual(y, ())
else:
self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x)
@parameterized.parameters(
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,