mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add new unstack function to numpy/array_api namespaces
This commit is contained in:
parent
64775d02a3
commit
6bdc83c680
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.27
|
||||
|
||||
* New Functionality
|
||||
* Added {func}`jax.numpy.unstack`, following the addition of this function in
|
||||
the array API 2023 standard, soon to be adopted by NumPy.
|
||||
|
||||
* Changes
|
||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||
|
@ -417,6 +417,7 @@ namespace; they are listed below.
|
||||
unique_values
|
||||
unpackbits
|
||||
unravel_index
|
||||
unstack
|
||||
unsignedinteger
|
||||
unwrap
|
||||
vander
|
||||
|
@ -1887,6 +1887,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis, dtype=dtype)
|
||||
|
||||
@util.implements(getattr(np, 'unstack', None))
|
||||
@partial(jit, static_argnames="axis")
|
||||
def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
util.check_arraylike("unstack", x)
|
||||
x = asarray(x)
|
||||
if x.ndim == 0:
|
||||
raise ValueError(
|
||||
"Unstack requires arrays with rank > 0, however a scalar array was "
|
||||
"passed."
|
||||
)
|
||||
return tuple(moveaxis(x, axis, 0))
|
||||
|
||||
@util.implements(np.tile)
|
||||
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
|
||||
util.check_arraylike("tile", A)
|
||||
|
@ -180,6 +180,7 @@ from jax.experimental.array_api._manipulation_functions import (
|
||||
roll as roll,
|
||||
squeeze as squeeze,
|
||||
stack as stack,
|
||||
unstack as unstack,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._searching_functions import (
|
||||
|
@ -82,3 +82,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
|
||||
"""Joins a sequence of arrays along a new axis."""
|
||||
dtype = _result_type(*arrays)
|
||||
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)
|
||||
|
||||
def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
"""Splits an array in a sequence of arrays along the given axis."""
|
||||
return jax.numpy.unstack(x, axis=axis)
|
||||
|
@ -253,6 +253,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
unpackbits as unpackbits,
|
||||
unravel_index as unravel_index,
|
||||
unsignedinteger as unsignedinteger,
|
||||
unstack as unstack,
|
||||
unwrap as unwrap,
|
||||
vander as vander,
|
||||
vdot as vdot,
|
||||
|
@ -859,6 +859,7 @@ def unpackbits(
|
||||
) -> Array: ...
|
||||
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
|
||||
unsignedinteger = _np.unsignedinteger
|
||||
def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ...
|
||||
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
|
||||
axis: int = ..., period: ArrayLike = ...) -> Array: ...
|
||||
def vander(
|
||||
|
@ -171,6 +171,7 @@ MAIN_NAMESPACE = {
|
||||
'unique_counts',
|
||||
'unique_inverse',
|
||||
'unique_values',
|
||||
'unstack',
|
||||
'var',
|
||||
'vecdot',
|
||||
'where',
|
||||
|
@ -173,6 +173,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for a in out]
|
||||
return f
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in all_shapes
|
||||
for axis in list(range(-len(shape), len(shape)))],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testUnstack(self, shape, axis, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
if jnp.asarray(x).ndim == 0:
|
||||
with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"):
|
||||
jnp.unstack(x, axis=axis)
|
||||
return
|
||||
y = jnp.unstack(x, axis=axis)
|
||||
if shape[axis] == 0:
|
||||
self.assertEqual(y, ())
|
||||
else:
|
||||
self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x)
|
||||
|
||||
|
||||
@parameterized.parameters(
|
||||
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
|
||||
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
|
Loading…
x
Reference in New Issue
Block a user