mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
This commit is contained in:
commit
c145102ef4
@ -54,6 +54,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
A downstream effect of this several other internal functions need debug
|
||||
info. This change does not affect public APIs.
|
||||
See https://github.com/jax-ml/jax/issues/26480 for more detail.
|
||||
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
|
||||
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
|
||||
|
||||
* Bug fixes
|
||||
* TPU runtime startup and shutdown time should be significantly improved on
|
||||
|
@ -26,7 +26,7 @@ from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.util import safe_zip, safe_map, set_module
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
from jax.sharding import Sharding
|
||||
|
||||
@ -35,6 +35,8 @@ import numpy as np
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin
|
||||
return SingleDeviceSharding(device)
|
||||
else:
|
||||
return device
|
||||
|
||||
|
||||
@export
|
||||
def ndim(a: ArrayLike) -> int:
|
||||
"""Return the number of dimensions of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object.
|
||||
|
||||
Returns:
|
||||
An integer specifying the number of dimensions of ``a``.
|
||||
|
||||
Examples:
|
||||
Number of dimensions for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.ndim(x)
|
||||
1
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.ndim(y)
|
||||
2
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.ndim(3.14)
|
||||
0
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property:
|
||||
|
||||
>>> x.ndim
|
||||
1
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("ndim", a, emit_warning=True)
|
||||
return np.ndim(a) # NumPy dispatches to a.ndim if available.
|
||||
|
||||
|
||||
@export
|
||||
def shape(a: ArrayLike) -> tuple[int, ...]:
|
||||
"""Return the shape an array.
|
||||
|
||||
JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object.
|
||||
|
||||
Returns:
|
||||
An tuple of integers representing the shape of ``a``.
|
||||
|
||||
Examples:
|
||||
Shape for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.shape(x)
|
||||
(10,)
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.shape(y)
|
||||
(2, 3)
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.shape(3.14)
|
||||
()
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:
|
||||
|
||||
>>> x.shape
|
||||
(10,)
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("shape", a, emit_warning=True)
|
||||
return np.shape(a) # NumPy dispatches to a.shape if available.
|
||||
|
||||
|
||||
@export
|
||||
def size(a: ArrayLike, axis: int | None = None) -> int:
|
||||
"""Return number of elements along a given axis.
|
||||
|
||||
JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object
|
||||
axis: optional integer along which to count elements. By default, return
|
||||
the total number of elements.
|
||||
|
||||
Returns:
|
||||
An integer specifying the number of elements in ``a``.
|
||||
|
||||
Examples:
|
||||
Size for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.size(x)
|
||||
10
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.size(y)
|
||||
6
|
||||
>>> jnp.size(y, axis=1)
|
||||
3
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.size(3.14)
|
||||
1
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.size` property:
|
||||
|
||||
>>> y.size
|
||||
6
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("size", a, emit_warning=True)
|
||||
return np.size(a, axis=axis) # NumPy dispatches to a.size if available.
|
||||
|
@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import (
|
||||
vdot as vdot,
|
||||
)
|
||||
|
||||
from jax._src.numpy.util import (
|
||||
ndim as ndim,
|
||||
shape as shape,
|
||||
size as size,
|
||||
)
|
||||
|
||||
from jax._src.numpy.window_functions import (
|
||||
bartlett as bartlett,
|
||||
blackman as blackman,
|
||||
@ -279,15 +285,12 @@ from numpy import (
|
||||
integer as integer,
|
||||
iterable as iterable,
|
||||
nan as nan,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
number as number,
|
||||
object_ as object_,
|
||||
pi as pi,
|
||||
save as save,
|
||||
savez as savez,
|
||||
shape as shape,
|
||||
size as size,
|
||||
signedinteger as signedinteger,
|
||||
unsignedinteger as unsignedinteger,
|
||||
)
|
||||
|
@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||
ddof: int = 0, keepdims: builtins.bool = False,
|
||||
where: ArrayLike | None = ...) -> Array: ...
|
||||
ndarray = Array
|
||||
ndim = _np.ndim
|
||||
def ndim(a: ArrayLike) -> int: ...
|
||||
def negative(x: ArrayLike, /) -> Array: ...
|
||||
newaxis = None
|
||||
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
@ -842,7 +842,7 @@ def setdiff1d(
|
||||
fill_value: ArrayLike | None = ...,
|
||||
) -> Array: ...
|
||||
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ...
|
||||
shape = _np.shape
|
||||
def shape(a: ArrayLike) -> tuple[int, ...]: ...
|
||||
def sign(x: ArrayLike, /) -> Array: ...
|
||||
def signbit(x: ArrayLike, /) -> Array: ...
|
||||
signedinteger = _np.signedinteger
|
||||
@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ...
|
||||
def sinc(x: ArrayLike, /) -> Array: ...
|
||||
single: Any
|
||||
def sinh(x: ArrayLike, /) -> Array: ...
|
||||
size = _np.size
|
||||
def size(a: ArrayLike, axis: int | None = None) -> int: ...
|
||||
def sort(
|
||||
a: ArrayLike,
|
||||
axis: int | None = ...,
|
||||
|
@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
||||
check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=default_dtypes,
|
||||
op=['ndim', 'shape', 'size'],
|
||||
)
|
||||
def testNdimShapeSize(self, shape, dtype, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jnp_op = getattr(jnp, op)
|
||||
np_op = getattr(np, op)
|
||||
x = rng(shape, dtype)
|
||||
expected = np_op(x)
|
||||
self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input.
|
||||
self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input.
|
||||
self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input.
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=nonzerodim_shapes,
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testSizeAlongAxis(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
axis = self.rng().randint(-len(shape), len(shape))
|
||||
np_op = partial(np.size, axis=axis)
|
||||
jnp_op = partial(jnp.size, axis=axis)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
op=[jnp.ndim, jnp.shape, jnp.size],
|
||||
)
|
||||
def testNdimShapeSizeNonArrayInput(self, op):
|
||||
msg = f"{op.__name__} requires ndarray or scalar arguments"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
op([1, 2, 3])
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
Loading…
x
Reference in New Issue
Block a user