jax.numpy ndim/shape/size: deprecate non-array input

This commit is contained in:
Jake VanderPlas 2025-02-20 12:31:31 -08:00
parent 8af6f70fe0
commit 8cec6e636a
5 changed files with 171 additions and 7 deletions

View File

@ -54,6 +54,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
* Bug fixes
* TPU runtime startup and shutdown time should be significantly improved on

View File

@ -26,7 +26,7 @@ from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
from jax._src.util import safe_zip, safe_map
from jax._src.util import safe_zip, safe_map, set_module
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax.sharding import Sharding
@ -35,6 +35,8 @@ import numpy as np
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
export = set_module('jax.numpy')
_dtype = partial(dtypes.dtype, canonicalize=True)
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin
return SingleDeviceSharding(device)
else:
return device
@export
def ndim(a: ArrayLike) -> int:
"""Return the number of dimensions of an array.
JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object.
Returns:
An integer specifying the number of dimensions of ``a``.
Examples:
Number of dimensions for arrays:
>>> x = jnp.arange(10)
>>> jnp.ndim(x)
1
>>> y = jnp.ones((2, 3))
>>> jnp.ndim(y)
2
This also works for scalars:
>>> jnp.ndim(3.14)
0
For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property:
>>> x.ndim
1
"""
# Deprecation warning added 2025-2-20.
check_arraylike("ndim", a, emit_warning=True)
return np.ndim(a) # NumPy dispatches to a.ndim if available.
@export
def shape(a: ArrayLike) -> tuple[int, ...]:
"""Return the shape an array.
JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object.
Returns:
An tuple of integers representing the shape of ``a``.
Examples:
Shape for arrays:
>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)
This also works for scalars:
>>> jnp.shape(3.14)
()
For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:
>>> x.shape
(10,)
"""
# Deprecation warning added 2025-2-20.
check_arraylike("shape", a, emit_warning=True)
return np.shape(a) # NumPy dispatches to a.shape if available.
@export
def size(a: ArrayLike, axis: int | None = None) -> int:
"""Return number of elements along a given axis.
JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object
axis: optional integer along which to count elements. By default, return
the total number of elements.
Returns:
An integer specifying the number of elements in ``a``.
Examples:
Size for arrays:
>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3
This also works for scalars:
>>> jnp.size(3.14)
1
For arrays, this can also be accessed via the :attr:`jax.Array.size` property:
>>> y.size
6
"""
# Deprecation warning added 2025-2-20.
check_arraylike("size", a, emit_warning=True)
return np.size(a, axis=axis) # NumPy dispatches to a.size if available.

View File

@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import (
vdot as vdot,
)
from jax._src.numpy.util import (
ndim as ndim,
shape as shape,
size as size,
)
from jax._src.numpy.window_functions import (
bartlett as bartlett,
blackman as blackman,
@ -279,15 +285,12 @@ from numpy import (
integer as integer,
iterable as iterable,
nan as nan,
ndim as ndim,
newaxis as newaxis,
number as number,
object_ as object_,
pi as pi,
save as save,
savez as savez,
shape as shape,
size as size,
signedinteger as signedinteger,
unsignedinteger as unsignedinteger,
)

View File

@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
ddof: int = 0, keepdims: builtins.bool = False,
where: ArrayLike | None = ...) -> Array: ...
ndarray = Array
ndim = _np.ndim
def ndim(a: ArrayLike) -> int: ...
def negative(x: ArrayLike, /) -> Array: ...
newaxis = None
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ...
@ -842,7 +842,7 @@ def setdiff1d(
fill_value: ArrayLike | None = ...,
) -> Array: ...
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ...
shape = _np.shape
def shape(a: ArrayLike) -> tuple[int, ...]: ...
def sign(x: ArrayLike, /) -> Array: ...
def signbit(x: ArrayLike, /) -> Array: ...
signedinteger = _np.signedinteger
@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ...
def sinc(x: ArrayLike, /) -> Array: ...
single: Any
def sinh(x: ArrayLike, /) -> Array: ...
size = _np.size
def size(a: ArrayLike, axis: int | None = None) -> int: ...
def sort(
a: ArrayLike,
axis: int | None = ...,

View File

@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes,
op=['ndim', 'shape', 'size'],
)
def testNdimShapeSize(self, shape, dtype, op):
rng = jtu.rand_default(self.rng())
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
x = rng(shape, dtype)
expected = np_op(x)
self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input.
self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input.
self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input.
@jtu.sample_product(
shape=nonzerodim_shapes,
dtype=default_dtypes,
)
def testSizeAlongAxis(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
axis = self.rng().randint(-len(shape), len(shape))
np_op = partial(np.size, axis=axis)
jnp_op = partial(jnp.size, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
op=[jnp.ndim, jnp.shape, jnp.size],
)
def testNdimShapeSizeNonArrayInput(self, op):
msg = f"{op.__name__} requires ndarray or scalar arguments"
with self.assertWarnsRegex(DeprecationWarning, msg):
op([1, 2, 3])
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.