From 6bdc83c6807e70e54982c373ee9883b46bd6e437 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 15 Apr 2024 21:03:26 +0000 Subject: [PATCH] Add new unstack function to numpy/array_api namespaces --- CHANGELOG.md | 4 ++++ docs/jax.numpy.rst | 1 + jax/_src/numpy/lax_numpy.py | 12 +++++++++++ jax/experimental/array_api/__init__.py | 1 + .../array_api/_manipulation_functions.py | 4 ++++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 1 + tests/array_api_test.py | 1 + tests/lax_numpy_test.py | 21 +++++++++++++++++++ 9 files changed, 46 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1310aa3fa..331554c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 +* New Functionality + * Added {func}`jax.numpy.unstack`, following the addition of this function in + the array API 2023 standard, soon to be adopted by NumPy. + * Changes * {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 41c87603b..d866df52d 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -417,6 +417,7 @@ namespace; they are listed below. unique_values unpackbits unravel_index + unstack unsignedinteger unwrap vander diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b23ca210..080d67591 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1887,6 +1887,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) +@util.implements(getattr(np, 'unstack', None)) +@partial(jit, static_argnames="axis") +def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: + util.check_arraylike("unstack", x) + x = asarray(x) + if x.ndim == 0: + raise ValueError( + "Unstack requires arrays with rank > 0, however a scalar array was " + "passed." + ) + return tuple(moveaxis(x, axis, 0)) + @util.implements(np.tile) def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: util.check_arraylike("tile", A) diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 14405e67f..dfb2ff988 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -180,6 +180,7 @@ from jax.experimental.array_api._manipulation_functions import ( roll as roll, squeeze as squeeze, stack as stack, + unstack as unstack, ) from jax.experimental.array_api._searching_functions import ( diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index fdc83fc83..3a1845909 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -82,3 +82,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array """Joins a sequence of arrays along a new axis.""" dtype = _result_type(*arrays) return jax.numpy.stack(arrays, axis=axis, dtype=dtype) + +def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]: + """Splits an array in a sequence of arrays along the given axis.""" + return jax.numpy.unstack(x, axis=axis) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 273c5a2aa..0b8a8cdd0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -253,6 +253,7 @@ from jax._src.numpy.lax_numpy import ( unpackbits as unpackbits, unravel_index as unravel_index, unsignedinteger as unsignedinteger, + unstack as unstack, unwrap as unwrap, vander as vander, vdot as vdot, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 5a2046f6f..7c7e68a06 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -859,6 +859,7 @@ def unpackbits( ) -> Array: ... def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ... unsignedinteger = _np.unsignedinteger +def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ... def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ..., axis: int = ..., period: ArrayLike = ...) -> Array: ... def vander( diff --git a/tests/array_api_test.py b/tests/array_api_test.py index f4dcfb74c..91c64e749 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -171,6 +171,7 @@ MAIN_NAMESPACE = { 'unique_counts', 'unique_inverse', 'unique_values', + 'unstack', 'var', 'vecdot', 'where', diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 24bd01247..231807e46 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -173,6 +173,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for a in out] return f + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape)))], + dtype=all_dtypes, + ) + def testUnstack(self, shape, axis, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + if jnp.asarray(x).ndim == 0: + with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"): + jnp.unstack(x, axis=axis) + return + y = jnp.unstack(x, axis=axis) + if shape[axis] == 0: + self.assertEqual(y, ()) + else: + self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x) + + @parameterized.parameters( [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,